argmin/solver/trustregion/
cauchypoint.rs1use crate::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::fmt::Debug;
16
17#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, PartialOrd, Default)]
25pub struct CauchyPoint<F> {
26 radius: F,
28}
29
30impl<F: ArgminFloat> CauchyPoint<F> {
31 pub fn new() -> Self {
33 CauchyPoint { radius: F::nan() }
34 }
35}
36
37impl<O, F> Solver<O> for CauchyPoint<F>
38where
39 O: ArgminOp<Output = F, Float = F>,
40 O::Param: Debug
41 + Clone
42 + Serialize
43 + ArgminMul<O::Float, O::Param>
44 + ArgminWeightedDot<O::Param, F, O::Hessian>
45 + ArgminNorm<O::Float>,
46 O::Hessian: Clone + Serialize,
47 F: ArgminFloat,
48{
49 const NAME: &'static str = "Cauchy Point";
50
51 fn next_iter(
52 &mut self,
53 op: &mut OpWrapper<O>,
54 state: &IterState<O>,
55 ) -> Result<ArgminIterData<O>, Error> {
56 let param = state.get_param();
57 let grad = state
58 .get_grad()
59 .unwrap_or_else(|| op.gradient(¶m).unwrap());
60 let grad_norm = grad.norm();
61 let hessian = state
62 .get_hessian()
63 .unwrap_or_else(|| op.hessian(¶m).unwrap());
64
65 let wdp = grad.weighted_dot(&hessian, &grad);
66 let tau: F = if wdp <= F::from_f64(0.0).unwrap() {
67 F::from_f64(1.0).unwrap()
68 } else {
69 F::from_f64(1.0)
70 .unwrap()
71 .min(grad_norm.powi(3) / (self.radius * wdp))
72 };
73
74 let new_param = grad.mul(&(-tau * self.radius / grad_norm));
75 Ok(ArgminIterData::new().param(new_param))
76 }
77
78 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
79 if state.get_iter() >= 1 {
80 TerminationReason::MaxItersReached
81 } else {
82 TerminationReason::NotTerminated
83 }
84 }
85}
86
87impl<F: ArgminFloat> ArgminTrustRegion<F> for CauchyPoint<F> {
88 fn set_radius(&mut self, radius: F) {
89 self.radius = radius;
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::test_trait_impl;
97
98 test_trait_impl!(cauchypoint, CauchyPoint<f64>);
99}