argmin/solver/quasinewton/
lbfgs.rs1use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::collections::VecDeque;
17use std::fmt::Debug;
18
19#[derive(Clone, Serialize, Deserialize)]
30pub struct LBFGS<L, P, F> {
31 linesearch: L,
33 m: usize,
35 s: VecDeque<P>,
37 y: VecDeque<P>,
39 tol_grad: F,
41 tol_cost: F,
43}
44
45impl<L, P, F: ArgminFloat> LBFGS<L, P, F> {
46 pub fn new(linesearch: L, m: usize) -> Self {
48 LBFGS {
49 linesearch,
50 m,
51 s: VecDeque::with_capacity(m),
52 y: VecDeque::with_capacity(m),
53 tol_grad: F::epsilon().sqrt(),
54 tol_cost: F::epsilon(),
55 }
56 }
57
58 pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
60 self.tol_grad = tol_grad;
61 self
62 }
63
64 pub fn with_tol_cost(mut self, tol_cost: F) -> Self {
66 self.tol_cost = tol_cost;
67 self
68 }
69}
70
71impl<O, L, P, F> Solver<O> for LBFGS<L, P, F>
72where
73 O: ArgminOp<Param = P, Output = F, Float = F>,
74 O::Param: Clone
75 + Serialize
76 + DeserializeOwned
77 + Debug
78 + Default
79 + ArgminSub<O::Param, O::Param>
80 + ArgminAdd<O::Param, O::Param>
81 + ArgminDot<O::Param, O::Float>
82 + ArgminScaledAdd<O::Param, O::Float, O::Param>
83 + ArgminNorm<O::Float>
84 + ArgminMul<O::Float, O::Param>,
85 O::Hessian: Clone + Default + Serialize + DeserializeOwned,
86 L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
87 F: ArgminFloat,
88{
89 const NAME: &'static str = "L-BFGS";
90
91 fn init(
92 &mut self,
93 op: &mut OpWrapper<O>,
94 state: &IterState<O>,
95 ) -> Result<Option<ArgminIterData<O>>, Error> {
96 let param = state.get_param();
97 let cost = op.apply(¶m)?;
98 let grad = op.gradient(¶m)?;
99 Ok(Some(
100 ArgminIterData::new().param(param).cost(cost).grad(grad),
101 ))
102 }
103
104 fn next_iter(
105 &mut self,
106 op: &mut OpWrapper<O>,
107 state: &IterState<O>,
108 ) -> Result<ArgminIterData<O>, Error> {
109 let param = state.get_param();
110 let cur_cost = state.get_cost();
111 let prev_grad = state.get_grad().unwrap();
112 let gamma: F = if let (Some(ref sk), Some(ref yk)) = (self.s.back(), self.y.back()) {
115 sk.dot(*yk) / yk.dot(*yk)
116 } else {
117 F::from_f64(1.0).unwrap()
118 };
119
120 let mut q = prev_grad.clone();
122 let cur_m = self.s.len();
123 let mut alpha: Vec<F> = vec![F::from_f64(0.0).unwrap(); cur_m];
124 let mut rho: Vec<F> = vec![F::from_f64(0.0).unwrap(); cur_m];
125 for (i, (ref sk, ref yk)) in self.s.iter().rev().zip(self.y.iter().rev()).enumerate() {
126 let sk = *sk;
127 let yk = *yk;
128 let yksk: F = yk.dot(sk);
129 let rho_t = F::from_f64(1.0).unwrap() / yksk;
130 let skq: F = sk.dot(&q);
131 let alpha_t = skq.mul(rho_t);
132 q = q.sub(&yk.mul(&alpha_t));
133 rho[cur_m - i - 1] = rho_t;
134 alpha[cur_m - i - 1] = alpha_t;
135 }
136 let mut r = q.mul(&gamma);
137 for (i, (ref sk, ref yk)) in self.s.iter().zip(self.y.iter()).enumerate() {
138 let sk = *sk;
139 let yk = *yk;
140 let beta = yk.dot(&r).mul(rho[i]);
141 r = r.add(&sk.mul(&(alpha[i] - beta)));
142 }
143
144 self.linesearch
145 .set_search_direction(r.mul(&F::from_f64(-1.0).unwrap()));
146
147 let ArgminResult {
149 operator: line_op,
150 state:
151 IterState {
152 param: xk1,
153 cost: next_cost,
154 ..
155 },
156 } = Executor::new(
157 OpWrapper::new_from_wrapper(op),
158 self.linesearch.clone(),
159 param.clone(),
160 )
161 .grad(prev_grad.clone())
162 .cost(cur_cost)
163 .ctrlc(false)
164 .run()?;
165
166 op.consume_op(line_op);
168
169 if state.get_iter() >= self.m as u64 {
170 self.s.pop_front();
171 self.y.pop_front();
172 }
173
174 let grad = op.gradient(&xk1)?;
175
176 self.s.push_back(xk1.sub(¶m));
177 self.y.push_back(grad.sub(&prev_grad));
178
179 Ok(ArgminIterData::new()
180 .param(xk1)
181 .cost(next_cost)
182 .grad(grad)
183 .kv(make_kv!("gamma" => gamma;)))
184 }
185
186 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
187 if state.get_grad().unwrap().norm() < self.tol_grad {
188 return TerminationReason::TargetPrecisionReached;
189 }
190 if (state.get_prev_cost() - state.get_cost()).abs() < self.tol_cost {
191 return TerminationReason::NoChangeInCost;
192 }
193 TerminationReason::NotTerminated
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::solver::linesearch::MoreThuenteLineSearch;
201 use crate::test_trait_impl;
202
203 type Operator = MinimalNoOperator;
204
205 test_trait_impl!(lbfgs, LBFGS<Operator, MoreThuenteLineSearch<Operator, f64>, f64>);
206
207 #[test]
208 fn test_tolerances() {
209 let linesearch: MoreThuenteLineSearch<f64, f64> =
210 MoreThuenteLineSearch::new().c(1e-4, 0.9).unwrap();
211
212 let tol1 = 1e-4;
213 let tol2 = 1e-2;
214
215 let LBFGS {
216 tol_grad: t1,
217 tol_cost: t2,
218 ..
219 }: LBFGS<_, f64, f64> = LBFGS::new(linesearch, 7)
220 .with_tol_grad(tol1)
221 .with_tol_cost(tol2);
222
223 assert!((t1 - tol1).abs() < std::f64::EPSILON);
224 assert!((t2 - tol2).abs() < std::f64::EPSILON);
225 }
226}