argmin/solver/gradientdescent/
steepestdescent.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Steepest Descent method
9//!
10//! [SteepestDescent](struct.SteepestDescent.html)
11//!
12//! # References:
13//!
14//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
15//! Springer. ISBN 0-387-30303-0.
16
17use crate::prelude::*;
18use serde::{Deserialize, Serialize};
19
20/// Steepest descent iteratively takes steps in the direction of the strongest negative gradient.
21/// In each iteration, a line search is employed to obtain an appropriate step length.
22///
23/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/steepestdescent.rs)
24///
25/// # References:
26///
27/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
28/// Springer. ISBN 0-387-30303-0.
29#[derive(Clone, Serialize, Deserialize)]
30pub struct SteepestDescent<L> {
31    /// line search
32    linesearch: L,
33}
34
35impl<L> SteepestDescent<L> {
36    /// Constructor
37    pub fn new(linesearch: L) -> Self {
38        SteepestDescent { linesearch }
39    }
40}
41
42impl<O, L, F> Solver<O> for SteepestDescent<L>
43where
44    O: ArgminOp<Output = F, Float = F>,
45    O::Param: Clone
46        + Default
47        + Serialize
48        + ArgminSub<O::Param, O::Param>
49        + ArgminDot<O::Param, O::Float>
50        + ArgminScaledAdd<O::Param, O::Float, O::Param>
51        + ArgminMul<O::Float, O::Param>
52        + ArgminSub<O::Param, O::Param>
53        + ArgminNorm<O::Float>,
54    O::Hessian: Default,
55    L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
56    F: ArgminFloat,
57{
58    const NAME: &'static str = "Steepest Descent";
59
60    fn next_iter(
61        &mut self,
62        op: &mut OpWrapper<O>,
63        state: &IterState<O>,
64    ) -> Result<ArgminIterData<O>, Error> {
65        let param_new = state.get_param();
66        let new_cost = op.apply(&param_new)?;
67        let new_grad = op.gradient(&param_new)?;
68
69        self.linesearch
70            .set_search_direction(new_grad.mul(&(O::Float::from_f64(-1.0).unwrap())));
71
72        // Run solver
73        let ArgminResult {
74            operator: line_op,
75            state:
76                IterState {
77                    param: next_param,
78                    cost: next_cost,
79                    ..
80                },
81        } = Executor::new(
82            OpWrapper::new_from_wrapper(op),
83            self.linesearch.clone(),
84            param_new,
85        )
86        .grad(new_grad)
87        .cost(new_cost)
88        .ctrlc(false)
89        .run()?;
90
91        // Get back operator and function evaluation counts
92        op.consume_op(line_op);
93
94        Ok(ArgminIterData::new().param(next_param).cost(next_cost))
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::solver::linesearch::MoreThuenteLineSearch;
102    use crate::test_trait_impl;
103
104    test_trait_impl!(
105        steepest_descent,
106        SteepestDescent<MoreThuenteLineSearch<Vec<f64>, f64>>
107    );
108}