argmin/solver/trustregion/
steihaug.rsuse crate::prelude::*;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, PartialOrd, Default)]
pub struct Steihaug<P, F> {
radius: F,
epsilon: F,
p: P,
r: P,
rtr: F,
r_0_norm: F,
d: P,
max_iters: u64,
}
impl<P, F> Steihaug<P, F>
where
P: Default + Clone + ArgminMul<F, P> + ArgminDot<P, F> + ArgminAdd<P, P>,
F: ArgminFloat,
{
pub fn new() -> Self {
Steihaug {
radius: F::nan(),
epsilon: F::from_f64(10e-10).unwrap(),
p: P::default(),
r: P::default(),
rtr: F::nan(),
r_0_norm: F::nan(),
d: P::default(),
max_iters: std::u64::MAX,
}
}
pub fn epsilon(mut self, epsilon: F) -> Result<Self, Error> {
if epsilon <= F::from_f64(0.0).unwrap() {
return Err(ArgminError::InvalidParameter {
text: "Steihaug: epsilon must be > 0.0.".to_string(),
}
.into());
}
self.epsilon = epsilon;
Ok(self)
}
pub fn max_iters(mut self, iters: u64) -> Self {
self.max_iters = iters;
self
}
fn eval_m<H>(&self, p: &P, g: &P, h: &H) -> F
where
P: ArgminWeightedDot<P, F, H>,
{
g.dot(&p) + F::from_f64(0.5).unwrap() * p.weighted_dot(&h, &p)
}
#[allow(clippy::many_single_char_names)]
fn tau<G, H>(&self, filter_func: G, eval: bool, g: &P, h: &H) -> F
where
G: Fn(F) -> bool,
H: ArgminDot<P, P>,
{
let a = self.p.dot(&self.p);
let b = self.d.dot(&self.d);
let c = self.p.dot(&self.d);
let delta = self.radius.powi(2);
let t1 = (-a * b + b * delta + c.powi(2)).sqrt();
let tau1 = -(t1 + c) / b;
let tau2 = (t1 - c) / b;
let mut t = vec![tau1, tau2];
if tau1.is_nan() || tau2.is_nan() || tau1.is_infinite() || tau2.is_infinite() {
let tau3 = (delta - a) / (F::from_f64(2.0).unwrap() * c);
t.push(tau3);
}
let v = if eval {
let mut v = t
.iter()
.cloned()
.enumerate()
.filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
.map(|(i, tau)| {
let p = self.p.add(&self.d.mul(&tau));
(i, self.eval_m(&p, g, h))
})
.filter(|(_, m)| !m.is_nan() || !m.is_infinite())
.collect::<Vec<(usize, F)>>();
v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
v
} else {
let mut v = t
.iter()
.cloned()
.enumerate()
.filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
.collect::<Vec<(usize, F)>>();
v.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
v
};
t[v[0].0]
}
}
impl<P, O, F> Solver<O> for Steihaug<P, F>
where
O: ArgminOp<Param = P, Output = F, Float = F>,
P: Clone
+ Serialize
+ DeserializeOwned
+ Default
+ ArgminMul<F, P>
+ ArgminWeightedDot<P, F, O::Hessian>
+ ArgminNorm<F>
+ ArgminDot<P, F>
+ ArgminAdd<P, P>
+ ArgminSub<P, P>
+ ArgminZeroLike
+ ArgminMul<F, P>,
O::Hessian: ArgminDot<P, P>,
F: ArgminFloat,
{
const NAME: &'static str = "Steihaug";
fn init(
&mut self,
_op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<Option<ArgminIterData<O>>, Error> {
self.r = state.get_grad().unwrap();
self.r_0_norm = self.r.norm();
self.rtr = self.r.dot(&self.r);
self.d = self.r.mul(&F::from_f64(-1.0).unwrap());
self.p = self.r.zero_like();
Ok(if self.r_0_norm < self.epsilon {
Some(
ArgminIterData::new()
.param(self.p.clone())
.termination_reason(TerminationReason::TargetPrecisionReached),
)
} else {
None
})
}
fn next_iter(
&mut self,
_op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
let grad = state.get_grad().unwrap();
let h = state.get_hessian().unwrap();
let dhd = self.d.weighted_dot(&h, &self.d);
if dhd <= F::from_f64(0.0).unwrap() {
let tau = self.tau(|_| true, true, &grad, &h);
return Ok(ArgminIterData::new()
.param(self.p.add(&self.d.mul(&tau)))
.termination_reason(TerminationReason::TargetPrecisionReached));
}
let alpha = self.rtr / dhd;
let p_n = self.p.add(&self.d.mul(&alpha));
if p_n.norm() >= self.radius {
let tau = self.tau(|x| x >= F::from_f64(0.0).unwrap(), false, &grad, &h);
return Ok(ArgminIterData::new()
.param(self.p.add(&self.d.mul(&tau)))
.termination_reason(TerminationReason::TargetPrecisionReached));
}
let r_n = self.r.add(&h.dot(&self.d).mul(&alpha));
if r_n.norm() < self.epsilon * self.r_0_norm {
return Ok(ArgminIterData::new()
.param(p_n)
.termination_reason(TerminationReason::TargetPrecisionReached));
}
let rjtrj = r_n.dot(&r_n);
let beta = rjtrj / self.rtr;
self.d = r_n.mul(&F::from_f64(-1.0).unwrap()).add(&self.d.mul(&beta));
self.r = r_n;
self.p = p_n;
self.rtr = rjtrj;
Ok(ArgminIterData::new()
.param(self.p.clone())
.cost(self.rtr)
.grad(grad)
.hessian(h))
}
fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
if state.get_iter() >= self.max_iters {
TerminationReason::MaxItersReached
} else {
TerminationReason::NotTerminated
}
}
}
impl<P: Clone + Serialize, F: ArgminFloat> ArgminTrustRegion<F> for Steihaug<P, F> {
fn set_radius(&mut self, radius: F) {
self.radius = radius;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_trait_impl;
test_trait_impl!(steihaug, Steihaug<MinimalNoOperator, f64>);
}