argmin/solver/trustregion/
dogleg.rs1use crate::prelude::*;
14use serde::{Deserialize, Serialize};
15
16#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, PartialOrd, Default)]
25pub struct Dogleg<F> {
26 radius: F,
28}
29
30impl<F: ArgminFloat> Dogleg<F> {
31 pub fn new() -> Self {
33 Dogleg { radius: F::nan() }
34 }
35}
36
37impl<O, F> Solver<O> for Dogleg<F>
38where
39 O: ArgminOp<Output = F, Float = F>,
40 O::Param: std::fmt::Debug
41 + ArgminMul<F, O::Param>
42 + ArgminWeightedDot<O::Param, O::Float, O::Hessian>
43 + ArgminNorm<F>
44 + ArgminDot<O::Param, O::Float>
45 + ArgminAdd<O::Param, O::Param>
46 + ArgminSub<O::Param, O::Param>,
47 O::Hessian: ArgminInv<O::Hessian> + ArgminDot<O::Param, O::Param>,
48 F: ArgminFloat,
49{
50 const NAME: &'static str = "Dogleg";
51
52 fn next_iter(
53 &mut self,
54 op: &mut OpWrapper<O>,
55 state: &IterState<O>,
56 ) -> Result<ArgminIterData<O>, Error> {
57 let param = state.get_param();
58 let g = state
59 .get_grad()
60 .unwrap_or_else(|| op.gradient(¶m).unwrap());
61 let h = state
62 .get_hessian()
63 .unwrap_or_else(|| op.hessian(¶m).unwrap());
64 let pstar;
65
66 let pb = (h.inv()?).dot(&g).mul(&F::from_f64(-1.0).unwrap());
68
69 if pb.norm() <= self.radius {
70 pstar = pb;
71 } else {
72 let pu = g.mul(&(-g.dot(&g) / g.weighted_dot(&h, &g)));
74 let utu = pu.dot(&pu);
77 let btb = pb.dot(&pb);
78 let utb = pu.dot(&pb);
79
80 let delta = self.radius.powi(2);
82 let t1 = F::from_f64(3.0).unwrap() * utb - btb - F::from_f64(2.0).unwrap() * utu;
83 let t2 = (utb.powi(2) - F::from_f64(2.0).unwrap() * utb * delta + delta * btb
84 - btb * utu
85 + delta * utu)
86 .sqrt();
87 let t3 = F::from_f64(-2.0).unwrap() * utb + btb + utu;
88 let tau1: F = -(t1 + t2) / t3;
89 let tau2: F = -(t1 - t2) / t3;
90
91 let mut tau = tau1.max(tau2);
93
94 if tau.is_nan() || tau.is_infinite() {
96 tau = (delta + btb - F::from_f64(2.0).unwrap() * utu) / (btb - utu);
97 }
98
99 if tau >= F::from_f64(0.0).unwrap() && tau < F::from_f64(1.0).unwrap() {
100 pstar = pu.mul(&tau);
101 } else if tau >= F::from_f64(1.0).unwrap() && tau <= F::from_f64(2.0).unwrap() {
102 pstar = pu.add(&pb.sub(&pu).mul(&(tau - F::from_f64(1.0).unwrap())));
103 } else {
104 return Err(ArgminError::ImpossibleError {
105 text: "tau is bigger than 2, this is not supposed to happen.".to_string(),
106 }
107 .into());
108 }
109 }
110 let out = ArgminIterData::new().param(pstar);
111 Ok(out)
112 }
113
114 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
115 if state.get_iter() >= 1 {
116 TerminationReason::MaxItersReached
117 } else {
118 TerminationReason::NotTerminated
119 }
120 }
121}
122
123impl<F: ArgminFloat> ArgminTrustRegion<F> for Dogleg<F> {
124 fn set_radius(&mut self, radius: F) {
125 self.radius = radius;
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::test_trait_impl;
133
134 test_trait_impl!(dogleg, Dogleg<f64>);
135}