argmin/solver/quasinewton/
sr1_trustregion.rs1use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18#[derive(Clone, Serialize, Deserialize)]
27pub struct SR1TrustRegion<B, R, F> {
28 r: F,
30 init_hessian: Option<B>,
32 subproblem: R,
34 radius: F,
36 eta: F,
38 tol_grad: F,
40}
41
42impl<B, R, F: ArgminFloat> SR1TrustRegion<B, R, F> {
43 pub fn new(subproblem: R) -> Self {
45 SR1TrustRegion {
46 r: F::from_f64(1e-8).unwrap(),
47 init_hessian: None,
48 subproblem,
49 radius: F::from_f64(1.0).unwrap(),
50 eta: F::from_f64(0.5 * 1e-3).unwrap(),
51 tol_grad: F::from_f64(1e-3).unwrap(),
52 }
53 }
54
55 pub fn hessian(mut self, init_hessian: B) -> Self {
58 self.init_hessian = Some(init_hessian);
59 self
60 }
61
62 pub fn r(mut self, r: F) -> Result<Self, Error> {
64 if r <= F::from_f64(0.0).unwrap() || r >= F::from_f64(1.0).unwrap() {
65 Err(ArgminError::InvalidParameter {
66 text: "SR1: r must be in (0, 1).".to_string(),
67 }
68 .into())
69 } else {
70 self.r = r;
71 Ok(self)
72 }
73 }
74
75 pub fn radius(mut self, radius: F) -> Self {
77 self.radius = radius.abs();
78 self
79 }
80
81 pub fn eta(mut self, eta: F) -> Result<Self, Error> {
83 if eta >= F::from_f64(10e-3).unwrap() || eta <= F::from_f64(0.0).unwrap() {
84 return Err(ArgminError::InvalidParameter {
85 text: "SR1TrustRegion: eta must be in (0, 10^-3).".to_string(),
86 }
87 .into());
88 }
89 self.eta = eta;
90 Ok(self)
91 }
92
93 pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
95 self.tol_grad = tol_grad;
96 self
97 }
98}
99
100impl<O, B, R, F> Solver<O> for SR1TrustRegion<B, R, F>
101where
102 O: ArgminOp<Output = F, Hessian = B, Float = F>,
103 O::Param: Debug
104 + Clone
105 + Default
106 + Serialize
107 + ArgminSub<O::Param, O::Param>
108 + ArgminAdd<O::Param, O::Param>
109 + ArgminDot<O::Param, O::Float>
110 + ArgminDot<O::Param, O::Hessian>
111 + ArgminNorm<O::Float>
112 + ArgminZeroLike
113 + ArgminMul<F, O::Param>,
114 O::Hessian: Debug
115 + Clone
116 + Default
117 + Serialize
118 + DeserializeOwned
119 + ArgminSub<O::Hessian, O::Hessian>
120 + ArgminDot<O::Param, O::Param>
121 + ArgminDot<O::Hessian, O::Hessian>
122 + ArgminAdd<O::Hessian, O::Hessian>
123 + ArgminMul<F, O::Hessian>,
124 R: ArgminTrustRegion<F> + Solver<OpWrapper<O>>,
125 F: ArgminFloat + ArgminNorm<O::Float>,
126{
127 const NAME: &'static str = "SR1 Trust Region";
128
129 fn init(
130 &mut self,
131 op: &mut OpWrapper<O>,
132 state: &IterState<O>,
133 ) -> Result<Option<ArgminIterData<O>>, Error> {
134 let param = state.get_param();
135 let cost = op.apply(¶m)?;
136 let grad = op.gradient(¶m)?;
137 let hessian = state
138 .get_hessian()
139 .unwrap_or_else(|| op.hessian(¶m).unwrap());
140 Ok(Some(
141 ArgminIterData::new()
142 .param(param)
143 .cost(cost)
144 .grad(grad)
145 .hessian(hessian),
146 ))
147 }
148
149 fn next_iter(
150 &mut self,
151 op: &mut OpWrapper<O>,
152 state: &IterState<O>,
153 ) -> Result<ArgminIterData<O>, Error> {
154 let xk = state.get_param();
155 let cost = state.get_cost();
156 let prev_grad = state
157 .get_grad()
158 .unwrap_or_else(|| op.gradient(&xk).unwrap());
159 let hessian: O::Hessian = state.get_hessian().unwrap();
160
161 self.subproblem.set_radius(self.radius);
162
163 let ArgminResult {
164 operator: sub_op,
165 state: IterState { param: sk, .. },
166 } = Executor::new(
167 OpWrapper::new_from_wrapper(op),
168 self.subproblem.clone(),
169 xk.zero_like(),
171 )
172 .cost(cost)
173 .grad(prev_grad.clone())
174 .hessian(hessian.clone())
175 .ctrlc(false)
176 .run()?;
177
178 op.consume_op(sub_op);
179
180 let xksk = xk.add(&sk);
181 let dfk1 = op.gradient(&xksk)?;
182 let yk = dfk1.sub(&prev_grad);
183 let fk1 = op.apply(&xksk)?;
184
185 let ared = cost - fk1;
186 let tmp1: F = prev_grad.dot(&sk);
187 let tmp2: F = sk.weighted_dot(&hessian, &sk);
188 let tmp2: F = tmp2.mul(F::from_f64(0.5).unwrap());
189 let pred = -tmp1 - tmp2;
190 let ap = ared / pred;
191
192 let (xk1, fk1, dfk1) = if ap > self.eta {
193 (xksk, fk1, dfk1)
194 } else {
195 (xk, cost, prev_grad)
196 };
197
198 self.radius = if ap > F::from_f64(0.75).unwrap() {
199 if sk.norm() <= F::from_f64(0.8).unwrap() * self.radius {
200 self.radius
201 } else {
202 F::from_f64(2.0).unwrap() * self.radius
203 }
204 } else if ap <= F::from_f64(0.75).unwrap() && ap >= F::from_f64(0.1).unwrap() {
205 self.radius
206 } else {
207 F::from_f64(0.5).unwrap() * self.radius
208 };
209
210 let bksk = hessian.dot(&sk);
211 let ykbksk = yk.sub(&bksk);
212 let skykbksk: F = sk.dot(&ykbksk);
213
214 let hessian_update = skykbksk.abs() >= self.r * sk.norm() * skykbksk.norm();
215 let hessian = if hessian_update {
216 let a: O::Hessian = ykbksk.dot(&ykbksk);
217 let b: F = sk.dot(&ykbksk);
218 hessian.add(&a.mul(&(F::from_f64(1.0).unwrap() / b)))
219 } else {
220 hessian
221 };
222
223 Ok(ArgminIterData::new()
224 .param(xk1)
225 .cost(fk1)
226 .grad(dfk1)
227 .hessian(hessian)
228 .kv(make_kv!["ared" => ared;
229 "pred" => pred;
230 "ap" => ap;
231 "radius" => self.radius;
232 "hessian_update" => hessian_update;]))
233 }
234
235 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
236 if state.get_grad().unwrap().norm() < self.tol_grad {
237 return TerminationReason::TargetPrecisionReached;
238 }
239 TerminationReason::NotTerminated
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::solver::trustregion::CauchyPoint;
250 use crate::test_trait_impl;
251
252 type Operator = MinimalNoOperator;
253
254 test_trait_impl!(sr1, SR1TrustRegion<Operator, CauchyPoint<f64>, f64>);
255
256 #[test]
257 fn test_tolerances() {
258 let subproblem = CauchyPoint::new();
259
260 let tol: f64 = 1e-4;
261
262 let SR1TrustRegion { tol_grad: t, .. }: SR1TrustRegion<
263 MinimalNoOperator,
264 CauchyPoint<f64>,
265 f64,
266 > = SR1TrustRegion::new(subproblem).with_tol_grad(tol);
267
268 assert!((t - tol).abs() < std::f64::EPSILON);
269 }
270}