argmin/solver/quasinewton/
dfp.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16
17/// DFP method
18///
19/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/dfp.rs)
20///
21/// # References:
22///
23/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
24/// Springer. ISBN 0-387-30303-0.
25#[derive(Clone, Serialize, Deserialize)]
26pub struct DFP<L, H, F> {
27    /// Inverse Hessian
28    inv_hessian: H,
29    /// line search
30    linesearch: L,
31    /// Tolerance for the stopping criterion based on the change of the norm on the gradient
32    tol_grad: F,
33}
34
35impl<L, H, F: ArgminFloat> DFP<L, H, F> {
36    /// Constructor
37    pub fn new(init_inverse_hessian: H, linesearch: L) -> Self {
38        DFP {
39            inv_hessian: init_inverse_hessian,
40            linesearch,
41            tol_grad: F::epsilon().sqrt(),
42        }
43    }
44
45    /// Sets tolerance for the stopping criterion based on the change of the norm on the gradient
46    pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
47        self.tol_grad = tol_grad;
48        self
49    }
50}
51
52impl<O, L, H, F> Solver<O> for DFP<L, H, F>
53where
54    O: ArgminOp<Output = F, Hessian = H, Float = F>,
55    O::Param: Clone
56        + Default
57        + Serialize
58        + ArgminSub<O::Param, O::Param>
59        + ArgminDot<O::Param, O::Float>
60        + ArgminDot<O::Param, O::Hessian>
61        + ArgminScaledAdd<O::Param, F, O::Param>
62        + ArgminNorm<O::Float>
63        + ArgminMul<O::Float, O::Param>
64        + ArgminTranspose,
65    O::Hessian: Clone
66        + Default
67        + Serialize
68        + DeserializeOwned
69        + ArgminSub<O::Hessian, O::Hessian>
70        + ArgminDot<O::Param, O::Param>
71        + ArgminDot<O::Hessian, O::Hessian>
72        + ArgminAdd<O::Hessian, O::Hessian>
73        + ArgminMul<F, O::Hessian>
74        + ArgminTranspose
75        + ArgminEye,
76    L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
77    F: ArgminFloat,
78{
79    const NAME: &'static str = "DFP";
80
81    fn init(
82        &mut self,
83        op: &mut OpWrapper<O>,
84        state: &IterState<O>,
85    ) -> Result<Option<ArgminIterData<O>>, Error> {
86        let param = state.get_param();
87        let cost = op.apply(&param)?;
88        let grad = op.gradient(&param)?;
89        Ok(Some(
90            ArgminIterData::new().param(param).cost(cost).grad(grad),
91        ))
92    }
93
94    fn next_iter(
95        &mut self,
96        op: &mut OpWrapper<O>,
97        state: &IterState<O>,
98    ) -> Result<ArgminIterData<O>, Error> {
99        let param = state.get_param();
100        let cost = state.get_cost();
101        let prev_grad = if let Some(grad) = state.get_grad() {
102            grad
103        } else {
104            op.gradient(&param)?
105        };
106        let p = self
107            .inv_hessian
108            .dot(&prev_grad)
109            .mul(&F::from_f64(-1.0).unwrap());
110
111        self.linesearch.set_search_direction(p);
112
113        let ArgminResult {
114            operator: line_op,
115            state:
116                IterState {
117                    param: xk1,
118                    cost: next_cost,
119                    ..
120                },
121        } = Executor::new(
122            OpWrapper::new_from_wrapper(op),
123            self.linesearch.clone(),
124            param.clone(),
125        )
126        .grad(prev_grad.clone())
127        .cost(cost)
128        .ctrlc(false)
129        .run()?;
130
131        // take care of function eval counts
132        op.consume_op(line_op);
133
134        let grad = op.gradient(&xk1)?;
135        let yk = grad.sub(&prev_grad);
136
137        let sk = xk1.sub(&param);
138
139        let yksk: F = yk.dot(&sk);
140
141        let sksk: O::Hessian = sk.dot(&sk);
142
143        let tmp3: O::Param = self.inv_hessian.dot(&yk);
144        let tmp4: F = tmp3.dot(&yk);
145        let tmp3: O::Hessian = tmp3.dot(&tmp3);
146        let tmp3: O::Hessian = tmp3.mul(&(F::from_f64(1.0).unwrap() / tmp4));
147
148        self.inv_hessian = self
149            .inv_hessian
150            .sub(&tmp3)
151            .add(&sksk.mul(&(F::from_f64(1.0).unwrap() / yksk)));
152
153        Ok(ArgminIterData::new().param(xk1).cost(next_cost).grad(grad))
154    }
155
156    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
157        if state.get_grad().unwrap().norm() < self.tol_grad {
158            return TerminationReason::TargetPrecisionReached;
159        }
160        TerminationReason::NotTerminated
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::solver::linesearch::MoreThuenteLineSearch;
168    use crate::test_trait_impl;
169
170    type Operator = MinimalNoOperator;
171
172    test_trait_impl!(dfp, DFP<Operator, MoreThuenteLineSearch<Operator, f64>, f64>);
173
174    #[test]
175    fn test_tolerances() {
176        let linesearch: MoreThuenteLineSearch<f64, f64> =
177            MoreThuenteLineSearch::new().c(1e-4, 0.9).unwrap();
178        let init_hessian: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
179
180        let tol: f64 = 1e-4;
181
182        let DFP { tol_grad: t, .. } = DFP::new(init_hessian, linesearch).with_tol_grad(tol);
183
184        assert!((t - tol).abs() < std::f64::EPSILON);
185    }
186}