argmin/solver/quasinewton/
bfgs.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18/// BFGS method
19///
20/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/bfgs.rs)
21///
22/// # References:
23///
24/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
25/// Springer. ISBN 0-387-30303-0.
26#[derive(Clone, Serialize, Deserialize)]
27pub struct BFGS<L, H, F> {
28    /// Inverse Hessian
29    inv_hessian: H,
30    /// line search
31    linesearch: L,
32    /// Tolerance for the stopping criterion based on the change of the norm on the gradient
33    tol_grad: F,
34    /// Tolerance for the stopping criterion based on the change of the cost stopping criterion
35    tol_cost: F,
36}
37
38impl<L, H, F: ArgminFloat> BFGS<L, H, F> {
39    /// Constructor
40    pub fn new(init_inverse_hessian: H, linesearch: L) -> Self {
41        BFGS {
42            inv_hessian: init_inverse_hessian,
43            linesearch,
44            tol_grad: F::epsilon().sqrt(),
45            tol_cost: F::epsilon(),
46        }
47    }
48
49    /// Sets tolerance for the stopping criterion based on the change of the norm on the gradient
50    pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
51        self.tol_grad = tol_grad;
52        self
53    }
54
55    /// Sets tolerance for the stopping criterion based on the change of the cost stopping criterion
56    pub fn with_tol_cost(mut self, tol_cost: F) -> Self {
57        self.tol_cost = tol_cost;
58        self
59    }
60}
61
62impl<O, L, H, F> Solver<O> for BFGS<L, H, F>
63where
64    O: ArgminOp<Output = F, Hessian = H, Float = F>,
65    O::Param: Debug
66        + Default
67        + ArgminSub<O::Param, O::Param>
68        + ArgminDot<O::Param, O::Float>
69        + ArgminDot<O::Param, O::Hessian>
70        + ArgminScaledAdd<O::Param, O::Float, O::Param>
71        + ArgminNorm<O::Float>
72        + ArgminMul<O::Float, O::Param>,
73    O::Hessian: Clone
74        + Default
75        + Debug
76        + Serialize
77        + DeserializeOwned
78        + ArgminSub<O::Hessian, O::Hessian>
79        + ArgminDot<O::Param, O::Param>
80        + ArgminDot<O::Hessian, O::Hessian>
81        + ArgminAdd<O::Hessian, O::Hessian>
82        + ArgminMul<O::Float, O::Hessian>
83        + ArgminTranspose
84        + ArgminEye,
85    L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
86    F: ArgminFloat,
87{
88    const NAME: &'static str = "BFGS";
89
90    fn init(
91        &mut self,
92        op: &mut OpWrapper<O>,
93        state: &IterState<O>,
94    ) -> Result<Option<ArgminIterData<O>>, Error> {
95        let param = state.get_param();
96        let cost = op.apply(&param)?;
97        let grad = op.gradient(&param)?;
98        Ok(Some(
99            ArgminIterData::new().param(param).cost(cost).grad(grad),
100        ))
101    }
102
103    fn next_iter(
104        &mut self,
105        op: &mut OpWrapper<O>,
106        state: &IterState<O>,
107    ) -> Result<ArgminIterData<O>, Error> {
108        let param = state.get_param();
109        let cur_cost = state.get_cost();
110        let prev_grad = state.get_grad().unwrap();
111
112        let p = self
113            .inv_hessian
114            .dot(&prev_grad)
115            .mul(&F::from_f64(-1.0).unwrap());
116
117        self.linesearch.set_search_direction(p);
118
119        // Run solver
120        let ArgminResult {
121            operator: line_op,
122            state:
123                IterState {
124                    param: xk1,
125                    cost: next_cost,
126                    ..
127                },
128        } = Executor::new(
129            OpWrapper::new_from_wrapper(op),
130            self.linesearch.clone(),
131            param.clone(),
132        )
133        .grad(prev_grad.clone())
134        .cost(cur_cost)
135        .ctrlc(false)
136        .run()?;
137
138        // take care of function eval counts
139        op.consume_op(line_op);
140
141        let grad = op.gradient(&xk1)?;
142
143        let yk = grad.sub(&prev_grad);
144
145        let sk = xk1.sub(&param);
146
147        let yksk: F = yk.dot(&sk);
148        let rhok = F::from_f64(1.0).unwrap() / yksk;
149
150        let e = self.inv_hessian.eye_like();
151        let mat1: O::Hessian = sk.dot(&yk);
152        let mat1 = mat1.mul(&rhok);
153
154        let mat2 = mat1.clone().t();
155
156        let tmp1 = e.sub(&mat1);
157        let tmp2 = e.sub(&mat2);
158
159        let sksk: O::Hessian = sk.dot(&sk);
160        let sksk = sksk.mul(&rhok);
161
162        // if state.get_iter() == 0 {
163        //     let ykyk: f64 = yk.dot(&yk);
164        //     self.inv_hessian = self.inv_hessian.eye_like().mul(&(yksk / ykyk));
165        //     println!("{:?}", self.inv_hessian);
166        // }
167
168        self.inv_hessian = tmp1.dot(&self.inv_hessian.dot(&tmp2)).add(&sksk);
169
170        Ok(ArgminIterData::new().param(xk1).cost(next_cost).grad(grad))
171    }
172
173    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
174        if state.get_grad().unwrap().norm() < self.tol_grad {
175            return TerminationReason::TargetPrecisionReached;
176        }
177        if (state.get_prev_cost() - state.get_cost()).abs() < self.tol_cost {
178            return TerminationReason::NoChangeInCost;
179        }
180        TerminationReason::NotTerminated
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::solver::linesearch::MoreThuenteLineSearch;
188    use crate::test_trait_impl;
189
190    type Operator = MinimalNoOperator;
191
192    test_trait_impl!(bfgs, BFGS<Operator, MoreThuenteLineSearch<Operator, f64>, f64>);
193
194    #[test]
195    fn test_tolerances() {
196        let linesearch: MoreThuenteLineSearch<f64, f64> =
197            MoreThuenteLineSearch::new().c(1e-4, 0.9).unwrap();
198        let init_hessian: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
199
200        let tol1: f64 = 1e-4;
201        let tol2: f64 = 1e-2;
202
203        let BFGS {
204            tol_grad: t1,
205            tol_cost: t2,
206            ..
207        } = BFGS::new(init_hessian, linesearch)
208            .with_tol_grad(tol1)
209            .with_tol_cost(tol2);
210
211        assert!((t1 - tol1).abs() < std::f64::EPSILON);
212        assert!((t2 - tol2).abs() < std::f64::EPSILON);
213    }
214}