argmin/solver/linesearch/
condition.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::core::{ArgminDot, ArgminError, ArgminFloat, Error};
14use serde::{de::DeserializeOwned, Deserialize, Serialize};
15
16/// Needs to be implemented by everything that wants to be a LineSearchCondition
17pub trait LineSearchCondition<T, F>: Serialize {
18    /// Evaluate the condition
19    fn eval(
20        &self,
21        cur_cost: F,
22        cur_grad: T,
23        init_cost: F,
24        init_grad: T,
25        search_direction: T,
26        alpha: F,
27    ) -> bool;
28
29    /// Indicates whether this condition requires the computation of the gradient at the new point
30    fn requires_cur_grad(&self) -> bool;
31}
32
33/// Armijo Condition
34#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
35pub struct ArmijoCondition<F> {
36    c: F,
37}
38
39impl<F: ArgminFloat> ArmijoCondition<F> {
40    /// Constructor
41    pub fn new(c: F) -> Result<Self, Error> {
42        if c <= F::from_f64(0.0).unwrap() || c >= F::from_f64(1.0).unwrap() {
43            return Err(ArgminError::InvalidParameter {
44                text: "ArmijoCondition: Parameter c must be in (0, 1)".to_string(),
45            }
46            .into());
47        }
48        Ok(ArmijoCondition { c })
49    }
50}
51
52impl<T, F> LineSearchCondition<T, F> for ArmijoCondition<F>
53where
54    T: ArgminDot<T, F>,
55    F: ArgminFloat + Serialize + DeserializeOwned,
56{
57    fn eval(
58        &self,
59        cur_cost: F,
60        _cur_grad: T,
61        init_cost: F,
62        init_grad: T,
63        search_direction: T,
64        alpha: F,
65    ) -> bool {
66        cur_cost <= init_cost + self.c * alpha * init_grad.dot(&search_direction)
67    }
68
69    fn requires_cur_grad(&self) -> bool {
70        false
71    }
72}
73
74/// Wolfe Condition
75#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
76pub struct WolfeCondition<F> {
77    c1: F,
78    c2: F,
79}
80
81impl<F: ArgminFloat> WolfeCondition<F> {
82    /// Constructor
83    pub fn new(c1: F, c2: F) -> Result<Self, Error> {
84        if c1 <= F::from_f64(0.0).unwrap() || c1 >= F::from_f64(1.0).unwrap() {
85            return Err(ArgminError::InvalidParameter {
86                text: "WolfeCondition: Parameter c1 must be in (0, 1)".to_string(),
87            }
88            .into());
89        }
90        if c2 <= c1 || c2 >= F::from_f64(1.0).unwrap() {
91            return Err(ArgminError::InvalidParameter {
92                text: "WolfeCondition: Parameter c2 must be in (c1, 1)".to_string(),
93            }
94            .into());
95        }
96        Ok(WolfeCondition { c1, c2 })
97    }
98}
99
100impl<T, F> LineSearchCondition<T, F> for WolfeCondition<F>
101where
102    T: Clone + ArgminDot<T, F>,
103    F: ArgminFloat + DeserializeOwned + Serialize,
104{
105    fn eval(
106        &self,
107        cur_cost: F,
108        cur_grad: T,
109        init_cost: F,
110        init_grad: T,
111        search_direction: T,
112        alpha: F,
113    ) -> bool {
114        let tmp = init_grad.dot(&search_direction);
115        (cur_cost <= init_cost + self.c1 * alpha * tmp)
116            && cur_grad.dot(&search_direction) >= self.c2 * tmp
117    }
118
119    fn requires_cur_grad(&self) -> bool {
120        true
121    }
122}
123
124/// Strong Wolfe conditions
125#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
126pub struct StrongWolfeCondition<F> {
127    c1: F,
128    c2: F,
129}
130
131impl<F: ArgminFloat> StrongWolfeCondition<F> {
132    /// Constructor
133    pub fn new(c1: F, c2: F) -> Result<Self, Error> {
134        if c1 <= F::from_f64(0.0).unwrap() || c1 >= F::from_f64(1.0).unwrap() {
135            return Err(ArgminError::InvalidParameter {
136                text: "StrongWolfeCondition: Parameter c1 must be in (0, 1)".to_string(),
137            }
138            .into());
139        }
140        if c2 <= c1 || c2 >= F::from_f64(1.0).unwrap() {
141            return Err(ArgminError::InvalidParameter {
142                text: "StrongWolfeCondition: Parameter c2 must be in (c1, 1)".to_string(),
143            }
144            .into());
145        }
146        Ok(StrongWolfeCondition { c1, c2 })
147    }
148}
149
150impl<T, F> LineSearchCondition<T, F> for StrongWolfeCondition<F>
151where
152    T: Clone + ArgminDot<T, F>,
153    F: ArgminFloat + Serialize + DeserializeOwned,
154{
155    fn eval(
156        &self,
157        cur_cost: F,
158        cur_grad: T,
159        init_cost: F,
160        init_grad: T,
161        search_direction: T,
162        alpha: F,
163    ) -> bool {
164        let tmp = init_grad.dot(&search_direction);
165        (cur_cost <= init_cost + self.c1 * alpha * tmp)
166            && cur_grad.dot(&search_direction).abs() <= self.c2 * tmp.abs()
167    }
168
169    fn requires_cur_grad(&self) -> bool {
170        true
171    }
172}
173
174/// Goldstein conditions
175#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
176pub struct GoldsteinCondition<F> {
177    c: F,
178}
179
180impl<F: ArgminFloat> GoldsteinCondition<F> {
181    /// Constructor
182    pub fn new(c: F) -> Result<Self, Error> {
183        if c <= F::from_f64(0.0).unwrap() || c >= F::from_f64(0.5).unwrap() {
184            return Err(ArgminError::InvalidParameter {
185                text: "GoldsteinCondition: Parameter c must be in (0, 0.5)".to_string(),
186            }
187            .into());
188        }
189        Ok(GoldsteinCondition { c })
190    }
191}
192
193impl<T, F> LineSearchCondition<T, F> for GoldsteinCondition<F>
194where
195    T: ArgminDot<T, F>,
196    F: ArgminFloat + Serialize + DeserializeOwned,
197{
198    fn eval(
199        &self,
200        cur_cost: F,
201        _cur_grad: T,
202        init_cost: F,
203        init_grad: T,
204        search_direction: T,
205        alpha: F,
206    ) -> bool {
207        let tmp = alpha * init_grad.dot(&search_direction);
208        init_cost + (F::from_f64(1.0).unwrap() - self.c) * tmp <= cur_cost
209            && cur_cost <= init_cost + self.c * alpha * tmp
210    }
211
212    fn requires_cur_grad(&self) -> bool {
213        false
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::test_trait_impl;
221
222    test_trait_impl!(goldstein, GoldsteinCondition<f64>);
223    test_trait_impl!(armijo, ArmijoCondition<f64>);
224    test_trait_impl!(wolfe, WolfeCondition<f64>);
225    test_trait_impl!(strongwolfe, StrongWolfeCondition<f64>);
226}