argmin/solver/quasinewton/
sr1.rs

1// Copyright 2019-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18/// SR1 method (broken!)
19///
20/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/sr1.rs)
21///
22/// # References:
23///
24/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
25/// Springer. ISBN 0-387-30303-0.
26#[derive(Clone, Serialize, Deserialize)]
27pub struct SR1<L, H, F> {
28    /// parameter for skipping rule
29    r: F,
30    /// Inverse Hessian
31    inv_hessian: H,
32    /// line search
33    linesearch: L,
34    /// Tolerance for the stopping criterion based on the change of the norm on the gradient
35    tol_grad: F,
36    /// Tolerance for the stopping criterion based on the change of the cost stopping criterion
37    tol_cost: F,
38}
39
40impl<L, H, F: ArgminFloat> SR1<L, H, F> {
41    /// Constructor
42    pub fn new(init_inverse_hessian: H, linesearch: L) -> Self {
43        SR1 {
44            r: F::from_f64(1e-8).unwrap(),
45            inv_hessian: init_inverse_hessian,
46            linesearch,
47            tol_grad: F::epsilon().sqrt(),
48            tol_cost: F::epsilon(),
49        }
50    }
51
52    /// Set r
53    pub fn r(mut self, r: F) -> Result<Self, Error> {
54        if r < F::from_f64(0.0).unwrap() || r > F::from_f64(1.0).unwrap() {
55            Err(ArgminError::InvalidParameter {
56                text: "SR1: r must be between 0 and 1.".to_string(),
57            }
58            .into())
59        } else {
60            self.r = r;
61            Ok(self)
62        }
63    }
64
65    /// Sets tolerance for the stopping criterion based on the change of the norm on the gradient
66    pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
67        self.tol_grad = tol_grad;
68        self
69    }
70
71    /// Sets tolerance for the stopping criterion based on the change of the cost stopping criterion
72    pub fn with_tol_cost(mut self, tol_cost: F) -> Self {
73        self.tol_cost = tol_cost;
74        self
75    }
76}
77
78impl<O, L, H, F> Solver<O> for SR1<L, H, F>
79where
80    O: ArgminOp<Output = F, Hessian = H, Float = F>,
81    O::Param: Debug
82        + Clone
83        + Default
84        + Serialize
85        + ArgminSub<O::Param, O::Param>
86        + ArgminDot<O::Param, O::Float>
87        + ArgminDot<O::Param, O::Hessian>
88        + ArgminNorm<O::Float>
89        + ArgminMul<O::Float, O::Param>,
90    O::Hessian: Debug
91        + Clone
92        + Default
93        + Serialize
94        + DeserializeOwned
95        + ArgminSub<O::Hessian, O::Hessian>
96        + ArgminDot<O::Param, O::Param>
97        + ArgminDot<O::Hessian, O::Hessian>
98        + ArgminAdd<O::Hessian, O::Hessian>
99        + ArgminMul<F, O::Hessian>,
100    L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
101    F: ArgminFloat,
102{
103    const NAME: &'static str = "SR1";
104
105    fn init(
106        &mut self,
107        op: &mut OpWrapper<O>,
108        state: &IterState<O>,
109    ) -> Result<Option<ArgminIterData<O>>, Error> {
110        let param = state.get_param();
111        let cost = op.apply(&param)?;
112        let grad = op.gradient(&param)?;
113        Ok(Some(
114            ArgminIterData::new().param(param).cost(cost).grad(grad),
115        ))
116    }
117
118    fn next_iter(
119        &mut self,
120        op: &mut OpWrapper<O>,
121        state: &IterState<O>,
122    ) -> Result<ArgminIterData<O>, Error> {
123        let param = state.get_param();
124        let cost = state.get_cost();
125        let prev_grad = if let Some(grad) = state.get_grad() {
126            grad
127        } else {
128            op.gradient(&param)?
129        };
130
131        let p = self
132            .inv_hessian
133            .dot(&prev_grad)
134            .mul(&F::from_f64(-1.0).unwrap());
135
136        self.linesearch.set_search_direction(p);
137
138        // Run solver
139        let ArgminResult {
140            operator: line_op,
141            state:
142                IterState {
143                    param: xk1,
144                    cost: next_cost,
145                    ..
146                },
147        } = Executor::new(
148            OpWrapper::new_from_wrapper(op),
149            self.linesearch.clone(),
150            param.clone(),
151        )
152        .grad(prev_grad.clone())
153        .cost(cost)
154        .ctrlc(false)
155        .run()?;
156
157        // take care of function eval counts
158        op.consume_op(line_op);
159
160        let grad = op.gradient(&xk1)?;
161        let yk = grad.sub(&prev_grad);
162
163        let sk = xk1.sub(&param);
164
165        let skmhkyk: O::Param = sk.sub(&self.inv_hessian.dot(&yk));
166        let a: O::Hessian = skmhkyk.dot(&skmhkyk);
167        let b: F = skmhkyk.dot(&yk);
168
169        let hessian_update = b.abs() >= self.r * yk.norm() * skmhkyk.norm();
170
171        // a try to see whether the skipping rule based on B_k makes any difference (seems not)
172        // let bk = self.inv_hessian.inv()?;
173        // let ykmbksk = yk.sub(&bk.dot(&sk));
174        // let tmp: f64 = sk.dot(&ykmbksk);
175        // let sksk: f64 = sk.dot(&sk);
176        // let blah: f64 = ykmbksk.dot(&ykmbksk);
177        // let hessian_update = tmp.abs() >= self.r * sksk.sqrt() * blah.sqrt();
178
179        if hessian_update {
180            self.inv_hessian = self
181                .inv_hessian
182                .add(&a.mul(&(F::from_f64(1.0).unwrap() / b)));
183        }
184
185        Ok(ArgminIterData::new()
186            .param(xk1)
187            .cost(next_cost)
188            .grad(grad)
189            .kv(make_kv!["denom" => b; "hessian_update" => hessian_update;]))
190    }
191
192    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
193        if state.get_grad().unwrap().norm() < self.tol_grad {
194            return TerminationReason::TargetPrecisionReached;
195        }
196        if (state.get_prev_cost() - state.get_cost()).abs() < self.tol_cost {
197            return TerminationReason::NoChangeInCost;
198        }
199        TerminationReason::NotTerminated
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::solver::linesearch::MoreThuenteLineSearch;
207    use crate::test_trait_impl;
208
209    type Operator = MinimalNoOperator;
210
211    test_trait_impl!(sr1, SR1<Operator, MoreThuenteLineSearch<Operator, f64>, f64>);
212
213    #[test]
214    fn test_tolerances() {
215        let linesearch: MoreThuenteLineSearch<f64, f64> =
216            MoreThuenteLineSearch::new().c(1e-4, 0.9).unwrap();
217        let init_hessian: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
218
219        let tol1: f64 = 1e-4;
220        let tol2: f64 = 1e-2;
221
222        let SR1 {
223            tol_grad: t1,
224            tol_cost: t2,
225            ..
226        } = SR1::new(init_hessian, linesearch)
227            .with_tol_grad(tol1)
228            .with_tol_cost(tol2);
229
230        assert!((t1 - tol1).abs() < std::f64::EPSILON);
231        assert!((t2 - tol2).abs() < std::f64::EPSILON);
232    }
233}