argmin/solver/conjugategradient/
nonlinear_cg.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Important TODO: Find out which line search should be the default choice. Also try to replicate
9//! CG_DESCENT.
10//!
11//! # References:
12//!
13//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
14//! Springer. ISBN 0-387-30303-0.
15
16use crate::prelude::*;
17use serde::de::DeserializeOwned;
18use serde::{Deserialize, Serialize};
19use std::default::Default;
20
21/// The nonlinear conjugate gradient is a generalization of the conjugate gradient method for
22/// nonlinear optimization problems.
23///
24/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/nonlinear_cg.rs)
25///
26/// # References:
27///
28/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
29/// Springer. ISBN 0-387-30303-0.
30#[derive(Clone, Serialize, Deserialize)]
31pub struct NonlinearConjugateGradient<P, L, B, F> {
32    /// p
33    p: P,
34    /// beta
35    beta: F,
36    /// line search
37    linesearch: L,
38    /// beta update method
39    beta_method: B,
40    /// Number of iterations after which a restart is performed
41    restart_iter: u64,
42    /// Restart based on orthogonality
43    restart_orthogonality: Option<F>,
44}
45
46impl<P, L, B, F> NonlinearConjugateGradient<P, L, B, F>
47where
48    P: Default,
49    F: ArgminFloat,
50{
51    /// Constructor (Polak Ribiere Conjugate Gradient (PR-CG))
52    pub fn new(linesearch: L, beta_method: B) -> Result<Self, Error> {
53        Ok(NonlinearConjugateGradient {
54            p: P::default(),
55            beta: F::nan(),
56            linesearch,
57            beta_method,
58            restart_iter: std::u64::MAX,
59            restart_orthogonality: None,
60        })
61    }
62
63    /// Specifiy the number of iterations after which a restart should be performed
64    /// This allows the algorithm to "forget" previous information which may not be helpful
65    /// anymore.
66    pub fn restart_iters(mut self, iters: u64) -> Self {
67        self.restart_iter = iters;
68        self
69    }
70
71    /// Set the value for the orthogonality measure.
72    /// Setting this parameter leads to a restart of the algorithm (setting beta = 0) after two
73    /// consecutive search directions are not orthogonal anymore. In other words, if this condition
74    /// is met:
75    ///
76    /// `|\nabla f_k^T * \nabla f_{k-1}| / | \nabla f_k ||^2 >= v`
77    ///
78    /// A typical value for `v` is 0.1.
79    pub fn restart_orthogonality(mut self, v: F) -> Self {
80        self.restart_orthogonality = Some(v);
81        self
82    }
83}
84
85impl<O, P, L, B, F> Solver<O> for NonlinearConjugateGradient<P, L, B, F>
86where
87    O: ArgminOp<Param = P, Output = F, Float = F>,
88    P: Clone
89        + Default
90        + Serialize
91        + DeserializeOwned
92        + ArgminSub<O::Param, O::Param>
93        + ArgminDot<O::Param, O::Float>
94        + ArgminScaledAdd<O::Param, O::Float, O::Param>
95        + ArgminAdd<O::Param, O::Param>
96        + ArgminMul<F, O::Param>
97        + ArgminDot<O::Param, O::Float>
98        + ArgminNorm<O::Float>,
99    O::Hessian: Default,
100    L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
101    B: ArgminNLCGBetaUpdate<O::Param, O::Float>,
102    F: ArgminFloat,
103{
104    const NAME: &'static str = "Nonlinear Conjugate Gradient";
105
106    fn init(
107        &mut self,
108        op: &mut OpWrapper<O>,
109        state: &IterState<O>,
110    ) -> Result<Option<ArgminIterData<O>>, Error> {
111        let param = state.get_param();
112        let cost = op.apply(&param)?;
113        let grad = op.gradient(&param)?;
114        self.p = grad.mul(&(F::from_f64(-1.0).unwrap()));
115        Ok(Some(
116            ArgminIterData::new().param(param).cost(cost).grad(grad),
117        ))
118    }
119
120    fn next_iter(
121        &mut self,
122        op: &mut OpWrapper<O>,
123        state: &IterState<O>,
124    ) -> Result<ArgminIterData<O>, Error> {
125        let xk = state.get_param();
126        let grad = if let Some(grad) = state.get_grad() {
127            grad
128        } else {
129            op.gradient(&xk)?
130        };
131        let cur_cost = state.get_cost();
132
133        // Linesearch
134        self.linesearch.set_search_direction(self.p.clone());
135
136        // Run solver
137        let ArgminResult {
138            operator: line_op,
139            state: line_state,
140        } = Executor::new(OpWrapper::new_from_wrapper(op), self.linesearch.clone(), xk)
141            .grad(grad.clone())
142            .cost(cur_cost)
143            .ctrlc(false)
144            .run()?;
145
146        // takes care of the counts of function evaluations
147        op.consume_op(line_op);
148
149        let xk1 = line_state.get_param();
150
151        // Update of beta
152        let new_grad = op.gradient(&xk1)?;
153
154        let restart_orthogonality = match self.restart_orthogonality {
155            Some(v) => new_grad.dot(&grad).abs() / new_grad.norm().powi(2) >= v,
156            None => false,
157        };
158
159        let restart_iter: bool =
160            (state.get_iter() % self.restart_iter == 0) && state.get_iter() != 0;
161
162        if restart_iter || restart_orthogonality {
163            self.beta = F::from_f64(0.0).unwrap();
164        } else {
165            self.beta = self.beta_method.update(&grad, &new_grad, &self.p);
166        }
167
168        // Update of p
169        self.p = new_grad
170            .mul(&(F::from_f64(-1.0).unwrap()))
171            .add(&self.p.mul(&self.beta));
172
173        // Housekeeping
174        let cost = op.apply(&xk1)?;
175
176        Ok(ArgminIterData::new()
177            .param(xk1)
178            .cost(cost)
179            .grad(new_grad)
180            .kv(make_kv!("beta" => self.beta;
181             "restart_iter" => restart_iter;
182             "restart_orthogonality" => restart_orthogonality;
183            )))
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::core::MinimalNoOperator;
191    use crate::solver::conjugategradient::beta::PolakRibiere;
192    use crate::solver::linesearch::MoreThuenteLineSearch;
193    use crate::test_trait_impl;
194
195    test_trait_impl!(
196        nonlinear_cg,
197        NonlinearConjugateGradient<
198            MinimalNoOperator,
199            MoreThuenteLineSearch<MinimalNoOperator, f64>,
200            PolakRibiere,
201            f64
202        >
203    );
204}