argmin/solver/gaussnewton/
gaussnewton_method.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::default::Default;
16
17/// Gauss-Newton method
18///
19/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/gaussnewton.rs)
20///
21/// # References:
22///
23/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
24/// Springer. ISBN 0-387-30303-0.
25#[derive(Clone, Serialize, Deserialize)]
26pub struct GaussNewton<F> {
27    /// gamma
28    gamma: F,
29    /// Tolerance for the stopping criterion based on cost difference
30    tol: F,
31}
32
33impl<F: ArgminFloat> GaussNewton<F> {
34    /// Constructor
35    pub fn new() -> Self {
36        GaussNewton {
37            gamma: F::from_f64(1.0).unwrap(),
38            tol: F::epsilon().sqrt(),
39        }
40    }
41
42    /// set gamma
43    pub fn with_gamma(mut self, gamma: F) -> Result<Self, Error> {
44        if gamma <= F::from_f64(0.0).unwrap() || gamma > F::from_f64(1.0).unwrap() {
45            return Err(ArgminError::InvalidParameter {
46                text: "Gauss-Newton: gamma must be in  (0, 1].".to_string(),
47            }
48            .into());
49        }
50        self.gamma = gamma;
51        Ok(self)
52    }
53
54    /// Set tolerance for the stopping criterion based on cost difference
55    pub fn with_tol(mut self, tol: F) -> Result<Self, Error> {
56        if tol <= F::from_f64(0.0).unwrap() {
57            return Err(ArgminError::InvalidParameter {
58                text: "Gauss-Newton: tol must be positive.".to_string(),
59            }
60            .into());
61        }
62        self.tol = tol;
63        Ok(self)
64    }
65}
66
67impl<F: ArgminFloat> Default for GaussNewton<F> {
68    fn default() -> GaussNewton<F> {
69        GaussNewton::new()
70    }
71}
72
73impl<O, F> Solver<O> for GaussNewton<F>
74where
75    O: ArgminOp<Float = F>,
76    O::Param: Default
77        + ArgminScaledSub<O::Param, O::Float, O::Param>
78        + ArgminSub<O::Param, O::Param>
79        + ArgminMul<O::Float, O::Param>,
80    O::Output: ArgminNorm<O::Float>,
81    O::Jacobian: ArgminTranspose
82        + ArgminInv<O::Jacobian>
83        + ArgminDot<O::Jacobian, O::Jacobian>
84        + ArgminDot<O::Output, O::Param>
85        + ArgminDot<O::Param, O::Param>,
86    O::Hessian: Default,
87    F: ArgminFloat,
88{
89    const NAME: &'static str = "Gauss-Newton method";
90
91    fn next_iter(
92        &mut self,
93        op: &mut OpWrapper<O>,
94        state: &IterState<O>,
95    ) -> Result<ArgminIterData<O>, Error> {
96        let param = state.get_param();
97        let residuals = op.apply(&param)?;
98        let jacobian = op.jacobian(&param)?;
99
100        let p = jacobian
101            .clone()
102            .t()
103            .dot(&jacobian)
104            .inv()?
105            .dot(&jacobian.t().dot(&residuals));
106
107        let new_param = param.sub(&p.mul(&self.gamma));
108
109        Ok(ArgminIterData::new()
110            .param(new_param)
111            .cost(residuals.norm()))
112    }
113
114    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
115        if (state.get_prev_cost() - state.get_cost()).abs() < self.tol {
116            return TerminationReason::NoChangeInCost;
117        }
118        TerminationReason::NotTerminated
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::test_trait_impl;
126
127    test_trait_impl!(gauss_newton_method, GaussNewton<f64>);
128
129    #[test]
130    fn test_tolerance() {
131        let tol1: f64 = 1e-4;
132
133        let GaussNewton { tol: t, .. } = GaussNewton::new().with_tol(tol1).unwrap();
134
135        assert!((t - tol1).abs() < std::f64::EPSILON);
136    }
137
138    #[test]
139    fn test_gamma() {
140        let gamma: f64 = 0.5;
141
142        let GaussNewton { gamma: g, .. } = GaussNewton::new().with_gamma(gamma).unwrap();
143
144        assert!((g - gamma).abs() < std::f64::EPSILON);
145    }
146}