argmin/solver/quasinewton/
sr1_trustregion.rs

1// Copyright 2019-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18/// SR1 Trust Region method
19///
20/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/sr1_trustregion.rs)
21///
22/// # References:
23///
24/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
25/// Springer. ISBN 0-387-30303-0.
26#[derive(Clone, Serialize, Deserialize)]
27pub struct SR1TrustRegion<B, R, F> {
28    /// parameter for skipping rule
29    r: F,
30    /// Inverse Hessian
31    init_hessian: Option<B>,
32    /// subproblem
33    subproblem: R,
34    /// Radius
35    radius: F,
36    /// eta \in [0, 1/4)
37    eta: F,
38    /// Tolerance for the stopping criterion based on the change of the norm on the gradient
39    tol_grad: F,
40}
41
42impl<B, R, F: ArgminFloat> SR1TrustRegion<B, R, F> {
43    /// Constructor
44    pub fn new(subproblem: R) -> Self {
45        SR1TrustRegion {
46            r: F::from_f64(1e-8).unwrap(),
47            init_hessian: None,
48            subproblem,
49            radius: F::from_f64(1.0).unwrap(),
50            eta: F::from_f64(0.5 * 1e-3).unwrap(),
51            tol_grad: F::from_f64(1e-3).unwrap(),
52        }
53    }
54
55    /// provide initial Hessian (if not provided, the algorithm will try to compute it using the
56    /// `hessian` method of `ArgminOp`.
57    pub fn hessian(mut self, init_hessian: B) -> Self {
58        self.init_hessian = Some(init_hessian);
59        self
60    }
61
62    /// Set r
63    pub fn r(mut self, r: F) -> Result<Self, Error> {
64        if r <= F::from_f64(0.0).unwrap() || r >= F::from_f64(1.0).unwrap() {
65            Err(ArgminError::InvalidParameter {
66                text: "SR1: r must be in (0, 1).".to_string(),
67            }
68            .into())
69        } else {
70            self.r = r;
71            Ok(self)
72        }
73    }
74
75    /// set radius
76    pub fn radius(mut self, radius: F) -> Self {
77        self.radius = radius.abs();
78        self
79    }
80
81    /// Set eta
82    pub fn eta(mut self, eta: F) -> Result<Self, Error> {
83        if eta >= F::from_f64(10e-3).unwrap() || eta <= F::from_f64(0.0).unwrap() {
84            return Err(ArgminError::InvalidParameter {
85                text: "SR1TrustRegion: eta must be in (0, 10^-3).".to_string(),
86            }
87            .into());
88        }
89        self.eta = eta;
90        Ok(self)
91    }
92
93    /// Sets tolerance for the stopping criterion based on the change of the norm on the gradient
94    pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
95        self.tol_grad = tol_grad;
96        self
97    }
98}
99
100impl<O, B, R, F> Solver<O> for SR1TrustRegion<B, R, F>
101where
102    O: ArgminOp<Output = F, Hessian = B, Float = F>,
103    O::Param: Debug
104        + Clone
105        + Default
106        + Serialize
107        + ArgminSub<O::Param, O::Param>
108        + ArgminAdd<O::Param, O::Param>
109        + ArgminDot<O::Param, O::Float>
110        + ArgminDot<O::Param, O::Hessian>
111        + ArgminNorm<O::Float>
112        + ArgminZeroLike
113        + ArgminMul<F, O::Param>,
114    O::Hessian: Debug
115        + Clone
116        + Default
117        + Serialize
118        + DeserializeOwned
119        + ArgminSub<O::Hessian, O::Hessian>
120        + ArgminDot<O::Param, O::Param>
121        + ArgminDot<O::Hessian, O::Hessian>
122        + ArgminAdd<O::Hessian, O::Hessian>
123        + ArgminMul<F, O::Hessian>,
124    R: ArgminTrustRegion<F> + Solver<OpWrapper<O>>,
125    F: ArgminFloat + ArgminNorm<O::Float>,
126{
127    const NAME: &'static str = "SR1 Trust Region";
128
129    fn init(
130        &mut self,
131        op: &mut OpWrapper<O>,
132        state: &IterState<O>,
133    ) -> Result<Option<ArgminIterData<O>>, Error> {
134        let param = state.get_param();
135        let cost = op.apply(&param)?;
136        let grad = op.gradient(&param)?;
137        let hessian = state
138            .get_hessian()
139            .unwrap_or_else(|| op.hessian(&param).unwrap());
140        Ok(Some(
141            ArgminIterData::new()
142                .param(param)
143                .cost(cost)
144                .grad(grad)
145                .hessian(hessian),
146        ))
147    }
148
149    fn next_iter(
150        &mut self,
151        op: &mut OpWrapper<O>,
152        state: &IterState<O>,
153    ) -> Result<ArgminIterData<O>, Error> {
154        let xk = state.get_param();
155        let cost = state.get_cost();
156        let prev_grad = state
157            .get_grad()
158            .unwrap_or_else(|| op.gradient(&xk).unwrap());
159        let hessian: O::Hessian = state.get_hessian().unwrap();
160
161        self.subproblem.set_radius(self.radius);
162
163        let ArgminResult {
164            operator: sub_op,
165            state: IterState { param: sk, .. },
166        } = Executor::new(
167            OpWrapper::new_from_wrapper(op),
168            self.subproblem.clone(),
169            // xk.clone(),
170            xk.zero_like(),
171        )
172        .cost(cost)
173        .grad(prev_grad.clone())
174        .hessian(hessian.clone())
175        .ctrlc(false)
176        .run()?;
177
178        op.consume_op(sub_op);
179
180        let xksk = xk.add(&sk);
181        let dfk1 = op.gradient(&xksk)?;
182        let yk = dfk1.sub(&prev_grad);
183        let fk1 = op.apply(&xksk)?;
184
185        let ared = cost - fk1;
186        let tmp1: F = prev_grad.dot(&sk);
187        let tmp2: F = sk.weighted_dot(&hessian, &sk);
188        let tmp2: F = tmp2.mul(F::from_f64(0.5).unwrap());
189        let pred = -tmp1 - tmp2;
190        let ap = ared / pred;
191
192        let (xk1, fk1, dfk1) = if ap > self.eta {
193            (xksk, fk1, dfk1)
194        } else {
195            (xk, cost, prev_grad)
196        };
197
198        self.radius = if ap > F::from_f64(0.75).unwrap() {
199            if sk.norm() <= F::from_f64(0.8).unwrap() * self.radius {
200                self.radius
201            } else {
202                F::from_f64(2.0).unwrap() * self.radius
203            }
204        } else if ap <= F::from_f64(0.75).unwrap() && ap >= F::from_f64(0.1).unwrap() {
205            self.radius
206        } else {
207            F::from_f64(0.5).unwrap() * self.radius
208        };
209
210        let bksk = hessian.dot(&sk);
211        let ykbksk = yk.sub(&bksk);
212        let skykbksk: F = sk.dot(&ykbksk);
213
214        let hessian_update = skykbksk.abs() >= self.r * sk.norm() * skykbksk.norm();
215        let hessian = if hessian_update {
216            let a: O::Hessian = ykbksk.dot(&ykbksk);
217            let b: F = sk.dot(&ykbksk);
218            hessian.add(&a.mul(&(F::from_f64(1.0).unwrap() / b)))
219        } else {
220            hessian
221        };
222
223        Ok(ArgminIterData::new()
224            .param(xk1)
225            .cost(fk1)
226            .grad(dfk1)
227            .hessian(hessian)
228            .kv(make_kv!["ared" => ared;
229                         "pred" => pred;
230                         "ap" => ap;
231                         "radius" => self.radius;
232                         "hessian_update" => hessian_update;]))
233    }
234
235    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
236        if state.get_grad().unwrap().norm() < self.tol_grad {
237            return TerminationReason::TargetPrecisionReached;
238        }
239        // if (state.get_prev_cost() - state.get_cost()).abs() < std::f64::EPSILON {
240        //     return TerminationReason::NoChangeInCost;
241        // }
242        TerminationReason::NotTerminated
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::solver::trustregion::CauchyPoint;
250    use crate::test_trait_impl;
251
252    type Operator = MinimalNoOperator;
253
254    test_trait_impl!(sr1, SR1TrustRegion<Operator, CauchyPoint<f64>, f64>);
255
256    #[test]
257    fn test_tolerances() {
258        let subproblem = CauchyPoint::new();
259
260        let tol: f64 = 1e-4;
261
262        let SR1TrustRegion { tol_grad: t, .. }: SR1TrustRegion<
263            MinimalNoOperator,
264            CauchyPoint<f64>,
265            f64,
266        > = SR1TrustRegion::new(subproblem).with_tol_grad(tol);
267
268        assert!((t - tol).abs() < std::f64::EPSILON);
269    }
270}