argmin/solver/quasinewton/
dfp.rs1use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16
17#[derive(Clone, Serialize, Deserialize)]
26pub struct DFP<L, H, F> {
27 inv_hessian: H,
29 linesearch: L,
31 tol_grad: F,
33}
34
35impl<L, H, F: ArgminFloat> DFP<L, H, F> {
36 pub fn new(init_inverse_hessian: H, linesearch: L) -> Self {
38 DFP {
39 inv_hessian: init_inverse_hessian,
40 linesearch,
41 tol_grad: F::epsilon().sqrt(),
42 }
43 }
44
45 pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
47 self.tol_grad = tol_grad;
48 self
49 }
50}
51
52impl<O, L, H, F> Solver<O> for DFP<L, H, F>
53where
54 O: ArgminOp<Output = F, Hessian = H, Float = F>,
55 O::Param: Clone
56 + Default
57 + Serialize
58 + ArgminSub<O::Param, O::Param>
59 + ArgminDot<O::Param, O::Float>
60 + ArgminDot<O::Param, O::Hessian>
61 + ArgminScaledAdd<O::Param, F, O::Param>
62 + ArgminNorm<O::Float>
63 + ArgminMul<O::Float, O::Param>
64 + ArgminTranspose,
65 O::Hessian: Clone
66 + Default
67 + Serialize
68 + DeserializeOwned
69 + ArgminSub<O::Hessian, O::Hessian>
70 + ArgminDot<O::Param, O::Param>
71 + ArgminDot<O::Hessian, O::Hessian>
72 + ArgminAdd<O::Hessian, O::Hessian>
73 + ArgminMul<F, O::Hessian>
74 + ArgminTranspose
75 + ArgminEye,
76 L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
77 F: ArgminFloat,
78{
79 const NAME: &'static str = "DFP";
80
81 fn init(
82 &mut self,
83 op: &mut OpWrapper<O>,
84 state: &IterState<O>,
85 ) -> Result<Option<ArgminIterData<O>>, Error> {
86 let param = state.get_param();
87 let cost = op.apply(¶m)?;
88 let grad = op.gradient(¶m)?;
89 Ok(Some(
90 ArgminIterData::new().param(param).cost(cost).grad(grad),
91 ))
92 }
93
94 fn next_iter(
95 &mut self,
96 op: &mut OpWrapper<O>,
97 state: &IterState<O>,
98 ) -> Result<ArgminIterData<O>, Error> {
99 let param = state.get_param();
100 let cost = state.get_cost();
101 let prev_grad = if let Some(grad) = state.get_grad() {
102 grad
103 } else {
104 op.gradient(¶m)?
105 };
106 let p = self
107 .inv_hessian
108 .dot(&prev_grad)
109 .mul(&F::from_f64(-1.0).unwrap());
110
111 self.linesearch.set_search_direction(p);
112
113 let ArgminResult {
114 operator: line_op,
115 state:
116 IterState {
117 param: xk1,
118 cost: next_cost,
119 ..
120 },
121 } = Executor::new(
122 OpWrapper::new_from_wrapper(op),
123 self.linesearch.clone(),
124 param.clone(),
125 )
126 .grad(prev_grad.clone())
127 .cost(cost)
128 .ctrlc(false)
129 .run()?;
130
131 op.consume_op(line_op);
133
134 let grad = op.gradient(&xk1)?;
135 let yk = grad.sub(&prev_grad);
136
137 let sk = xk1.sub(¶m);
138
139 let yksk: F = yk.dot(&sk);
140
141 let sksk: O::Hessian = sk.dot(&sk);
142
143 let tmp3: O::Param = self.inv_hessian.dot(&yk);
144 let tmp4: F = tmp3.dot(&yk);
145 let tmp3: O::Hessian = tmp3.dot(&tmp3);
146 let tmp3: O::Hessian = tmp3.mul(&(F::from_f64(1.0).unwrap() / tmp4));
147
148 self.inv_hessian = self
149 .inv_hessian
150 .sub(&tmp3)
151 .add(&sksk.mul(&(F::from_f64(1.0).unwrap() / yksk)));
152
153 Ok(ArgminIterData::new().param(xk1).cost(next_cost).grad(grad))
154 }
155
156 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
157 if state.get_grad().unwrap().norm() < self.tol_grad {
158 return TerminationReason::TargetPrecisionReached;
159 }
160 TerminationReason::NotTerminated
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::solver::linesearch::MoreThuenteLineSearch;
168 use crate::test_trait_impl;
169
170 type Operator = MinimalNoOperator;
171
172 test_trait_impl!(dfp, DFP<Operator, MoreThuenteLineSearch<Operator, f64>, f64>);
173
174 #[test]
175 fn test_tolerances() {
176 let linesearch: MoreThuenteLineSearch<f64, f64> =
177 MoreThuenteLineSearch::new().c(1e-4, 0.9).unwrap();
178 let init_hessian: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
179
180 let tol: f64 = 1e-4;
181
182 let DFP { tol_grad: t, .. } = DFP::new(init_hessian, linesearch).with_tol_grad(tol);
183
184 assert!((t - tol).abs() < std::f64::EPSILON);
185 }
186}