argmin/solver/gradientdescent/steepestdescent.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
// Copyright 2018-2020 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
//! Steepest Descent method
//!
//! [SteepestDescent](struct.SteepestDescent.html)
//!
//! # References:
//!
//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
//! Springer. ISBN 0-387-30303-0.
use crate::prelude::*;
use serde::{Deserialize, Serialize};
/// Steepest descent iteratively takes steps in the direction of the strongest negative gradient.
/// In each iteration, a line search is employed to obtain an appropriate step length.
///
/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/steepestdescent.rs)
///
/// # References:
///
/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
/// Springer. ISBN 0-387-30303-0.
#[derive(Clone, Serialize, Deserialize)]
pub struct SteepestDescent<L> {
/// line search
linesearch: L,
}
impl<L> SteepestDescent<L> {
/// Constructor
pub fn new(linesearch: L) -> Self {
SteepestDescent { linesearch }
}
}
impl<O, L, F> Solver<O> for SteepestDescent<L>
where
O: ArgminOp<Output = F, Float = F>,
O::Param: Clone
+ Default
+ Serialize
+ ArgminSub<O::Param, O::Param>
+ ArgminDot<O::Param, O::Float>
+ ArgminScaledAdd<O::Param, O::Float, O::Param>
+ ArgminMul<O::Float, O::Param>
+ ArgminSub<O::Param, O::Param>
+ ArgminNorm<O::Float>,
O::Hessian: Default,
L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
F: ArgminFloat,
{
const NAME: &'static str = "Steepest Descent";
fn next_iter(
&mut self,
op: &mut OpWrapper<O>,
state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
let param_new = state.get_param();
let new_cost = op.apply(¶m_new)?;
let new_grad = op.gradient(¶m_new)?;
self.linesearch
.set_search_direction(new_grad.mul(&(O::Float::from_f64(-1.0).unwrap())));
// Run solver
let ArgminResult {
operator: line_op,
state:
IterState {
param: next_param,
cost: next_cost,
..
},
} = Executor::new(
OpWrapper::new_from_wrapper(op),
self.linesearch.clone(),
param_new,
)
.grad(new_grad)
.cost(new_cost)
.ctrlc(false)
.run()?;
// Get back operator and function evaluation counts
op.consume_op(line_op);
Ok(ArgminIterData::new().param(next_param).cost(next_cost))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::linesearch::MoreThuenteLineSearch;
use crate::test_trait_impl;
test_trait_impl!(
steepest_descent,
SteepestDescent<MoreThuenteLineSearch<Vec<f64>, f64>>
);
}