argmin/solver/trustregion/
trustregion_method.rsuse crate::prelude::*;
use crate::solver::trustregion::reduction_ratio;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Clone, Serialize, Deserialize)]
pub struct TrustRegion<R, F> {
radius: F,
max_radius: F,
eta: F,
subproblem: R,
fxk: F,
mk0: F,
}
impl<R, F: ArgminFloat> TrustRegion<R, F> {
pub fn new(subproblem: R) -> Self {
TrustRegion {
radius: F::from_f64(1.0).unwrap(),
max_radius: F::from_f64(100.0).unwrap(),
eta: F::from_f64(0.125).unwrap(),
subproblem,
fxk: F::nan(),
mk0: F::nan(),
}
}
pub fn radius(mut self, radius: F) -> Self {
self.radius = radius;
self
}
pub fn max_radius(mut self, max_radius: F) -> Self {
self.max_radius = max_radius;
self
}
pub fn eta(mut self, eta: F) -> Result<Self, Error> {
if eta >= F::from_f64(0.25).unwrap() || eta < F::from_f64(0.0).unwrap() {
return Err(ArgminError::InvalidParameter {
text: "TrustRegion: eta must be in [0, 1/4).".to_string(),
}
.into());
}
self.eta = eta;
Ok(self)
}
}
impl<O, R, F> Solver<O> for TrustRegion<R, F>
where
O: ArgminOp<Output = F, Float = F>,
O::Param: Default
+ Clone
+ Debug
+ Serialize
+ ArgminMul<F, O::Param>
+ ArgminWeightedDot<O::Param, F, O::Hessian>
+ ArgminNorm<F>
+ ArgminDot<O::Param, F>
+ ArgminAdd<O::Param, O::Param>
+ ArgminSub<O::Param, O::Param>
+ ArgminZeroLike
+ ArgminMul<F, O::Param>,
O::Hessian: Default + Clone + Debug + Serialize + ArgminDot<O::Param, O::Param>,
R: ArgminTrustRegion<F> + Solver<OpWrapper<O>>,
F: ArgminFloat,
{
const NAME: &'static str = "Trust region";
fn init(
&mut self,
op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<Option<ArgminIterData<O>>, Error> {
let param = state.get_param();
let grad = op.gradient(¶m)?;
let hessian = op.hessian(¶m)?;
self.fxk = op.apply(¶m)?;
self.mk0 = self.fxk;
Ok(Some(
ArgminIterData::new()
.param(param)
.cost(self.fxk)
.grad(grad)
.hessian(hessian),
))
}
fn next_iter(
&mut self,
op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
let param = state.get_param();
let grad = state
.get_grad()
.unwrap_or_else(|| op.gradient(¶m).unwrap());
let hessian = state
.get_hessian()
.unwrap_or_else(|| op.hessian(¶m).unwrap());
self.subproblem.set_radius(self.radius);
let ArgminResult {
operator: sub_op,
state: IterState { param: pk, .. },
} = Executor::new(
OpWrapper::new_from_wrapper(op),
self.subproblem.clone(),
param.clone(),
)
.grad(grad.clone())
.hessian(hessian.clone())
.ctrlc(false)
.run()?;
op.consume_op(sub_op);
let new_param = pk.add(¶m);
let fxkpk = op.apply(&new_param)?;
let mkpk =
self.fxk + pk.dot(&grad) + F::from_f64(0.5).unwrap() * pk.weighted_dot(&hessian, &pk);
let rho = reduction_ratio(self.fxk, fxkpk, self.mk0, mkpk);
let pk_norm = pk.norm();
let cur_radius = self.radius;
self.radius = if rho < F::from_f64(0.25).unwrap() {
F::from_f64(0.25).unwrap() * pk_norm
} else if rho > F::from_f64(0.75).unwrap()
&& (pk_norm - self.radius).abs() <= F::from_f64(10.0).unwrap() * F::epsilon()
{
self.max_radius.min(F::from_f64(2.0).unwrap() * self.radius)
} else {
self.radius
};
Ok(if rho > self.eta {
self.fxk = fxkpk;
self.mk0 = fxkpk;
let grad = op.gradient(&new_param)?;
let hessian = op.hessian(&new_param)?;
ArgminIterData::new()
.param(new_param)
.cost(fxkpk)
.grad(grad)
.hessian(hessian)
} else {
ArgminIterData::new().param(param).cost(self.fxk)
}
.kv(make_kv!("radius" => cur_radius;)))
}
fn terminate(&mut self, _state: &IterState<O>) -> TerminationReason {
TerminationReason::NotTerminated
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::trustregion::steihaug::Steihaug;
use crate::test_trait_impl;
type Operator = MinimalNoOperator;
test_trait_impl!(trustregion, TrustRegion<Steihaug<Operator, f64>, f64>);
}