argmin/solver/quasinewton/
lbfgs.rsuse crate::prelude::*;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::fmt::Debug;
#[derive(Clone, Serialize, Deserialize)]
pub struct LBFGS<L, P, F> {
linesearch: L,
m: usize,
s: VecDeque<P>,
y: VecDeque<P>,
tol_grad: F,
tol_cost: F,
}
impl<L, P, F: ArgminFloat> LBFGS<L, P, F> {
pub fn new(linesearch: L, m: usize) -> Self {
LBFGS {
linesearch,
m,
s: VecDeque::with_capacity(m),
y: VecDeque::with_capacity(m),
tol_grad: F::epsilon().sqrt(),
tol_cost: F::epsilon(),
}
}
pub fn with_tol_grad(mut self, tol_grad: F) -> Self {
self.tol_grad = tol_grad;
self
}
pub fn with_tol_cost(mut self, tol_cost: F) -> Self {
self.tol_cost = tol_cost;
self
}
}
impl<O, L, P, F> Solver<O> for LBFGS<L, P, F>
where
O: ArgminOp<Param = P, Output = F, Float = F>,
O::Param: Clone
+ Serialize
+ DeserializeOwned
+ Debug
+ Default
+ ArgminSub<O::Param, O::Param>
+ ArgminAdd<O::Param, O::Param>
+ ArgminDot<O::Param, O::Float>
+ ArgminScaledAdd<O::Param, O::Float, O::Param>
+ ArgminNorm<O::Float>
+ ArgminMul<O::Float, O::Param>,
O::Hessian: Clone + Default + Serialize + DeserializeOwned,
L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
F: ArgminFloat,
{
const NAME: &'static str = "L-BFGS";
fn init(
&mut self,
op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<Option<ArgminIterData<O>>, Error> {
let param = state.get_param();
let cost = op.apply(¶m)?;
let grad = op.gradient(¶m)?;
Ok(Some(
ArgminIterData::new().param(param).cost(cost).grad(grad),
))
}
fn next_iter(
&mut self,
op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
let param = state.get_param();
let cur_cost = state.get_cost();
let prev_grad = state.get_grad().unwrap();
let gamma: F = if let (Some(ref sk), Some(ref yk)) = (self.s.back(), self.y.back()) {
sk.dot(*yk) / yk.dot(*yk)
} else {
F::from_f64(1.0).unwrap()
};
let mut q = prev_grad.clone();
let cur_m = self.s.len();
let mut alpha: Vec<F> = vec![F::from_f64(0.0).unwrap(); cur_m];
let mut rho: Vec<F> = vec![F::from_f64(0.0).unwrap(); cur_m];
for (i, (ref sk, ref yk)) in self.s.iter().rev().zip(self.y.iter().rev()).enumerate() {
let sk = *sk;
let yk = *yk;
let yksk: F = yk.dot(sk);
let rho_t = F::from_f64(1.0).unwrap() / yksk;
let skq: F = sk.dot(&q);
let alpha_t = skq.mul(rho_t);
q = q.sub(&yk.mul(&alpha_t));
rho[cur_m - i - 1] = rho_t;
alpha[cur_m - i - 1] = alpha_t;
}
let mut r = q.mul(&gamma);
for (i, (ref sk, ref yk)) in self.s.iter().zip(self.y.iter()).enumerate() {
let sk = *sk;
let yk = *yk;
let beta = yk.dot(&r).mul(rho[i]);
r = r.add(&sk.mul(&(alpha[i] - beta)));
}
self.linesearch
.set_search_direction(r.mul(&F::from_f64(-1.0).unwrap()));
let ArgminResult {
operator: line_op,
state:
IterState {
param: xk1,
cost: next_cost,
..
},
} = Executor::new(
OpWrapper::new_from_wrapper(op),
self.linesearch.clone(),
param.clone(),
)
.grad(prev_grad.clone())
.cost(cur_cost)
.ctrlc(false)
.run()?;
op.consume_op(line_op);
if state.get_iter() >= self.m as u64 {
self.s.pop_front();
self.y.pop_front();
}
let grad = op.gradient(&xk1)?;
self.s.push_back(xk1.sub(¶m));
self.y.push_back(grad.sub(&prev_grad));
Ok(ArgminIterData::new()
.param(xk1)
.cost(next_cost)
.grad(grad)
.kv(make_kv!("gamma" => gamma;)))
}
fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
if state.get_grad().unwrap().norm() < self.tol_grad {
return TerminationReason::TargetPrecisionReached;
}
if (state.get_prev_cost() - state.get_cost()).abs() < self.tol_cost {
return TerminationReason::NoChangeInCost;
}
TerminationReason::NotTerminated
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::linesearch::MoreThuenteLineSearch;
use crate::test_trait_impl;
type Operator = MinimalNoOperator;
test_trait_impl!(lbfgs, LBFGS<Operator, MoreThuenteLineSearch<Operator, f64>, f64>);
#[test]
fn test_tolerances() {
let linesearch: MoreThuenteLineSearch<f64, f64> =
MoreThuenteLineSearch::new().c(1e-4, 0.9).unwrap();
let tol1 = 1e-4;
let tol2 = 1e-2;
let LBFGS {
tol_grad: t1,
tol_cost: t2,
..
}: LBFGS<_, f64, f64> = LBFGS::new(linesearch, 7)
.with_tol_grad(tol1)
.with_tol_cost(tol2);
assert!((t1 - tol1).abs() < std::f64::EPSILON);
assert!((t2 - tol2).abs() < std::f64::EPSILON);
}
}