argmin/solver/quasinewton/
sr1.rs1use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18#[derive(Clone, Serialize, Deserialize)]
27pub struct SR1<L, H, F> {
28 r: F,
30 inv_hessian: H,
32 linesearch: L,
34 tol_grad: F,
36 tol_cost: F,
38}
39
40impl<L, H, F: ArgminFloat> SR1<L, H, F> {
41 pub fn new(init_inverse_hessian: H, linesearch: L) -> Self {
43 SR1 {
44 r: F::from_f64(1e-8).unwrap(),
45 inv_hessian: init_inverse_hessian,
46 linesearch,
47 tol_grad: F::epsilon().sqrt(),
48 tol_cost: F::epsilon(),
49 }
50 }
51
52 pub fn r(mut self, r: F) -> Result<Self, Error> {
54 if r < F::from_f64(0.0).unwrap() || r > F::from_f64(1.0).unwrap() {
55 Err(ArgminError::InvalidParameter {
56 text: "SR1: r must be between 0 and 1.".to_string(),
57 }
58 .into())
59 } else {
60 self.r = r;
61 Ok(self)
62 }
63 }
64
65 pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
67 self.tol_grad = tol_grad;
68 self
69 }
70
71 pub fn with_tol_cost(mut self, tol_cost: F) -> Self {
73 self.tol_cost = tol_cost;
74 self
75 }
76}
77
78impl<O, L, H, F> Solver<O> for SR1<L, H, F>
79where
80 O: ArgminOp<Output = F, Hessian = H, Float = F>,
81 O::Param: Debug
82 + Clone
83 + Default
84 + Serialize
85 + ArgminSub<O::Param, O::Param>
86 + ArgminDot<O::Param, O::Float>
87 + ArgminDot<O::Param, O::Hessian>
88 + ArgminNorm<O::Float>
89 + ArgminMul<O::Float, O::Param>,
90 O::Hessian: Debug
91 + Clone
92 + Default
93 + Serialize
94 + DeserializeOwned
95 + ArgminSub<O::Hessian, O::Hessian>
96 + ArgminDot<O::Param, O::Param>
97 + ArgminDot<O::Hessian, O::Hessian>
98 + ArgminAdd<O::Hessian, O::Hessian>
99 + ArgminMul<F, O::Hessian>,
100 L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
101 F: ArgminFloat,
102{
103 const NAME: &'static str = "SR1";
104
105 fn init(
106 &mut self,
107 op: &mut OpWrapper<O>,
108 state: &IterState<O>,
109 ) -> Result<Option<ArgminIterData<O>>, Error> {
110 let param = state.get_param();
111 let cost = op.apply(¶m)?;
112 let grad = op.gradient(¶m)?;
113 Ok(Some(
114 ArgminIterData::new().param(param).cost(cost).grad(grad),
115 ))
116 }
117
118 fn next_iter(
119 &mut self,
120 op: &mut OpWrapper<O>,
121 state: &IterState<O>,
122 ) -> Result<ArgminIterData<O>, Error> {
123 let param = state.get_param();
124 let cost = state.get_cost();
125 let prev_grad = if let Some(grad) = state.get_grad() {
126 grad
127 } else {
128 op.gradient(¶m)?
129 };
130
131 let p = self
132 .inv_hessian
133 .dot(&prev_grad)
134 .mul(&F::from_f64(-1.0).unwrap());
135
136 self.linesearch.set_search_direction(p);
137
138 let ArgminResult {
140 operator: line_op,
141 state:
142 IterState {
143 param: xk1,
144 cost: next_cost,
145 ..
146 },
147 } = Executor::new(
148 OpWrapper::new_from_wrapper(op),
149 self.linesearch.clone(),
150 param.clone(),
151 )
152 .grad(prev_grad.clone())
153 .cost(cost)
154 .ctrlc(false)
155 .run()?;
156
157 op.consume_op(line_op);
159
160 let grad = op.gradient(&xk1)?;
161 let yk = grad.sub(&prev_grad);
162
163 let sk = xk1.sub(¶m);
164
165 let skmhkyk: O::Param = sk.sub(&self.inv_hessian.dot(&yk));
166 let a: O::Hessian = skmhkyk.dot(&skmhkyk);
167 let b: F = skmhkyk.dot(&yk);
168
169 let hessian_update = b.abs() >= self.r * yk.norm() * skmhkyk.norm();
170
171 if hessian_update {
180 self.inv_hessian = self
181 .inv_hessian
182 .add(&a.mul(&(F::from_f64(1.0).unwrap() / b)));
183 }
184
185 Ok(ArgminIterData::new()
186 .param(xk1)
187 .cost(next_cost)
188 .grad(grad)
189 .kv(make_kv!["denom" => b; "hessian_update" => hessian_update;]))
190 }
191
192 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
193 if state.get_grad().unwrap().norm() < self.tol_grad {
194 return TerminationReason::TargetPrecisionReached;
195 }
196 if (state.get_prev_cost() - state.get_cost()).abs() < self.tol_cost {
197 return TerminationReason::NoChangeInCost;
198 }
199 TerminationReason::NotTerminated
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::solver::linesearch::MoreThuenteLineSearch;
207 use crate::test_trait_impl;
208
209 type Operator = MinimalNoOperator;
210
211 test_trait_impl!(sr1, SR1<Operator, MoreThuenteLineSearch<Operator, f64>, f64>);
212
213 #[test]
214 fn test_tolerances() {
215 let linesearch: MoreThuenteLineSearch<f64, f64> =
216 MoreThuenteLineSearch::new().c(1e-4, 0.9).unwrap();
217 let init_hessian: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
218
219 let tol1: f64 = 1e-4;
220 let tol2: f64 = 1e-2;
221
222 let SR1 {
223 tol_grad: t1,
224 tol_cost: t2,
225 ..
226 } = SR1::new(init_hessian, linesearch)
227 .with_tol_grad(tol1)
228 .with_tol_cost(tol2);
229
230 assert!((t1 - tol1).abs() < std::f64::EPSILON);
231 assert!((t2 - tol2).abs() < std::f64::EPSILON);
232 }
233}