argmin/solver/quasinewton/
lbfgs.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::collections::VecDeque;
17use std::fmt::Debug;
18
19/// L-BFGS method
20///
21/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/lbfgs.rs)
22///
23/// TODO: Implement compact representation of BFGS updating (Nocedal/Wright p.230)
24///
25/// # References:
26///
27/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
28/// Springer. ISBN 0-387-30303-0.
29#[derive(Clone, Serialize, Deserialize)]
30pub struct LBFGS<L, P, F> {
31    /// line search
32    linesearch: L,
33    /// m
34    m: usize,
35    /// s_{k-1}
36    s: VecDeque<P>,
37    /// y_{k-1}
38    y: VecDeque<P>,
39    /// Tolerance for the stopping criterion based on the change of the norm on the gradient
40    tol_grad: F,
41    /// Tolerance for the stopping criterion based on the change of the cost stopping criterion
42    tol_cost: F,
43}
44
45impl<L, P, F: ArgminFloat> LBFGS<L, P, F> {
46    /// Constructor
47    pub fn new(linesearch: L, m: usize) -> Self {
48        LBFGS {
49            linesearch,
50            m,
51            s: VecDeque::with_capacity(m),
52            y: VecDeque::with_capacity(m),
53            tol_grad: F::epsilon().sqrt(),
54            tol_cost: F::epsilon(),
55        }
56    }
57
58    /// Sets tolerance for the stopping criterion based on the change of the norm on the gradient
59    pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
60        self.tol_grad = tol_grad;
61        self
62    }
63
64    /// Sets tolerance for the stopping criterion based on the change of the cost stopping criterion
65    pub fn with_tol_cost(mut self, tol_cost: F) -> Self {
66        self.tol_cost = tol_cost;
67        self
68    }
69}
70
71impl<O, L, P, F> Solver<O> for LBFGS<L, P, F>
72where
73    O: ArgminOp<Param = P, Output = F, Float = F>,
74    O::Param: Clone
75        + Serialize
76        + DeserializeOwned
77        + Debug
78        + Default
79        + ArgminSub<O::Param, O::Param>
80        + ArgminAdd<O::Param, O::Param>
81        + ArgminDot<O::Param, O::Float>
82        + ArgminScaledAdd<O::Param, O::Float, O::Param>
83        + ArgminNorm<O::Float>
84        + ArgminMul<O::Float, O::Param>,
85    O::Hessian: Clone + Default + Serialize + DeserializeOwned,
86    L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
87    F: ArgminFloat,
88{
89    const NAME: &'static str = "L-BFGS";
90
91    fn init(
92        &mut self,
93        op: &mut OpWrapper<O>,
94        state: &IterState<O>,
95    ) -> Result<Option<ArgminIterData<O>>, Error> {
96        let param = state.get_param();
97        let cost = op.apply(&param)?;
98        let grad = op.gradient(&param)?;
99        Ok(Some(
100            ArgminIterData::new().param(param).cost(cost).grad(grad),
101        ))
102    }
103
104    fn next_iter(
105        &mut self,
106        op: &mut OpWrapper<O>,
107        state: &IterState<O>,
108    ) -> Result<ArgminIterData<O>, Error> {
109        let param = state.get_param();
110        let cur_cost = state.get_cost();
111        let prev_grad = state.get_grad().unwrap();
112        // .unwrap_or_else(|| op.gradient(&param).unwrap());
113
114        let gamma: F = if let (Some(ref sk), Some(ref yk)) = (self.s.back(), self.y.back()) {
115            sk.dot(*yk) / yk.dot(*yk)
116        } else {
117            F::from_f64(1.0).unwrap()
118        };
119
120        // L-BFGS two-loop recursion
121        let mut q = prev_grad.clone();
122        let cur_m = self.s.len();
123        let mut alpha: Vec<F> = vec![F::from_f64(0.0).unwrap(); cur_m];
124        let mut rho: Vec<F> = vec![F::from_f64(0.0).unwrap(); cur_m];
125        for (i, (ref sk, ref yk)) in self.s.iter().rev().zip(self.y.iter().rev()).enumerate() {
126            let sk = *sk;
127            let yk = *yk;
128            let yksk: F = yk.dot(sk);
129            let rho_t = F::from_f64(1.0).unwrap() / yksk;
130            let skq: F = sk.dot(&q);
131            let alpha_t = skq.mul(rho_t);
132            q = q.sub(&yk.mul(&alpha_t));
133            rho[cur_m - i - 1] = rho_t;
134            alpha[cur_m - i - 1] = alpha_t;
135        }
136        let mut r = q.mul(&gamma);
137        for (i, (ref sk, ref yk)) in self.s.iter().zip(self.y.iter()).enumerate() {
138            let sk = *sk;
139            let yk = *yk;
140            let beta = yk.dot(&r).mul(rho[i]);
141            r = r.add(&sk.mul(&(alpha[i] - beta)));
142        }
143
144        self.linesearch
145            .set_search_direction(r.mul(&F::from_f64(-1.0).unwrap()));
146
147        // Run solver
148        let ArgminResult {
149            operator: line_op,
150            state:
151                IterState {
152                    param: xk1,
153                    cost: next_cost,
154                    ..
155                },
156        } = Executor::new(
157            OpWrapper::new_from_wrapper(op),
158            self.linesearch.clone(),
159            param.clone(),
160        )
161        .grad(prev_grad.clone())
162        .cost(cur_cost)
163        .ctrlc(false)
164        .run()?;
165
166        // take back operator and take care of function evaluation counts
167        op.consume_op(line_op);
168
169        if state.get_iter() >= self.m as u64 {
170            self.s.pop_front();
171            self.y.pop_front();
172        }
173
174        let grad = op.gradient(&xk1)?;
175
176        self.s.push_back(xk1.sub(&param));
177        self.y.push_back(grad.sub(&prev_grad));
178
179        Ok(ArgminIterData::new()
180            .param(xk1)
181            .cost(next_cost)
182            .grad(grad)
183            .kv(make_kv!("gamma" => gamma;)))
184    }
185
186    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
187        if state.get_grad().unwrap().norm() < self.tol_grad {
188            return TerminationReason::TargetPrecisionReached;
189        }
190        if (state.get_prev_cost() - state.get_cost()).abs() < self.tol_cost {
191            return TerminationReason::NoChangeInCost;
192        }
193        TerminationReason::NotTerminated
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::solver::linesearch::MoreThuenteLineSearch;
201    use crate::test_trait_impl;
202
203    type Operator = MinimalNoOperator;
204
205    test_trait_impl!(lbfgs, LBFGS<Operator, MoreThuenteLineSearch<Operator, f64>, f64>);
206
207    #[test]
208    fn test_tolerances() {
209        let linesearch: MoreThuenteLineSearch<f64, f64> =
210            MoreThuenteLineSearch::new().c(1e-4, 0.9).unwrap();
211
212        let tol1 = 1e-4;
213        let tol2 = 1e-2;
214
215        let LBFGS {
216            tol_grad: t1,
217            tol_cost: t2,
218            ..
219        }: LBFGS<_, f64, f64> = LBFGS::new(linesearch, 7)
220            .with_tol_grad(tol1)
221            .with_tol_cost(tol2);
222
223        assert!((t1 - tol1).abs() < std::f64::EPSILON);
224        assert!((t2 - tol2).abs() < std::f64::EPSILON);
225    }
226}