argmin/solver/linesearch/
backtracking.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! * [Backtracking line search](struct.BacktrackingLineSearch.html)
9
10use crate::prelude::*;
11use crate::solver::linesearch::condition::*;
12use serde::de::DeserializeOwned;
13use serde::{Deserialize, Serialize};
14
15/// The Backtracking line search is a simple method to find a step length which obeys the Armijo
16/// (sufficient decrease) condition.
17///
18/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/backtracking.rs)
19///
20/// # References:
21///
22/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
23/// Springer. ISBN 0-387-30303-0.
24///
25/// [1] Wikipedia: https://en.wikipedia.org/wiki/Backtracking_line_search
26#[derive(Serialize, Deserialize, Clone)]
27pub struct BacktrackingLineSearch<P, L, F> {
28    /// initial parameter vector
29    init_param: P,
30    /// initial cost
31    init_cost: F,
32    /// initial gradient
33    init_grad: P,
34    /// Search direction
35    search_direction: Option<P>,
36    /// Contraction factor rho
37    rho: F,
38    /// Stopping condition
39    condition: Box<L>,
40    /// alpha
41    alpha: F,
42}
43
44impl<P: Default, L, F: ArgminFloat> BacktrackingLineSearch<P, L, F> {
45    /// Constructor
46    pub fn new(condition: L) -> Self {
47        BacktrackingLineSearch {
48            init_param: P::default(),
49            init_cost: F::infinity(),
50            init_grad: P::default(),
51            search_direction: None,
52            rho: F::from_f64(0.9).unwrap(),
53            condition: Box::new(condition),
54            alpha: F::from_f64(1.0).unwrap(),
55        }
56    }
57
58    /// Set rho
59    pub fn rho(mut self, rho: F) -> Result<Self, Error> {
60        if rho <= F::from_f64(0.0).unwrap() || rho >= F::from_f64(1.0).unwrap() {
61            return Err(ArgminError::InvalidParameter {
62                text: "BacktrackingLineSearch: Contraction factor rho must be in (0, 1)."
63                    .to_string(),
64            }
65            .into());
66        }
67        self.rho = rho;
68        Ok(self)
69    }
70}
71
72impl<P, L, F> ArgminLineSearch<P, F> for BacktrackingLineSearch<P, L, F>
73where
74    P: Clone + Serialize + ArgminSub<P, P> + ArgminDot<P, f64> + ArgminScaledAdd<P, f64, P>,
75    L: LineSearchCondition<P, F>,
76    F: ArgminFloat + Serialize + DeserializeOwned,
77{
78    /// Set search direction
79    fn set_search_direction(&mut self, search_direction: P) {
80        self.search_direction = Some(search_direction);
81    }
82
83    /// Set initial alpha value
84    fn set_init_alpha(&mut self, alpha: F) -> Result<(), Error> {
85        if alpha <= F::from_f64(0.0).unwrap() {
86            return Err(ArgminError::InvalidParameter {
87                text: "LineSearch: Inital alpha must be > 0.".to_string(),
88            }
89            .into());
90        }
91        self.alpha = alpha;
92        Ok(())
93    }
94}
95
96impl<O, P, L, F> Solver<O> for BacktrackingLineSearch<P, L, F>
97where
98    P: Clone
99        + Default
100        + Serialize
101        + DeserializeOwned
102        + ArgminSub<P, P>
103        + ArgminDot<P, F>
104        + ArgminScaledAdd<P, F, P>,
105    O: ArgminOp<Param = P, Output = F, Float = F>,
106    L: LineSearchCondition<P, F>,
107    F: ArgminFloat,
108{
109    const NAME: &'static str = "Backtracking Line search";
110
111    fn init(
112        &mut self,
113        op: &mut OpWrapper<O>,
114        state: &IterState<O>,
115    ) -> Result<Option<ArgminIterData<O>>, Error> {
116        self.init_param = state.get_param();
117        let cost = state.get_cost();
118        self.init_cost = if cost == F::infinity() {
119            op.apply(&self.init_param)?
120        } else {
121            cost
122        };
123
124        self.init_grad = state.get_grad().unwrap_or(op.gradient(&self.init_param)?);
125
126        if self.search_direction.is_none() {
127            return Err(ArgminError::NotInitialized {
128                text: "BacktrackingLineSearch: search_direction must be set.".to_string(),
129            }
130            .into());
131        }
132
133        Ok(None)
134    }
135
136    fn next_iter(
137        &mut self,
138        op: &mut OpWrapper<O>,
139        _state: &IterState<O>,
140    ) -> Result<ArgminIterData<O>, Error> {
141        let new_param = self
142            .init_param
143            .scaled_add(&self.alpha, self.search_direction.as_ref().unwrap());
144
145        let cur_cost = op.apply(&new_param)?;
146
147        self.alpha = self.alpha * self.rho;
148
149        let mut out = ArgminIterData::new()
150            .param(new_param.clone())
151            .cost(cur_cost);
152
153        if self.condition.requires_cur_grad() {
154            out = out.grad(op.gradient(&new_param)?);
155        }
156
157        Ok(out)
158    }
159
160    fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
161        if self.condition.eval(
162            state.get_cost(),
163            state.get_grad().unwrap_or_default(),
164            self.init_cost,
165            self.init_grad.clone(),
166            self.search_direction.clone().unwrap(),
167            self.alpha,
168        ) {
169            TerminationReason::LineSearchConditionMet
170        } else {
171            TerminationReason::NotTerminated
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::core::MinimalNoOperator;
180    use crate::test_trait_impl;
181
182    test_trait_impl!(backtrackinglinesearch,
183                    BacktrackingLineSearch<MinimalNoOperator, ArmijoCondition<f64>, f64>);
184}