argmin/solver/gaussnewton/
gaussnewton_method.rsuse crate::prelude::*;
use serde::{Deserialize, Serialize};
use std::default::Default;
#[derive(Clone, Serialize, Deserialize)]
pub struct GaussNewton<F> {
gamma: F,
tol: F,
}
impl<F: ArgminFloat> GaussNewton<F> {
pub fn new() -> Self {
GaussNewton {
gamma: F::from_f64(1.0).unwrap(),
tol: F::epsilon().sqrt(),
}
}
pub fn with_gamma(mut self, gamma: F) -> Result<Self, Error> {
if gamma <= F::from_f64(0.0).unwrap() || gamma > F::from_f64(1.0).unwrap() {
return Err(ArgminError::InvalidParameter {
text: "Gauss-Newton: gamma must be in (0, 1].".to_string(),
}
.into());
}
self.gamma = gamma;
Ok(self)
}
pub fn with_tol(mut self, tol: F) -> Result<Self, Error> {
if tol <= F::from_f64(0.0).unwrap() {
return Err(ArgminError::InvalidParameter {
text: "Gauss-Newton: tol must be positive.".to_string(),
}
.into());
}
self.tol = tol;
Ok(self)
}
}
impl<F: ArgminFloat> Default for GaussNewton<F> {
fn default() -> GaussNewton<F> {
GaussNewton::new()
}
}
impl<O, F> Solver<O> for GaussNewton<F>
where
O: ArgminOp<Float = F>,
O::Param: Default
+ ArgminScaledSub<O::Param, O::Float, O::Param>
+ ArgminSub<O::Param, O::Param>
+ ArgminMul<O::Float, O::Param>,
O::Output: ArgminNorm<O::Float>,
O::Jacobian: ArgminTranspose
+ ArgminInv<O::Jacobian>
+ ArgminDot<O::Jacobian, O::Jacobian>
+ ArgminDot<O::Output, O::Param>
+ ArgminDot<O::Param, O::Param>,
O::Hessian: Default,
F: ArgminFloat,
{
const NAME: &'static str = "Gauss-Newton method";
fn next_iter(
&mut self,
op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
let param = state.get_param();
let residuals = op.apply(¶m)?;
let jacobian = op.jacobian(¶m)?;
let p = jacobian
.clone()
.t()
.dot(&jacobian)
.inv()?
.dot(&jacobian.t().dot(&residuals));
let new_param = param.sub(&p.mul(&self.gamma));
Ok(ArgminIterData::new()
.param(new_param)
.cost(residuals.norm()))
}
fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
if (state.get_prev_cost() - state.get_cost()).abs() < self.tol {
return TerminationReason::NoChangeInCost;
}
TerminationReason::NotTerminated
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_trait_impl;
test_trait_impl!(gauss_newton_method, GaussNewton<f64>);
#[test]
fn test_tolerance() {
let tol1: f64 = 1e-4;
let GaussNewton { tol: t, .. } = GaussNewton::new().with_tol(tol1).unwrap();
assert!((t - tol1).abs() < std::f64::EPSILON);
}
#[test]
fn test_gamma() {
let gamma: f64 = 0.5;
let GaussNewton { gamma: g, .. } = GaussNewton::new().with_gamma(gamma).unwrap();
assert!((g - gamma).abs() < std::f64::EPSILON);
}
}