argmin/solver/quasinewton/
bfgs.rs1use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18#[derive(Clone, Serialize, Deserialize)]
27pub struct BFGS<L, H, F> {
28 inv_hessian: H,
30 linesearch: L,
32 tol_grad: F,
34 tol_cost: F,
36}
37
38impl<L, H, F: ArgminFloat> BFGS<L, H, F> {
39 pub fn new(init_inverse_hessian: H, linesearch: L) -> Self {
41 BFGS {
42 inv_hessian: init_inverse_hessian,
43 linesearch,
44 tol_grad: F::epsilon().sqrt(),
45 tol_cost: F::epsilon(),
46 }
47 }
48
49 pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
51 self.tol_grad = tol_grad;
52 self
53 }
54
55 pub fn with_tol_cost(mut self, tol_cost: F) -> Self {
57 self.tol_cost = tol_cost;
58 self
59 }
60}
61
62impl<O, L, H, F> Solver<O> for BFGS<L, H, F>
63where
64 O: ArgminOp<Output = F, Hessian = H, Float = F>,
65 O::Param: Debug
66 + Default
67 + ArgminSub<O::Param, O::Param>
68 + ArgminDot<O::Param, O::Float>
69 + ArgminDot<O::Param, O::Hessian>
70 + ArgminScaledAdd<O::Param, O::Float, O::Param>
71 + ArgminNorm<O::Float>
72 + ArgminMul<O::Float, O::Param>,
73 O::Hessian: Clone
74 + Default
75 + Debug
76 + Serialize
77 + DeserializeOwned
78 + ArgminSub<O::Hessian, O::Hessian>
79 + ArgminDot<O::Param, O::Param>
80 + ArgminDot<O::Hessian, O::Hessian>
81 + ArgminAdd<O::Hessian, O::Hessian>
82 + ArgminMul<O::Float, O::Hessian>
83 + ArgminTranspose
84 + ArgminEye,
85 L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
86 F: ArgminFloat,
87{
88 const NAME: &'static str = "BFGS";
89
90 fn init(
91 &mut self,
92 op: &mut OpWrapper<O>,
93 state: &IterState<O>,
94 ) -> Result<Option<ArgminIterData<O>>, Error> {
95 let param = state.get_param();
96 let cost = op.apply(¶m)?;
97 let grad = op.gradient(¶m)?;
98 Ok(Some(
99 ArgminIterData::new().param(param).cost(cost).grad(grad),
100 ))
101 }
102
103 fn next_iter(
104 &mut self,
105 op: &mut OpWrapper<O>,
106 state: &IterState<O>,
107 ) -> Result<ArgminIterData<O>, Error> {
108 let param = state.get_param();
109 let cur_cost = state.get_cost();
110 let prev_grad = state.get_grad().unwrap();
111
112 let p = self
113 .inv_hessian
114 .dot(&prev_grad)
115 .mul(&F::from_f64(-1.0).unwrap());
116
117 self.linesearch.set_search_direction(p);
118
119 let ArgminResult {
121 operator: line_op,
122 state:
123 IterState {
124 param: xk1,
125 cost: next_cost,
126 ..
127 },
128 } = Executor::new(
129 OpWrapper::new_from_wrapper(op),
130 self.linesearch.clone(),
131 param.clone(),
132 )
133 .grad(prev_grad.clone())
134 .cost(cur_cost)
135 .ctrlc(false)
136 .run()?;
137
138 op.consume_op(line_op);
140
141 let grad = op.gradient(&xk1)?;
142
143 let yk = grad.sub(&prev_grad);
144
145 let sk = xk1.sub(¶m);
146
147 let yksk: F = yk.dot(&sk);
148 let rhok = F::from_f64(1.0).unwrap() / yksk;
149
150 let e = self.inv_hessian.eye_like();
151 let mat1: O::Hessian = sk.dot(&yk);
152 let mat1 = mat1.mul(&rhok);
153
154 let mat2 = mat1.clone().t();
155
156 let tmp1 = e.sub(&mat1);
157 let tmp2 = e.sub(&mat2);
158
159 let sksk: O::Hessian = sk.dot(&sk);
160 let sksk = sksk.mul(&rhok);
161
162 self.inv_hessian = tmp1.dot(&self.inv_hessian.dot(&tmp2)).add(&sksk);
169
170 Ok(ArgminIterData::new().param(xk1).cost(next_cost).grad(grad))
171 }
172
173 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
174 if state.get_grad().unwrap().norm() < self.tol_grad {
175 return TerminationReason::TargetPrecisionReached;
176 }
177 if (state.get_prev_cost() - state.get_cost()).abs() < self.tol_cost {
178 return TerminationReason::NoChangeInCost;
179 }
180 TerminationReason::NotTerminated
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::solver::linesearch::MoreThuenteLineSearch;
188 use crate::test_trait_impl;
189
190 type Operator = MinimalNoOperator;
191
192 test_trait_impl!(bfgs, BFGS<Operator, MoreThuenteLineSearch<Operator, f64>, f64>);
193
194 #[test]
195 fn test_tolerances() {
196 let linesearch: MoreThuenteLineSearch<f64, f64> =
197 MoreThuenteLineSearch::new().c(1e-4, 0.9).unwrap();
198 let init_hessian: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
199
200 let tol1: f64 = 1e-4;
201 let tol2: f64 = 1e-2;
202
203 let BFGS {
204 tol_grad: t1,
205 tol_cost: t2,
206 ..
207 } = BFGS::new(init_hessian, linesearch)
208 .with_tol_grad(tol1)
209 .with_tol_cost(tol2);
210
211 assert!((t1 - tol1).abs() < std::f64::EPSILON);
212 assert!((t2 - tol2).abs() < std::f64::EPSILON);
213 }
214}