argmin/solver/trustregion/
dogleg.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::{Deserialize, Serialize};
15
16/// The Dogleg method computes the intersection of the trust region boundary with a path given by
17/// the unconstraind minimum along the steepest descent direction and the optimum of the quadratic
18/// approximation of the cost function at the current point.
19///
20/// # References:
21///
22/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
23/// Springer. ISBN 0-387-30303-0.
24#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, PartialOrd, Default)]
25pub struct Dogleg<F> {
26    /// Radius
27    radius: F,
28}
29
30impl<F: ArgminFloat> Dogleg<F> {
31    /// Constructor
32    pub fn new() -> Self {
33        Dogleg { radius: F::nan() }
34    }
35}
36
37impl<O, F> Solver<O> for Dogleg<F>
38where
39    O: ArgminOp<Output = F, Float = F>,
40    O::Param: std::fmt::Debug
41        + ArgminMul<F, O::Param>
42        + ArgminWeightedDot<O::Param, O::Float, O::Hessian>
43        + ArgminNorm<F>
44        + ArgminDot<O::Param, O::Float>
45        + ArgminAdd<O::Param, O::Param>
46        + ArgminSub<O::Param, O::Param>,
47    O::Hessian: ArgminInv<O::Hessian> + ArgminDot<O::Param, O::Param>,
48    F: ArgminFloat,
49{
50    const NAME: &'static str = "Dogleg";
51
52    fn next_iter(
53        &mut self,
54        op: &mut OpWrapper<O>,
55        state: &IterState<O>,
56    ) -> Result<ArgminIterData<O>, Error> {
57        let param = state.get_param();
58        let g = state
59            .get_grad()
60            .unwrap_or_else(|| op.gradient(&param).unwrap());
61        let h = state
62            .get_hessian()
63            .unwrap_or_else(|| op.hessian(&param).unwrap());
64        let pstar;
65
66        // pb = -H^-1g
67        let pb = (h.inv()?).dot(&g).mul(&F::from_f64(-1.0).unwrap());
68
69        if pb.norm() <= self.radius {
70            pstar = pb;
71        } else {
72            // pu = - (g^Tg)/(g^THg) * g
73            let pu = g.mul(&(-g.dot(&g) / g.weighted_dot(&h, &g)));
74            // println!("pb: {:?}, pu: {:?}", pb, pu);
75
76            let utu = pu.dot(&pu);
77            let btb = pb.dot(&pb);
78            let utb = pu.dot(&pb);
79
80            // compute tau
81            let delta = self.radius.powi(2);
82            let t1 = F::from_f64(3.0).unwrap() * utb - btb - F::from_f64(2.0).unwrap() * utu;
83            let t2 = (utb.powi(2) - F::from_f64(2.0).unwrap() * utb * delta + delta * btb
84                - btb * utu
85                + delta * utu)
86                .sqrt();
87            let t3 = F::from_f64(-2.0).unwrap() * utb + btb + utu;
88            let tau1: F = -(t1 + t2) / t3;
89            let tau2: F = -(t1 - t2) / t3;
90
91            // pick maximum value of both -- not sure if this is the proper way
92            let mut tau = tau1.max(tau2);
93
94            // if calculation failed because t3 is too small, use the third option
95            if tau.is_nan() || tau.is_infinite() {
96                tau = (delta + btb - F::from_f64(2.0).unwrap() * utu) / (btb - utu);
97            }
98
99            if tau >= F::from_f64(0.0).unwrap() && tau < F::from_f64(1.0).unwrap() {
100                pstar = pu.mul(&tau);
101            } else if tau >= F::from_f64(1.0).unwrap() && tau <= F::from_f64(2.0).unwrap() {
102                pstar = pu.add(&pb.sub(&pu).mul(&(tau - F::from_f64(1.0).unwrap())));
103            } else {
104                return Err(ArgminError::ImpossibleError {
105                    text: "tau is bigger than 2, this is not supposed to happen.".to_string(),
106                }
107                .into());
108            }
109        }
110        let out = ArgminIterData::new().param(pstar);
111        Ok(out)
112    }
113
114    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
115        if state.get_iter() >= 1 {
116            TerminationReason::MaxItersReached
117        } else {
118            TerminationReason::NotTerminated
119        }
120    }
121}
122
123impl<F: ArgminFloat> ArgminTrustRegion<F> for Dogleg<F> {
124    fn set_radius(&mut self, radius: F) {
125        self.radius = radius;
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::test_trait_impl;
133
134    test_trait_impl!(dogleg, Dogleg<f64>);
135}