argmin/solver/trustregion/
dogleg.rsuse crate::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, PartialOrd, Default)]
pub struct Dogleg<F> {
radius: F,
}
impl<F: ArgminFloat> Dogleg<F> {
pub fn new() -> Self {
Dogleg { radius: F::nan() }
}
}
impl<O, F> Solver<O> for Dogleg<F>
where
O: ArgminOp<Output = F, Float = F>,
O::Param: std::fmt::Debug
+ ArgminMul<F, O::Param>
+ ArgminWeightedDot<O::Param, O::Float, O::Hessian>
+ ArgminNorm<F>
+ ArgminDot<O::Param, O::Float>
+ ArgminAdd<O::Param, O::Param>
+ ArgminSub<O::Param, O::Param>,
O::Hessian: ArgminInv<O::Hessian> + ArgminDot<O::Param, O::Param>,
F: ArgminFloat,
{
const NAME: &'static str = "Dogleg";
fn next_iter(
&mut self,
op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
let param = state.get_param();
let g = state
.get_grad()
.unwrap_or_else(|| op.gradient(¶m).unwrap());
let h = state
.get_hessian()
.unwrap_or_else(|| op.hessian(¶m).unwrap());
let pstar;
let pb = (h.inv()?).dot(&g).mul(&F::from_f64(-1.0).unwrap());
if pb.norm() <= self.radius {
pstar = pb;
} else {
let pu = g.mul(&(-g.dot(&g) / g.weighted_dot(&h, &g)));
let utu = pu.dot(&pu);
let btb = pb.dot(&pb);
let utb = pu.dot(&pb);
let delta = self.radius.powi(2);
let t1 = F::from_f64(3.0).unwrap() * utb - btb - F::from_f64(2.0).unwrap() * utu;
let t2 = (utb.powi(2) - F::from_f64(2.0).unwrap() * utb * delta + delta * btb
- btb * utu
+ delta * utu)
.sqrt();
let t3 = F::from_f64(-2.0).unwrap() * utb + btb + utu;
let tau1: F = -(t1 + t2) / t3;
let tau2: F = -(t1 - t2) / t3;
let mut tau = tau1.max(tau2);
if tau.is_nan() || tau.is_infinite() {
tau = (delta + btb - F::from_f64(2.0).unwrap() * utu) / (btb - utu);
}
if tau >= F::from_f64(0.0).unwrap() && tau < F::from_f64(1.0).unwrap() {
pstar = pu.mul(&tau);
} else if tau >= F::from_f64(1.0).unwrap() && tau <= F::from_f64(2.0).unwrap() {
pstar = pu.add(&pb.sub(&pu).mul(&(tau - F::from_f64(1.0).unwrap())));
} else {
return Err(ArgminError::ImpossibleError {
text: "tau is bigger than 2, this is not supposed to happen.".to_string(),
}
.into());
}
}
let out = ArgminIterData::new().param(pstar);
Ok(out)
}
fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
if state.get_iter() >= 1 {
TerminationReason::MaxItersReached
} else {
TerminationReason::NotTerminated
}
}
}
impl<F: ArgminFloat> ArgminTrustRegion<F> for Dogleg<F> {
fn set_radius(&mut self, radius: F) {
self.radius = radius;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_trait_impl;
test_trait_impl!(dogleg, Dogleg<f64>);
}