argmin/solver/gaussnewton/
gaussnewton_linesearch.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::default::Default;
16
17/// Gauss-Newton method with linesearch
18///
19/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/gaussnewton_linesearch.rs)
20///
21/// # References:
22///
23/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
24/// Springer. ISBN 0-387-30303-0.
25#[derive(Clone, Serialize, Deserialize)]
26pub struct GaussNewtonLS<L, F> {
27    /// linesearch
28    linesearch: L,
29    /// Tolerance for the stopping criterion based on cost difference
30    tol: F,
31}
32
33impl<L, F: ArgminFloat> GaussNewtonLS<L, F> {
34    /// Constructor
35    pub fn new(linesearch: L) -> Self {
36        GaussNewtonLS {
37            linesearch,
38            tol: F::epsilon().sqrt(),
39        }
40    }
41
42    /// Set tolerance for the stopping criterion based on cost difference
43    pub fn with_tol(mut self, tol: F) -> Result<Self, Error> {
44        if tol <= F::from_f64(0.0).unwrap() {
45            return Err(ArgminError::InvalidParameter {
46                text: "Gauss-Newton-Linesearch: tol must be positive.".to_string(),
47            }
48            .into());
49        }
50        self.tol = tol;
51        Ok(self)
52    }
53}
54
55impl<O, L, F> Solver<O> for GaussNewtonLS<L, F>
56where
57    O: ArgminOp<Float = F>,
58    O::Param: Default
59        + std::fmt::Debug
60        + ArgminScaledSub<O::Param, O::Float, O::Param>
61        + ArgminSub<O::Param, O::Param>
62        + ArgminMul<O::Float, O::Param>,
63    O::Output: ArgminNorm<O::Float>,
64    O::Jacobian: ArgminTranspose
65        + ArgminInv<O::Jacobian>
66        + ArgminDot<O::Jacobian, O::Jacobian>
67        + ArgminDot<O::Output, O::Param>
68        + ArgminDot<O::Param, O::Param>,
69    O::Hessian: Default,
70    L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<LineSearchOP<O>>>,
71    F: ArgminFloat,
72{
73    const NAME: &'static str = "Gauss-Newton method with Linesearch";
74
75    fn next_iter(
76        &mut self,
77        op: &mut OpWrapper<O>,
78        state: &IterState<O>,
79    ) -> Result<ArgminIterData<O>, Error> {
80        let param = state.get_param();
81        let residuals = op.apply(&param)?;
82        let jacobian = op.jacobian(&param)?;
83        let jacobian_t = jacobian.clone().t();
84        let grad = jacobian_t.dot(&residuals);
85
86        let p = jacobian_t.dot(&jacobian).inv()?.dot(&grad);
87
88        self.linesearch
89            .set_search_direction(p.mul(&(F::from_f64(-1.0).unwrap())));
90
91        // create operator for linesearch
92        let line_op = OpWrapper::new(LineSearchOP {
93            op: op.take_op().unwrap(),
94        });
95
96        // perform linesearch
97        let ArgminResult {
98            operator: mut line_op,
99            state:
100                IterState {
101                    param: next_param,
102                    cost: next_cost,
103                    ..
104                },
105        } = Executor::new(line_op, self.linesearch.clone(), param)
106            .grad(grad)
107            .cost(residuals.norm())
108            .ctrlc(false)
109            .run()?;
110
111        // Here we cannot use `consume_op` because the operator we need is hidden inside a
112        // `LineSearchOP` hidden inside a `OpWrapper`. Therefore we have to split this in two
113        // separate tasks: first getting the operator, then dealing with the function counts.
114        op.op = Some(line_op.take_op().unwrap().op);
115        op.consume_func_counts(line_op);
116
117        Ok(ArgminIterData::new().param(next_param).cost(next_cost))
118    }
119
120    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
121        if (state.get_prev_cost() - state.get_cost()).abs() < self.tol {
122            return TerminationReason::NoChangeInCost;
123        }
124        TerminationReason::NotTerminated
125    }
126}
127
128#[doc(hidden)]
129#[derive(Clone, Default, Serialize, Deserialize)]
130pub struct LineSearchOP<O> {
131    pub op: O,
132}
133
134impl<O> ArgminOp for LineSearchOP<O>
135where
136    O: ArgminOp,
137    O::Jacobian: ArgminTranspose + ArgminDot<O::Output, O::Param>,
138    O::Output: ArgminNorm<O::Float>,
139{
140    type Param = O::Param;
141    type Output = O::Float;
142    type Hessian = O::Hessian;
143    type Jacobian = O::Jacobian;
144    type Float = O::Float;
145
146    fn apply(&self, p: &Self::Param) -> Result<Self::Output, Error> {
147        Ok(self.op.apply(p)?.norm())
148    }
149
150    fn gradient(&self, p: &Self::Param) -> Result<Self::Param, Error> {
151        Ok(self.op.jacobian(p)?.t().dot(&self.op.apply(p)?))
152    }
153
154    fn hessian(&self, p: &Self::Param) -> Result<Self::Hessian, Error> {
155        self.op.hessian(p)
156    }
157
158    fn jacobian(&self, p: &Self::Param) -> Result<Self::Jacobian, Error> {
159        self.op.jacobian(p)
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::solver::linesearch::MoreThuenteLineSearch;
167    use crate::test_trait_impl;
168
169    test_trait_impl!(
170        gauss_newton_linesearch_method,
171        GaussNewtonLS<MoreThuenteLineSearch<Vec<f64>, f64>, f64>
172    );
173
174    #[test]
175    fn test_tolerance() {
176        let tol1: f64 = 1e-4;
177
178        let linesearch: MoreThuenteLineSearch<Vec<f64>, f64> = MoreThuenteLineSearch::new();
179        let GaussNewtonLS { tol: t1, .. } = GaussNewtonLS::new(linesearch).with_tol(tol1).unwrap();
180
181        assert!((t1 - tol1).abs() < std::f64::EPSILON);
182    }
183}