argmin/solver/gradientdescent/
steepestdescent.rs1use crate::prelude::*;
18use serde::{Deserialize, Serialize};
19
20#[derive(Clone, Serialize, Deserialize)]
30pub struct SteepestDescent<L> {
31 linesearch: L,
33}
34
35impl<L> SteepestDescent<L> {
36 pub fn new(linesearch: L) -> Self {
38 SteepestDescent { linesearch }
39 }
40}
41
42impl<O, L, F> Solver<O> for SteepestDescent<L>
43where
44 O: ArgminOp<Output = F, Float = F>,
45 O::Param: Clone
46 + Default
47 + Serialize
48 + ArgminSub<O::Param, O::Param>
49 + ArgminDot<O::Param, O::Float>
50 + ArgminScaledAdd<O::Param, O::Float, O::Param>
51 + ArgminMul<O::Float, O::Param>
52 + ArgminSub<O::Param, O::Param>
53 + ArgminNorm<O::Float>,
54 O::Hessian: Default,
55 L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
56 F: ArgminFloat,
57{
58 const NAME: &'static str = "Steepest Descent";
59
60 fn next_iter(
61 &mut self,
62 op: &mut OpWrapper<O>,
63 state: &IterState<O>,
64 ) -> Result<ArgminIterData<O>, Error> {
65 let param_new = state.get_param();
66 let new_cost = op.apply(¶m_new)?;
67 let new_grad = op.gradient(¶m_new)?;
68
69 self.linesearch
70 .set_search_direction(new_grad.mul(&(O::Float::from_f64(-1.0).unwrap())));
71
72 let ArgminResult {
74 operator: line_op,
75 state:
76 IterState {
77 param: next_param,
78 cost: next_cost,
79 ..
80 },
81 } = Executor::new(
82 OpWrapper::new_from_wrapper(op),
83 self.linesearch.clone(),
84 param_new,
85 )
86 .grad(new_grad)
87 .cost(new_cost)
88 .ctrlc(false)
89 .run()?;
90
91 op.consume_op(line_op);
93
94 Ok(ArgminIterData::new().param(next_param).cost(next_cost))
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use crate::solver::linesearch::MoreThuenteLineSearch;
102 use crate::test_trait_impl;
103
104 test_trait_impl!(
105 steepest_descent,
106 SteepestDescent<MoreThuenteLineSearch<Vec<f64>, f64>>
107 );
108}