argmin/solver/newton/
newton_method.rs1use crate::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::default::Default;
16
17#[derive(Clone, Serialize, Deserialize)]
27pub struct Newton<F> {
28 gamma: F,
30}
31
32impl<F: ArgminFloat> Newton<F> {
33 pub fn new() -> Self {
35 Newton {
36 gamma: F::from_f64(1.0).unwrap(),
37 }
38 }
39
40 pub fn set_gamma(mut self, gamma: F) -> Result<Self, Error> {
42 if gamma <= F::from_f64(0.0).unwrap() || gamma > F::from_f64(1.0).unwrap() {
43 return Err(ArgminError::InvalidParameter {
44 text: "Newton: gamma must be in (0, 1].".to_string(),
45 }
46 .into());
47 }
48 self.gamma = gamma;
49 Ok(self)
50 }
51}
52
53impl<F: ArgminFloat> Default for Newton<F> {
54 fn default() -> Newton<F> {
55 Newton::new()
56 }
57}
58
59impl<O, F> Solver<O> for Newton<F>
60where
61 O: ArgminOp<Float = F>,
62 O::Param: ArgminScaledSub<O::Param, O::Float, O::Param>,
63 O::Hessian: ArgminInv<O::Hessian> + ArgminDot<O::Param, O::Param>,
64 F: ArgminFloat,
65{
66 const NAME: &'static str = "Newton method";
67
68 fn next_iter(
69 &mut self,
70 op: &mut OpWrapper<O>,
71 state: &IterState<O>,
72 ) -> Result<ArgminIterData<O>, Error> {
73 let param = state.get_param();
74 let grad = op.gradient(¶m)?;
75 let hessian = op.hessian(¶m)?;
76 let new_param = param.scaled_sub(&self.gamma, &hessian.inv()?.dot(&grad));
77 Ok(ArgminIterData::new().param(new_param))
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::test_trait_impl;
85
86 test_trait_impl!(newton_method, Newton<f64>);
87}