argmin/solver/trustregion/
cauchypoint.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::fmt::Debug;
16
17/// The Cauchy point is the minimum of the quadratic approximation of the cost function within the
18/// trust region along the direction given by the first derivative.
19///
20/// # References:
21///
22/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
23/// Springer. ISBN 0-387-30303-0.
24#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, PartialOrd, Default)]
25pub struct CauchyPoint<F> {
26    /// Radius
27    radius: F,
28}
29
30impl<F: ArgminFloat> CauchyPoint<F> {
31    /// Constructor
32    pub fn new() -> Self {
33        CauchyPoint { radius: F::nan() }
34    }
35}
36
37impl<O, F> Solver<O> for CauchyPoint<F>
38where
39    O: ArgminOp<Output = F, Float = F>,
40    O::Param: Debug
41        + Clone
42        + Serialize
43        + ArgminMul<O::Float, O::Param>
44        + ArgminWeightedDot<O::Param, F, O::Hessian>
45        + ArgminNorm<O::Float>,
46    O::Hessian: Clone + Serialize,
47    F: ArgminFloat,
48{
49    const NAME: &'static str = "Cauchy Point";
50
51    fn next_iter(
52        &mut self,
53        op: &mut OpWrapper<O>,
54        state: &IterState<O>,
55    ) -> Result<ArgminIterData<O>, Error> {
56        let param = state.get_param();
57        let grad = state
58            .get_grad()
59            .unwrap_or_else(|| op.gradient(&param).unwrap());
60        let grad_norm = grad.norm();
61        let hessian = state
62            .get_hessian()
63            .unwrap_or_else(|| op.hessian(&param).unwrap());
64
65        let wdp = grad.weighted_dot(&hessian, &grad);
66        let tau: F = if wdp <= F::from_f64(0.0).unwrap() {
67            F::from_f64(1.0).unwrap()
68        } else {
69            F::from_f64(1.0)
70                .unwrap()
71                .min(grad_norm.powi(3) / (self.radius * wdp))
72        };
73
74        let new_param = grad.mul(&(-tau * self.radius / grad_norm));
75        Ok(ArgminIterData::new().param(new_param))
76    }
77
78    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
79        if state.get_iter() >= 1 {
80            TerminationReason::MaxItersReached
81        } else {
82            TerminationReason::NotTerminated
83        }
84    }
85}
86
87impl<F: ArgminFloat> ArgminTrustRegion<F> for CauchyPoint<F> {
88    fn set_radius(&mut self, radius: F) {
89        self.radius = radius;
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::test_trait_impl;
97
98    test_trait_impl!(cauchypoint, CauchyPoint<f64>);
99}