argmin/solver/linesearch/
backtracking.rsuse crate::prelude::*;
use crate::solver::linesearch::condition::*;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone)]
pub struct BacktrackingLineSearch<P, L, F> {
init_param: P,
init_cost: F,
init_grad: P,
search_direction: Option<P>,
rho: F,
condition: Box<L>,
alpha: F,
}
impl<P: Default, L, F: ArgminFloat> BacktrackingLineSearch<P, L, F> {
pub fn new(condition: L) -> Self {
BacktrackingLineSearch {
init_param: P::default(),
init_cost: F::infinity(),
init_grad: P::default(),
search_direction: None,
rho: F::from_f64(0.9).unwrap(),
condition: Box::new(condition),
alpha: F::from_f64(1.0).unwrap(),
}
}
pub fn rho(mut self, rho: F) -> Result<Self, Error> {
if rho <= F::from_f64(0.0).unwrap() || rho >= F::from_f64(1.0).unwrap() {
return Err(ArgminError::InvalidParameter {
text: "BacktrackingLineSearch: Contraction factor rho must be in (0, 1)."
.to_string(),
}
.into());
}
self.rho = rho;
Ok(self)
}
}
impl<P, L, F> ArgminLineSearch<P, F> for BacktrackingLineSearch<P, L, F>
where
P: Clone + Serialize + ArgminSub<P, P> + ArgminDot<P, f64> + ArgminScaledAdd<P, f64, P>,
L: LineSearchCondition<P, F>,
F: ArgminFloat + Serialize + DeserializeOwned,
{
fn set_search_direction(&mut self, search_direction: P) {
self.search_direction = Some(search_direction);
}
fn set_init_alpha(&mut self, alpha: F) -> Result<(), Error> {
if alpha <= F::from_f64(0.0).unwrap() {
return Err(ArgminError::InvalidParameter {
text: "LineSearch: Inital alpha must be > 0.".to_string(),
}
.into());
}
self.alpha = alpha;
Ok(())
}
}
impl<O, P, L, F> Solver<O> for BacktrackingLineSearch<P, L, F>
where
P: Clone
+ Default
+ Serialize
+ DeserializeOwned
+ ArgminSub<P, P>
+ ArgminDot<P, F>
+ ArgminScaledAdd<P, F, P>,
O: ArgminOp<Param = P, Output = F, Float = F>,
L: LineSearchCondition<P, F>,
F: ArgminFloat,
{
const NAME: &'static str = "Backtracking Line search";
fn init(
&mut self,
op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<Option<ArgminIterData<O>>, Error> {
self.init_param = state.get_param();
let cost = state.get_cost();
self.init_cost = if cost == F::infinity() {
op.apply(&self.init_param)?
} else {
cost
};
self.init_grad = state.get_grad().unwrap_or(op.gradient(&self.init_param)?);
if self.search_direction.is_none() {
return Err(ArgminError::NotInitialized {
text: "BacktrackingLineSearch: search_direction must be set.".to_string(),
}
.into());
}
Ok(None)
}
fn next_iter(
&mut self,
op: &mut OpWrapper<O>,
_state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
let new_param = self
.init_param
.scaled_add(&self.alpha, self.search_direction.as_ref().unwrap());
let cur_cost = op.apply(&new_param)?;
self.alpha = self.alpha * self.rho;
let mut out = ArgminIterData::new()
.param(new_param.clone())
.cost(cur_cost);
if self.condition.requires_cur_grad() {
out = out.grad(op.gradient(&new_param)?);
}
Ok(out)
}
fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
if self.condition.eval(
state.get_cost(),
state.get_grad().unwrap_or_default(),
self.init_cost,
self.init_grad.clone(),
self.search_direction.clone().unwrap(),
self.alpha,
) {
TerminationReason::LineSearchConditionMet
} else {
TerminationReason::NotTerminated
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::MinimalNoOperator;
use crate::test_trait_impl;
test_trait_impl!(backtrackinglinesearch,
BacktrackingLineSearch<MinimalNoOperator, ArmijoCondition<f64>, f64>);
}