argmin/solver/gaussnewton/
gaussnewton_linesearch.rs1use crate::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::default::Default;
16
17#[derive(Clone, Serialize, Deserialize)]
26pub struct GaussNewtonLS<L, F> {
27 linesearch: L,
29 tol: F,
31}
32
33impl<L, F: ArgminFloat> GaussNewtonLS<L, F> {
34 pub fn new(linesearch: L) -> Self {
36 GaussNewtonLS {
37 linesearch,
38 tol: F::epsilon().sqrt(),
39 }
40 }
41
42 pub fn with_tol(mut self, tol: F) -> Result<Self, Error> {
44 if tol <= F::from_f64(0.0).unwrap() {
45 return Err(ArgminError::InvalidParameter {
46 text: "Gauss-Newton-Linesearch: tol must be positive.".to_string(),
47 }
48 .into());
49 }
50 self.tol = tol;
51 Ok(self)
52 }
53}
54
55impl<O, L, F> Solver<O> for GaussNewtonLS<L, F>
56where
57 O: ArgminOp<Float = F>,
58 O::Param: Default
59 + std::fmt::Debug
60 + ArgminScaledSub<O::Param, O::Float, O::Param>
61 + ArgminSub<O::Param, O::Param>
62 + ArgminMul<O::Float, O::Param>,
63 O::Output: ArgminNorm<O::Float>,
64 O::Jacobian: ArgminTranspose
65 + ArgminInv<O::Jacobian>
66 + ArgminDot<O::Jacobian, O::Jacobian>
67 + ArgminDot<O::Output, O::Param>
68 + ArgminDot<O::Param, O::Param>,
69 O::Hessian: Default,
70 L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<LineSearchOP<O>>>,
71 F: ArgminFloat,
72{
73 const NAME: &'static str = "Gauss-Newton method with Linesearch";
74
75 fn next_iter(
76 &mut self,
77 op: &mut OpWrapper<O>,
78 state: &IterState<O>,
79 ) -> Result<ArgminIterData<O>, Error> {
80 let param = state.get_param();
81 let residuals = op.apply(¶m)?;
82 let jacobian = op.jacobian(¶m)?;
83 let jacobian_t = jacobian.clone().t();
84 let grad = jacobian_t.dot(&residuals);
85
86 let p = jacobian_t.dot(&jacobian).inv()?.dot(&grad);
87
88 self.linesearch
89 .set_search_direction(p.mul(&(F::from_f64(-1.0).unwrap())));
90
91 let line_op = OpWrapper::new(LineSearchOP {
93 op: op.take_op().unwrap(),
94 });
95
96 let ArgminResult {
98 operator: mut line_op,
99 state:
100 IterState {
101 param: next_param,
102 cost: next_cost,
103 ..
104 },
105 } = Executor::new(line_op, self.linesearch.clone(), param)
106 .grad(grad)
107 .cost(residuals.norm())
108 .ctrlc(false)
109 .run()?;
110
111 op.op = Some(line_op.take_op().unwrap().op);
115 op.consume_func_counts(line_op);
116
117 Ok(ArgminIterData::new().param(next_param).cost(next_cost))
118 }
119
120 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
121 if (state.get_prev_cost() - state.get_cost()).abs() < self.tol {
122 return TerminationReason::NoChangeInCost;
123 }
124 TerminationReason::NotTerminated
125 }
126}
127
128#[doc(hidden)]
129#[derive(Clone, Default, Serialize, Deserialize)]
130pub struct LineSearchOP<O> {
131 pub op: O,
132}
133
134impl<O> ArgminOp for LineSearchOP<O>
135where
136 O: ArgminOp,
137 O::Jacobian: ArgminTranspose + ArgminDot<O::Output, O::Param>,
138 O::Output: ArgminNorm<O::Float>,
139{
140 type Param = O::Param;
141 type Output = O::Float;
142 type Hessian = O::Hessian;
143 type Jacobian = O::Jacobian;
144 type Float = O::Float;
145
146 fn apply(&self, p: &Self::Param) -> Result<Self::Output, Error> {
147 Ok(self.op.apply(p)?.norm())
148 }
149
150 fn gradient(&self, p: &Self::Param) -> Result<Self::Param, Error> {
151 Ok(self.op.jacobian(p)?.t().dot(&self.op.apply(p)?))
152 }
153
154 fn hessian(&self, p: &Self::Param) -> Result<Self::Hessian, Error> {
155 self.op.hessian(p)
156 }
157
158 fn jacobian(&self, p: &Self::Param) -> Result<Self::Jacobian, Error> {
159 self.op.jacobian(p)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::solver::linesearch::MoreThuenteLineSearch;
167 use crate::test_trait_impl;
168
169 test_trait_impl!(
170 gauss_newton_linesearch_method,
171 GaussNewtonLS<MoreThuenteLineSearch<Vec<f64>, f64>, f64>
172 );
173
174 #[test]
175 fn test_tolerance() {
176 let tol1: f64 = 1e-4;
177
178 let linesearch: MoreThuenteLineSearch<Vec<f64>, f64> = MoreThuenteLineSearch::new();
179 let GaussNewtonLS { tol: t1, .. } = GaussNewtonLS::new(linesearch).with_tol(tol1).unwrap();
180
181 assert!((t1 - tol1).abs() < std::f64::EPSILON);
182 }
183}