argmin/solver/gaussnewton/
gaussnewton_method.rs1use crate::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::default::Default;
16
17#[derive(Clone, Serialize, Deserialize)]
26pub struct GaussNewton<F> {
27 gamma: F,
29 tol: F,
31}
32
33impl<F: ArgminFloat> GaussNewton<F> {
34 pub fn new() -> Self {
36 GaussNewton {
37 gamma: F::from_f64(1.0).unwrap(),
38 tol: F::epsilon().sqrt(),
39 }
40 }
41
42 pub fn with_gamma(mut self, gamma: F) -> Result<Self, Error> {
44 if gamma <= F::from_f64(0.0).unwrap() || gamma > F::from_f64(1.0).unwrap() {
45 return Err(ArgminError::InvalidParameter {
46 text: "Gauss-Newton: gamma must be in (0, 1].".to_string(),
47 }
48 .into());
49 }
50 self.gamma = gamma;
51 Ok(self)
52 }
53
54 pub fn with_tol(mut self, tol: F) -> Result<Self, Error> {
56 if tol <= F::from_f64(0.0).unwrap() {
57 return Err(ArgminError::InvalidParameter {
58 text: "Gauss-Newton: tol must be positive.".to_string(),
59 }
60 .into());
61 }
62 self.tol = tol;
63 Ok(self)
64 }
65}
66
67impl<F: ArgminFloat> Default for GaussNewton<F> {
68 fn default() -> GaussNewton<F> {
69 GaussNewton::new()
70 }
71}
72
73impl<O, F> Solver<O> for GaussNewton<F>
74where
75 O: ArgminOp<Float = F>,
76 O::Param: Default
77 + ArgminScaledSub<O::Param, O::Float, O::Param>
78 + ArgminSub<O::Param, O::Param>
79 + ArgminMul<O::Float, O::Param>,
80 O::Output: ArgminNorm<O::Float>,
81 O::Jacobian: ArgminTranspose
82 + ArgminInv<O::Jacobian>
83 + ArgminDot<O::Jacobian, O::Jacobian>
84 + ArgminDot<O::Output, O::Param>
85 + ArgminDot<O::Param, O::Param>,
86 O::Hessian: Default,
87 F: ArgminFloat,
88{
89 const NAME: &'static str = "Gauss-Newton method";
90
91 fn next_iter(
92 &mut self,
93 op: &mut OpWrapper<O>,
94 state: &IterState<O>,
95 ) -> Result<ArgminIterData<O>, Error> {
96 let param = state.get_param();
97 let residuals = op.apply(¶m)?;
98 let jacobian = op.jacobian(¶m)?;
99
100 let p = jacobian
101 .clone()
102 .t()
103 .dot(&jacobian)
104 .inv()?
105 .dot(&jacobian.t().dot(&residuals));
106
107 let new_param = param.sub(&p.mul(&self.gamma));
108
109 Ok(ArgminIterData::new()
110 .param(new_param)
111 .cost(residuals.norm()))
112 }
113
114 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
115 if (state.get_prev_cost() - state.get_cost()).abs() < self.tol {
116 return TerminationReason::NoChangeInCost;
117 }
118 TerminationReason::NotTerminated
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::test_trait_impl;
126
127 test_trait_impl!(gauss_newton_method, GaussNewton<f64>);
128
129 #[test]
130 fn test_tolerance() {
131 let tol1: f64 = 1e-4;
132
133 let GaussNewton { tol: t, .. } = GaussNewton::new().with_tol(tol1).unwrap();
134
135 assert!((t - tol1).abs() < std::f64::EPSILON);
136 }
137
138 #[test]
139 fn test_gamma() {
140 let gamma: f64 = 0.5;
141
142 let GaussNewton { gamma: g, .. } = GaussNewton::new().with_gamma(gamma).unwrap();
143
144 assert!((g - gamma).abs() < std::f64::EPSILON);
145 }
146}