argmin/solver/newton/
newton_cg.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! TODO: Stop when search direction is close to 0
9//!
10//! # References:
11//!
12//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
13//! Springer. ISBN 0-387-30303-0.
14
15use crate::prelude::*;
16use crate::solver::conjugategradient::ConjugateGradient;
17use serde::{Deserialize, Serialize};
18
19/// The Newton-CG method (also called truncated Newton method) uses a modified CG to solve the
20/// Newton equations approximately. After a search direction is found, a line search is performed.
21///
22/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/newton_cg.rs)
23///
24/// # References:
25///
26/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
27/// Springer. ISBN 0-387-30303-0.
28#[derive(Clone, Serialize, Deserialize)]
29pub struct NewtonCG<L, F> {
30    /// line search
31    linesearch: L,
32    /// curvature_threshold
33    curvature_threshold: F,
34    /// Tolerance for the stopping criterion based on cost difference
35    tol: F,
36}
37
38impl<L, F: ArgminFloat> NewtonCG<L, F> {
39    /// Constructor
40    pub fn new(linesearch: L) -> Self {
41        NewtonCG {
42            linesearch,
43            curvature_threshold: F::from_f64(0.0).unwrap(),
44            tol: F::epsilon(),
45        }
46    }
47
48    /// Set curvature threshold
49    pub fn curvature_threshold(mut self, threshold: F) -> Self {
50        self.curvature_threshold = threshold;
51        self
52    }
53
54    /// Set tolerance for the stopping criterion based on cost difference
55    pub fn with_tol(mut self, tol: F) -> Result<Self, Error> {
56        if tol <= F::from_f64(0.0).unwrap() {
57            return Err(ArgminError::InvalidParameter {
58                text: "Newton-CG: tol must be positive.".to_string(),
59            }
60            .into());
61        }
62        self.tol = tol;
63        Ok(self)
64    }
65}
66
67impl<O, L, F> Solver<O> for NewtonCG<L, F>
68where
69    O: ArgminOp<Output = F, Float = F>,
70    O::Param: Send
71        + Sync
72        + Clone
73        + Serialize
74        + Default
75        + ArgminSub<O::Param, O::Param>
76        + ArgminAdd<O::Param, O::Param>
77        + ArgminDot<O::Param, O::Float>
78        + ArgminScaledAdd<O::Param, O::Float, O::Param>
79        + ArgminMul<F, O::Param>
80        + ArgminConj
81        + ArgminZeroLike
82        + ArgminNorm<O::Float>,
83    O::Hessian: Send
84        + Sync
85        + Default
86        + Clone
87        + Serialize
88        + Default
89        + ArgminInv<O::Hessian>
90        + ArgminDot<O::Param, O::Param>,
91    L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
92    F: ArgminFloat + Default + ArgminDiv<O::Float, O::Float> + ArgminNorm<O::Float> + ArgminConj,
93{
94    const NAME: &'static str = "Newton-CG";
95
96    fn next_iter(
97        &mut self,
98        op: &mut OpWrapper<O>,
99        state: &IterState<O>,
100    ) -> Result<ArgminIterData<O>, Error> {
101        let param = state.get_param();
102        let grad = op.gradient(&param)?;
103        let hessian = op.hessian(&param)?;
104
105        // Solve CG subproblem
106        let cg_op: CGSubProblem<O::Param, O::Hessian, O::Float> =
107            CGSubProblem::new(hessian.clone());
108        let mut cg_op = OpWrapper::new(cg_op);
109
110        let mut x_p = param.zero_like();
111        let mut x: O::Param = param.zero_like();
112        let mut cg = ConjugateGradient::new(grad.mul(&(F::from_f64(-1.0).unwrap())))?;
113
114        let mut cg_state = IterState::new(x_p.clone());
115        cg.init(&mut cg_op, &cg_state)?;
116        let grad_norm = grad.norm();
117        for iter in 0.. {
118            let data = cg.next_iter(&mut cg_op, &cg_state)?;
119            x = data.get_param().unwrap();
120            let p = cg.p_prev();
121            let curvature = p.dot(&hessian.dot(&p));
122            if curvature <= self.curvature_threshold {
123                if iter == 0 {
124                    x = grad.mul(&(F::from_f64(-1.0).unwrap()));
125                    break;
126                } else {
127                    x = x_p;
128                    break;
129                }
130            }
131            if data.get_cost().unwrap()
132                <= F::from_f64(0.5).unwrap().min(grad_norm.sqrt()) * grad_norm
133            {
134                break;
135            }
136            cg_state.param(x.clone());
137            cg_state.cost(data.get_cost().unwrap());
138            x_p = x.clone();
139        }
140
141        // perform line search
142        self.linesearch.set_search_direction(x);
143
144        // Run solver
145        let ArgminResult {
146            operator: line_op,
147            state:
148                IterState {
149                    param: next_param,
150                    cost: next_cost,
151                    ..
152                },
153        } = Executor::new(
154            OpWrapper::new_from_wrapper(op),
155            self.linesearch.clone(),
156            param,
157        )
158        .grad(grad)
159        .cost(state.get_cost())
160        .ctrlc(false)
161        .run()?;
162
163        op.consume_op(line_op);
164
165        Ok(ArgminIterData::new().param(next_param).cost(next_cost))
166    }
167
168    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
169        if (state.get_cost() - state.get_prev_cost()).abs() < self.tol {
170            TerminationReason::NoChangeInCost
171        } else {
172            TerminationReason::NotTerminated
173        }
174    }
175}
176
177#[derive(Clone, Default, Serialize, Deserialize)]
178struct CGSubProblem<T, H, F> {
179    hessian: H,
180    phantom: std::marker::PhantomData<T>,
181    float: std::marker::PhantomData<F>,
182}
183
184impl<T, H, F> CGSubProblem<T, H, F>
185where
186    T: Clone + Send + Sync,
187    H: Clone + Default + ArgminDot<T, T> + Send + Sync,
188{
189    /// constructor
190    pub fn new(hessian: H) -> Self {
191        CGSubProblem {
192            hessian,
193            phantom: std::marker::PhantomData,
194            float: std::marker::PhantomData,
195        }
196    }
197}
198
199impl<T, H, F> ArgminOp for CGSubProblem<T, H, F>
200where
201    T: Clone + Default + Send + Sync + Serialize + serde::de::DeserializeOwned,
202    H: Clone + Default + ArgminDot<T, T> + Send + Sync + Serialize + serde::de::DeserializeOwned,
203    F: ArgminFloat,
204{
205    type Param = T;
206    type Output = T;
207    type Hessian = ();
208    type Jacobian = ();
209    type Float = F;
210
211    fn apply(&self, p: &T) -> Result<T, Error> {
212        Ok(self.hessian.dot(&p))
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::solver::linesearch::MoreThuenteLineSearch;
220    use crate::test_trait_impl;
221
222    test_trait_impl!(
223        newton_cg,
224        NewtonCG<MoreThuenteLineSearch<Vec<f64>, f64>, f64>
225    );
226
227    test_trait_impl!(cg_subproblem, CGSubProblem<Vec<f64>, Vec<Vec<f64>>, f64>);
228
229    #[test]
230    fn test_tolerance() {
231        let tol1: f64 = 1e-4;
232
233        let linesearch: MoreThuenteLineSearch<Vec<f64>, f64> = MoreThuenteLineSearch::new();
234
235        let NewtonCG { tol: t, .. }: NewtonCG<MoreThuenteLineSearch<Vec<f64>, f64>, f64> =
236            NewtonCG::new(linesearch).with_tol(tol1).unwrap();
237
238        assert!((t - tol1).abs() < std::f64::EPSILON);
239    }
240}