argmin/solver/trustregion/
steihaug.rs1use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16
17#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, PartialOrd, Default)]
25pub struct Steihaug<P, F> {
26 radius: F,
28 epsilon: F,
30 p: P,
32 r: P,
34 rtr: F,
36 r_0_norm: F,
38 d: P,
40 max_iters: u64,
42}
43
44impl<P, F> Steihaug<P, F>
45where
46 P: Default + Clone + ArgminMul<F, P> + ArgminDot<P, F> + ArgminAdd<P, P>,
47 F: ArgminFloat,
48{
49 pub fn new() -> Self {
51 Steihaug {
52 radius: F::nan(),
53 epsilon: F::from_f64(10e-10).unwrap(),
54 p: P::default(),
55 r: P::default(),
56 rtr: F::nan(),
57 r_0_norm: F::nan(),
58 d: P::default(),
59 max_iters: std::u64::MAX,
60 }
61 }
62
63 pub fn epsilon(mut self, epsilon: F) -> Result<Self, Error> {
65 if epsilon <= F::from_f64(0.0).unwrap() {
66 return Err(ArgminError::InvalidParameter {
67 text: "Steihaug: epsilon must be > 0.0.".to_string(),
68 }
69 .into());
70 }
71 self.epsilon = epsilon;
72 Ok(self)
73 }
74
75 pub fn max_iters(mut self, iters: u64) -> Self {
77 self.max_iters = iters;
78 self
79 }
80
81 fn eval_m<H>(&self, p: &P, g: &P, h: &H) -> F
83 where
84 P: ArgminWeightedDot<P, F, H>,
85 {
86 g.dot(&p) + F::from_f64(0.5).unwrap() * p.weighted_dot(&h, &p)
88 }
89
90 #[allow(clippy::many_single_char_names)]
92 fn tau<G, H>(&self, filter_func: G, eval: bool, g: &P, h: &H) -> F
93 where
94 G: Fn(F) -> bool,
95 H: ArgminDot<P, P>,
96 {
97 let a = self.p.dot(&self.p);
98 let b = self.d.dot(&self.d);
99 let c = self.p.dot(&self.d);
100 let delta = self.radius.powi(2);
101 let t1 = (-a * b + b * delta + c.powi(2)).sqrt();
102 let tau1 = -(t1 + c) / b;
103 let tau2 = (t1 - c) / b;
104 let mut t = vec![tau1, tau2];
105 if tau1.is_nan() || tau2.is_nan() || tau1.is_infinite() || tau2.is_infinite() {
107 let tau3 = (delta - a) / (F::from_f64(2.0).unwrap() * c);
108 t.push(tau3);
109 }
110 let v = if eval {
111 let mut v = t
114 .iter()
115 .cloned()
116 .enumerate()
117 .filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
118 .map(|(i, tau)| {
119 let p = self.p.add(&self.d.mul(&tau));
120 (i, self.eval_m(&p, g, h))
121 })
122 .filter(|(_, m)| !m.is_nan() || !m.is_infinite())
123 .collect::<Vec<(usize, F)>>();
124 v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
125 v
126 } else {
127 let mut v = t
128 .iter()
129 .cloned()
130 .enumerate()
131 .filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
132 .collect::<Vec<(usize, F)>>();
133 v.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
134 v
135 };
136
137 t[v[0].0]
138 }
139}
140
141impl<P, O, F> Solver<O> for Steihaug<P, F>
142where
143 O: ArgminOp<Param = P, Output = F, Float = F>,
144 P: Clone
145 + Serialize
146 + DeserializeOwned
147 + Default
148 + ArgminMul<F, P>
149 + ArgminWeightedDot<P, F, O::Hessian>
150 + ArgminNorm<F>
151 + ArgminDot<P, F>
152 + ArgminAdd<P, P>
153 + ArgminSub<P, P>
154 + ArgminZeroLike
155 + ArgminMul<F, P>,
156 O::Hessian: ArgminDot<P, P>,
157 F: ArgminFloat,
158{
159 const NAME: &'static str = "Steihaug";
160
161 fn init(
162 &mut self,
163 _op: &mut OpWrapper<O>,
164 state: &IterState<O>,
165 ) -> Result<Option<ArgminIterData<O>>, Error> {
166 self.r = state.get_grad().unwrap();
167
168 self.r_0_norm = self.r.norm();
169 self.rtr = self.r.dot(&self.r);
170 self.d = self.r.mul(&F::from_f64(-1.0).unwrap());
171 self.p = self.r.zero_like();
172
173 Ok(if self.r_0_norm < self.epsilon {
174 Some(
175 ArgminIterData::new()
176 .param(self.p.clone())
177 .termination_reason(TerminationReason::TargetPrecisionReached),
178 )
179 } else {
180 None
181 })
182 }
183
184 fn next_iter(
185 &mut self,
186 _op: &mut OpWrapper<O>,
187 state: &IterState<O>,
188 ) -> Result<ArgminIterData<O>, Error> {
189 let grad = state.get_grad().unwrap();
190 let h = state.get_hessian().unwrap();
191 let dhd = self.d.weighted_dot(&h, &self.d);
192
193 if dhd <= F::from_f64(0.0).unwrap() {
195 let tau = self.tau(|_| true, true, &grad, &h);
196 return Ok(ArgminIterData::new()
197 .param(self.p.add(&self.d.mul(&tau)))
198 .termination_reason(TerminationReason::TargetPrecisionReached));
199 }
200
201 let alpha = self.rtr / dhd;
202 let p_n = self.p.add(&self.d.mul(&alpha));
203
204 if p_n.norm() >= self.radius {
206 let tau = self.tau(|x| x >= F::from_f64(0.0).unwrap(), false, &grad, &h);
207 return Ok(ArgminIterData::new()
208 .param(self.p.add(&self.d.mul(&tau)))
209 .termination_reason(TerminationReason::TargetPrecisionReached));
210 }
211
212 let r_n = self.r.add(&h.dot(&self.d).mul(&alpha));
213
214 if r_n.norm() < self.epsilon * self.r_0_norm {
215 return Ok(ArgminIterData::new()
216 .param(p_n)
217 .termination_reason(TerminationReason::TargetPrecisionReached));
218 }
219
220 let rjtrj = r_n.dot(&r_n);
221 let beta = rjtrj / self.rtr;
222 self.d = r_n.mul(&F::from_f64(-1.0).unwrap()).add(&self.d.mul(&beta));
223 self.r = r_n;
224 self.p = p_n;
225 self.rtr = rjtrj;
226
227 Ok(ArgminIterData::new()
228 .param(self.p.clone())
229 .cost(self.rtr)
230 .grad(grad)
231 .hessian(h))
232 }
233
234 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
235 if state.get_iter() >= self.max_iters {
236 TerminationReason::MaxItersReached
237 } else {
238 TerminationReason::NotTerminated
239 }
240 }
241}
242
243impl<P: Clone + Serialize, F: ArgminFloat> ArgminTrustRegion<F> for Steihaug<P, F> {
244 fn set_radius(&mut self, radius: F) {
245 self.radius = radius;
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::test_trait_impl;
253
254 test_trait_impl!(steihaug, Steihaug<MinimalNoOperator, f64>);
255}