argmin/solver/trustregion/
cauchypoint.rsuse crate::prelude::*;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, PartialOrd, Default)]
pub struct CauchyPoint<F> {
radius: F,
}
impl<F: ArgminFloat> CauchyPoint<F> {
pub fn new() -> Self {
CauchyPoint { radius: F::nan() }
}
}
impl<O, F> Solver<O> for CauchyPoint<F>
where
O: ArgminOp<Output = F, Float = F>,
O::Param: Debug
+ Clone
+ Serialize
+ ArgminMul<O::Float, O::Param>
+ ArgminWeightedDot<O::Param, F, O::Hessian>
+ ArgminNorm<O::Float>,
O::Hessian: Clone + Serialize,
F: ArgminFloat,
{
const NAME: &'static str = "Cauchy Point";
fn next_iter(
&mut self,
op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
let param = state.get_param();
let grad = state
.get_grad()
.unwrap_or_else(|| op.gradient(¶m).unwrap());
let grad_norm = grad.norm();
let hessian = state
.get_hessian()
.unwrap_or_else(|| op.hessian(¶m).unwrap());
let wdp = grad.weighted_dot(&hessian, &grad);
let tau: F = if wdp <= F::from_f64(0.0).unwrap() {
F::from_f64(1.0).unwrap()
} else {
F::from_f64(1.0)
.unwrap()
.min(grad_norm.powi(3) / (self.radius * wdp))
};
let new_param = grad.mul(&(-tau * self.radius / grad_norm));
Ok(ArgminIterData::new().param(new_param))
}
fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
if state.get_iter() >= 1 {
TerminationReason::MaxItersReached
} else {
TerminationReason::NotTerminated
}
}
}
impl<F: ArgminFloat> ArgminTrustRegion<F> for CauchyPoint<F> {
fn set_radius(&mut self, radius: F) {
self.radius = radius;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_trait_impl;
test_trait_impl!(cauchypoint, CauchyPoint<f64>);
}