argmin/solver/linesearch/
condition.rs1use crate::core::{ArgminDot, ArgminError, ArgminFloat, Error};
14use serde::{de::DeserializeOwned, Deserialize, Serialize};
15
16pub trait LineSearchCondition<T, F>: Serialize {
18 fn eval(
20 &self,
21 cur_cost: F,
22 cur_grad: T,
23 init_cost: F,
24 init_grad: T,
25 search_direction: T,
26 alpha: F,
27 ) -> bool;
28
29 fn requires_cur_grad(&self) -> bool;
31}
32
33#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
35pub struct ArmijoCondition<F> {
36 c: F,
37}
38
39impl<F: ArgminFloat> ArmijoCondition<F> {
40 pub fn new(c: F) -> Result<Self, Error> {
42 if c <= F::from_f64(0.0).unwrap() || c >= F::from_f64(1.0).unwrap() {
43 return Err(ArgminError::InvalidParameter {
44 text: "ArmijoCondition: Parameter c must be in (0, 1)".to_string(),
45 }
46 .into());
47 }
48 Ok(ArmijoCondition { c })
49 }
50}
51
52impl<T, F> LineSearchCondition<T, F> for ArmijoCondition<F>
53where
54 T: ArgminDot<T, F>,
55 F: ArgminFloat + Serialize + DeserializeOwned,
56{
57 fn eval(
58 &self,
59 cur_cost: F,
60 _cur_grad: T,
61 init_cost: F,
62 init_grad: T,
63 search_direction: T,
64 alpha: F,
65 ) -> bool {
66 cur_cost <= init_cost + self.c * alpha * init_grad.dot(&search_direction)
67 }
68
69 fn requires_cur_grad(&self) -> bool {
70 false
71 }
72}
73
74#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
76pub struct WolfeCondition<F> {
77 c1: F,
78 c2: F,
79}
80
81impl<F: ArgminFloat> WolfeCondition<F> {
82 pub fn new(c1: F, c2: F) -> Result<Self, Error> {
84 if c1 <= F::from_f64(0.0).unwrap() || c1 >= F::from_f64(1.0).unwrap() {
85 return Err(ArgminError::InvalidParameter {
86 text: "WolfeCondition: Parameter c1 must be in (0, 1)".to_string(),
87 }
88 .into());
89 }
90 if c2 <= c1 || c2 >= F::from_f64(1.0).unwrap() {
91 return Err(ArgminError::InvalidParameter {
92 text: "WolfeCondition: Parameter c2 must be in (c1, 1)".to_string(),
93 }
94 .into());
95 }
96 Ok(WolfeCondition { c1, c2 })
97 }
98}
99
100impl<T, F> LineSearchCondition<T, F> for WolfeCondition<F>
101where
102 T: Clone + ArgminDot<T, F>,
103 F: ArgminFloat + DeserializeOwned + Serialize,
104{
105 fn eval(
106 &self,
107 cur_cost: F,
108 cur_grad: T,
109 init_cost: F,
110 init_grad: T,
111 search_direction: T,
112 alpha: F,
113 ) -> bool {
114 let tmp = init_grad.dot(&search_direction);
115 (cur_cost <= init_cost + self.c1 * alpha * tmp)
116 && cur_grad.dot(&search_direction) >= self.c2 * tmp
117 }
118
119 fn requires_cur_grad(&self) -> bool {
120 true
121 }
122}
123
124#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
126pub struct StrongWolfeCondition<F> {
127 c1: F,
128 c2: F,
129}
130
131impl<F: ArgminFloat> StrongWolfeCondition<F> {
132 pub fn new(c1: F, c2: F) -> Result<Self, Error> {
134 if c1 <= F::from_f64(0.0).unwrap() || c1 >= F::from_f64(1.0).unwrap() {
135 return Err(ArgminError::InvalidParameter {
136 text: "StrongWolfeCondition: Parameter c1 must be in (0, 1)".to_string(),
137 }
138 .into());
139 }
140 if c2 <= c1 || c2 >= F::from_f64(1.0).unwrap() {
141 return Err(ArgminError::InvalidParameter {
142 text: "StrongWolfeCondition: Parameter c2 must be in (c1, 1)".to_string(),
143 }
144 .into());
145 }
146 Ok(StrongWolfeCondition { c1, c2 })
147 }
148}
149
150impl<T, F> LineSearchCondition<T, F> for StrongWolfeCondition<F>
151where
152 T: Clone + ArgminDot<T, F>,
153 F: ArgminFloat + Serialize + DeserializeOwned,
154{
155 fn eval(
156 &self,
157 cur_cost: F,
158 cur_grad: T,
159 init_cost: F,
160 init_grad: T,
161 search_direction: T,
162 alpha: F,
163 ) -> bool {
164 let tmp = init_grad.dot(&search_direction);
165 (cur_cost <= init_cost + self.c1 * alpha * tmp)
166 && cur_grad.dot(&search_direction).abs() <= self.c2 * tmp.abs()
167 }
168
169 fn requires_cur_grad(&self) -> bool {
170 true
171 }
172}
173
174#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
176pub struct GoldsteinCondition<F> {
177 c: F,
178}
179
180impl<F: ArgminFloat> GoldsteinCondition<F> {
181 pub fn new(c: F) -> Result<Self, Error> {
183 if c <= F::from_f64(0.0).unwrap() || c >= F::from_f64(0.5).unwrap() {
184 return Err(ArgminError::InvalidParameter {
185 text: "GoldsteinCondition: Parameter c must be in (0, 0.5)".to_string(),
186 }
187 .into());
188 }
189 Ok(GoldsteinCondition { c })
190 }
191}
192
193impl<T, F> LineSearchCondition<T, F> for GoldsteinCondition<F>
194where
195 T: ArgminDot<T, F>,
196 F: ArgminFloat + Serialize + DeserializeOwned,
197{
198 fn eval(
199 &self,
200 cur_cost: F,
201 _cur_grad: T,
202 init_cost: F,
203 init_grad: T,
204 search_direction: T,
205 alpha: F,
206 ) -> bool {
207 let tmp = alpha * init_grad.dot(&search_direction);
208 init_cost + (F::from_f64(1.0).unwrap() - self.c) * tmp <= cur_cost
209 && cur_cost <= init_cost + self.c * alpha * tmp
210 }
211
212 fn requires_cur_grad(&self) -> bool {
213 false
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::test_trait_impl;
221
222 test_trait_impl!(goldstein, GoldsteinCondition<f64>);
223 test_trait_impl!(armijo, ArmijoCondition<f64>);
224 test_trait_impl!(wolfe, WolfeCondition<f64>);
225 test_trait_impl!(strongwolfe, StrongWolfeCondition<f64>);
226}