argmin/solver/trustregion/
steihaug.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16
17/// The Steihaug method is a conjugate gradients based approach for finding an approximate solution
18/// to the second order approximation of the cost function within the trust region.
19///
20/// # References:
21///
22/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
23/// Springer. ISBN 0-387-30303-0.
24#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, PartialOrd, Default)]
25pub struct Steihaug<P, F> {
26    /// Radius
27    radius: F,
28    /// epsilon
29    epsilon: F,
30    /// p
31    p: P,
32    /// residual
33    r: P,
34    /// r^Tr
35    rtr: F,
36    /// initial residual
37    r_0_norm: F,
38    /// direction
39    d: P,
40    /// max iters
41    max_iters: u64,
42}
43
44impl<P, F> Steihaug<P, F>
45where
46    P: Default + Clone + ArgminMul<F, P> + ArgminDot<P, F> + ArgminAdd<P, P>,
47    F: ArgminFloat,
48{
49    /// Constructor
50    pub fn new() -> Self {
51        Steihaug {
52            radius: F::nan(),
53            epsilon: F::from_f64(10e-10).unwrap(),
54            p: P::default(),
55            r: P::default(),
56            rtr: F::nan(),
57            r_0_norm: F::nan(),
58            d: P::default(),
59            max_iters: std::u64::MAX,
60        }
61    }
62
63    /// Set epsilon
64    pub fn epsilon(mut self, epsilon: F) -> Result<Self, Error> {
65        if epsilon <= F::from_f64(0.0).unwrap() {
66            return Err(ArgminError::InvalidParameter {
67                text: "Steihaug: epsilon must be > 0.0.".to_string(),
68            }
69            .into());
70        }
71        self.epsilon = epsilon;
72        Ok(self)
73    }
74
75    /// set maximum number of iterations
76    pub fn max_iters(mut self, iters: u64) -> Self {
77        self.max_iters = iters;
78        self
79    }
80
81    /// evaluate m(p) (without considering f_init because it is not available)
82    fn eval_m<H>(&self, p: &P, g: &P, h: &H) -> F
83    where
84        P: ArgminWeightedDot<P, F, H>,
85    {
86        // self.cur_grad().dot(&p) + 0.5 * p.weighted_dot(&self.cur_hessian(), &p)
87        g.dot(&p) + F::from_f64(0.5).unwrap() * p.weighted_dot(&h, &p)
88    }
89
90    /// calculate all possible step lengths
91    #[allow(clippy::many_single_char_names)]
92    fn tau<G, H>(&self, filter_func: G, eval: bool, g: &P, h: &H) -> F
93    where
94        G: Fn(F) -> bool,
95        H: ArgminDot<P, P>,
96    {
97        let a = self.p.dot(&self.p);
98        let b = self.d.dot(&self.d);
99        let c = self.p.dot(&self.d);
100        let delta = self.radius.powi(2);
101        let t1 = (-a * b + b * delta + c.powi(2)).sqrt();
102        let tau1 = -(t1 + c) / b;
103        let tau2 = (t1 - c) / b;
104        let mut t = vec![tau1, tau2];
105        // Maybe calculating tau3 should only be done if b is close to zero?
106        if tau1.is_nan() || tau2.is_nan() || tau1.is_infinite() || tau2.is_infinite() {
107            let tau3 = (delta - a) / (F::from_f64(2.0).unwrap() * c);
108            t.push(tau3);
109        }
110        let v = if eval {
111            // remove NAN taus and calculate m (without f_init) for all taus, then sort them based
112            // on their result and return the tau which corresponds to the lowest m
113            let mut v = t
114                .iter()
115                .cloned()
116                .enumerate()
117                .filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
118                .map(|(i, tau)| {
119                    let p = self.p.add(&self.d.mul(&tau));
120                    (i, self.eval_m(&p, g, h))
121                })
122                .filter(|(_, m)| !m.is_nan() || !m.is_infinite())
123                .collect::<Vec<(usize, F)>>();
124            v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
125            v
126        } else {
127            let mut v = t
128                .iter()
129                .cloned()
130                .enumerate()
131                .filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
132                .collect::<Vec<(usize, F)>>();
133            v.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
134            v
135        };
136
137        t[v[0].0]
138    }
139}
140
141impl<P, O, F> Solver<O> for Steihaug<P, F>
142where
143    O: ArgminOp<Param = P, Output = F, Float = F>,
144    P: Clone
145        + Serialize
146        + DeserializeOwned
147        + Default
148        + ArgminMul<F, P>
149        + ArgminWeightedDot<P, F, O::Hessian>
150        + ArgminNorm<F>
151        + ArgminDot<P, F>
152        + ArgminAdd<P, P>
153        + ArgminSub<P, P>
154        + ArgminZeroLike
155        + ArgminMul<F, P>,
156    O::Hessian: ArgminDot<P, P>,
157    F: ArgminFloat,
158{
159    const NAME: &'static str = "Steihaug";
160
161    fn init(
162        &mut self,
163        _op: &mut OpWrapper<O>,
164        state: &IterState<O>,
165    ) -> Result<Option<ArgminIterData<O>>, Error> {
166        self.r = state.get_grad().unwrap();
167
168        self.r_0_norm = self.r.norm();
169        self.rtr = self.r.dot(&self.r);
170        self.d = self.r.mul(&F::from_f64(-1.0).unwrap());
171        self.p = self.r.zero_like();
172
173        Ok(if self.r_0_norm < self.epsilon {
174            Some(
175                ArgminIterData::new()
176                    .param(self.p.clone())
177                    .termination_reason(TerminationReason::TargetPrecisionReached),
178            )
179        } else {
180            None
181        })
182    }
183
184    fn next_iter(
185        &mut self,
186        _op: &mut OpWrapper<O>,
187        state: &IterState<O>,
188    ) -> Result<ArgminIterData<O>, Error> {
189        let grad = state.get_grad().unwrap();
190        let h = state.get_hessian().unwrap();
191        let dhd = self.d.weighted_dot(&h, &self.d);
192
193        // Current search direction d is a direction of zero curvature or negative curvature
194        if dhd <= F::from_f64(0.0).unwrap() {
195            let tau = self.tau(|_| true, true, &grad, &h);
196            return Ok(ArgminIterData::new()
197                .param(self.p.add(&self.d.mul(&tau)))
198                .termination_reason(TerminationReason::TargetPrecisionReached));
199        }
200
201        let alpha = self.rtr / dhd;
202        let p_n = self.p.add(&self.d.mul(&alpha));
203
204        // new p violates trust region bound
205        if p_n.norm() >= self.radius {
206            let tau = self.tau(|x| x >= F::from_f64(0.0).unwrap(), false, &grad, &h);
207            return Ok(ArgminIterData::new()
208                .param(self.p.add(&self.d.mul(&tau)))
209                .termination_reason(TerminationReason::TargetPrecisionReached));
210        }
211
212        let r_n = self.r.add(&h.dot(&self.d).mul(&alpha));
213
214        if r_n.norm() < self.epsilon * self.r_0_norm {
215            return Ok(ArgminIterData::new()
216                .param(p_n)
217                .termination_reason(TerminationReason::TargetPrecisionReached));
218        }
219
220        let rjtrj = r_n.dot(&r_n);
221        let beta = rjtrj / self.rtr;
222        self.d = r_n.mul(&F::from_f64(-1.0).unwrap()).add(&self.d.mul(&beta));
223        self.r = r_n;
224        self.p = p_n;
225        self.rtr = rjtrj;
226
227        Ok(ArgminIterData::new()
228            .param(self.p.clone())
229            .cost(self.rtr)
230            .grad(grad)
231            .hessian(h))
232    }
233
234    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
235        if state.get_iter() >= self.max_iters {
236            TerminationReason::MaxItersReached
237        } else {
238            TerminationReason::NotTerminated
239        }
240    }
241}
242
243impl<P: Clone + Serialize, F: ArgminFloat> ArgminTrustRegion<F> for Steihaug<P, F> {
244    fn set_radius(&mut self, radius: F) {
245        self.radius = radius;
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::test_trait_impl;
253
254    test_trait_impl!(steihaug, Steihaug<MinimalNoOperator, f64>);
255}