1use crate::prelude::*;
19use serde::de::DeserializeOwned;
20use serde::{Deserialize, Serialize};
21use std::default::Default;
22
23type Triplet<F> = (F, F, F);
24
25#[derive(Serialize, Deserialize, Clone)]
36pub struct HagerZhangLineSearch<P, F> {
37 delta: F,
39 sigma: F,
41 epsilon: F,
43 epsilon_k: F,
45 theta: F,
48 gamma: F,
50 eta: F,
52 a_x_init: F,
54 a_x: F,
56 a_f: F,
58 a_g: F,
60 b_x_init: F,
62 b_x: F,
64 b_f: F,
66 b_g: F,
68 c_x_init: F,
70 c_x: F,
72 c_f: F,
74 c_g: F,
76 best_x: F,
78 best_f: F,
80 best_g: F,
82 search_direction_b: Option<P>,
84 init_param: P,
86 finit: F,
88 init_grad: P,
90 search_direction: P,
92 dginit: F,
94}
95
96impl<P: Default, F: ArgminFloat> HagerZhangLineSearch<P, F> {
97 pub fn new() -> Self {
99 HagerZhangLineSearch {
100 delta: F::from_f64(0.1).unwrap(),
101 sigma: F::from_f64(0.9).unwrap(),
102 epsilon: F::from_f64(1e-6).unwrap(),
103 epsilon_k: F::nan(),
104 theta: F::from_f64(0.5).unwrap(),
105 gamma: F::from_f64(0.66).unwrap(),
106 eta: F::from_f64(0.01).unwrap(),
107 a_x_init: F::epsilon(),
108 a_x: F::nan(),
109 a_f: F::nan(),
110 a_g: F::nan(),
111 b_x_init: F::from_f64(100.0).unwrap(),
112 b_x: F::nan(),
113 b_f: F::nan(),
114 b_g: F::nan(),
115 c_x_init: F::from_f64(1.0).unwrap(),
116 c_x: F::nan(),
117 c_f: F::nan(),
118 c_g: F::nan(),
119 best_x: F::from_f64(0.0).unwrap(),
120 best_f: F::infinity(),
121 best_g: F::nan(),
122 search_direction_b: None,
123 init_param: P::default(),
124 init_grad: P::default(),
125 search_direction: P::default(),
126 dginit: F::nan(),
127 finit: F::infinity(),
128 }
129 }
130}
131
132impl<P, F> HagerZhangLineSearch<P, F>
133where
134 P: Clone + Default + Serialize + DeserializeOwned + ArgminScaledAdd<P, F, P> + ArgminDot<P, F>,
135 F: ArgminFloat,
136{
137 pub fn delta(mut self, delta: F) -> Result<Self, Error> {
139 if delta <= F::from_f64(0.0).unwrap() {
140 return Err(ArgminError::InvalidParameter {
141 text: "HagerZhangLineSearch: delta must be > 0.0.".to_string(),
142 }
143 .into());
144 }
145 if delta >= F::from_f64(1.0).unwrap() {
146 return Err(ArgminError::InvalidParameter {
147 text: "HagerZhangLineSearch: delta must be < 1.0.".to_string(),
148 }
149 .into());
150 }
151 self.delta = delta;
152 Ok(self)
153 }
154
155 pub fn sigma(mut self, sigma: F) -> Result<Self, Error> {
157 if sigma < self.delta {
158 return Err(ArgminError::InvalidParameter {
159 text: "HagerZhangLineSearch: sigma must be >= delta.".to_string(),
160 }
161 .into());
162 }
163 if sigma >= F::from_f64(1.0).unwrap() {
164 return Err(ArgminError::InvalidParameter {
165 text: "HagerZhangLineSearch: sigma must be < 1.0.".to_string(),
166 }
167 .into());
168 }
169 self.sigma = sigma;
170 Ok(self)
171 }
172
173 pub fn epsilon(mut self, epsilon: F) -> Result<Self, Error> {
175 if epsilon < F::from_f64(0.0).unwrap() {
176 return Err(ArgminError::InvalidParameter {
177 text: "HagerZhangLineSearch: epsilon must be >= 0.0.".to_string(),
178 }
179 .into());
180 }
181 self.epsilon = epsilon;
182 Ok(self)
183 }
184
185 pub fn theta(mut self, theta: F) -> Result<Self, Error> {
187 if theta <= F::from_f64(0.0).unwrap() {
188 return Err(ArgminError::InvalidParameter {
189 text: "HagerZhangLineSearch: theta must be > 0.0.".to_string(),
190 }
191 .into());
192 }
193 if theta >= F::from_f64(1.0).unwrap() {
194 return Err(ArgminError::InvalidParameter {
195 text: "HagerZhangLineSearch: theta must be < 1.0.".to_string(),
196 }
197 .into());
198 }
199 self.theta = theta;
200 Ok(self)
201 }
202
203 pub fn gamma(mut self, gamma: F) -> Result<Self, Error> {
205 if gamma <= F::from_f64(0.0).unwrap() {
206 return Err(ArgminError::InvalidParameter {
207 text: "HagerZhangLineSearch: gamma must be > 0.0.".to_string(),
208 }
209 .into());
210 }
211 if gamma >= F::from_f64(1.0).unwrap() {
212 return Err(ArgminError::InvalidParameter {
213 text: "HagerZhangLineSearch: gamma must be < 1.0.".to_string(),
214 }
215 .into());
216 }
217 self.gamma = gamma;
218 Ok(self)
219 }
220
221 pub fn eta(mut self, eta: F) -> Result<Self, Error> {
223 if eta <= F::from_f64(0.0).unwrap() {
224 return Err(ArgminError::InvalidParameter {
225 text: "HagerZhangLineSearch: eta must be > 0.0.".to_string(),
226 }
227 .into());
228 }
229 self.eta = eta;
230 Ok(self)
231 }
232
233 pub fn alpha(mut self, alpha_min: F, alpha_max: F) -> Result<Self, Error> {
235 if alpha_min < F::from_f64(0.0).unwrap() {
236 return Err(ArgminError::InvalidParameter {
237 text: "HagerZhangLineSearch: alpha_min must be >= 0.0.".to_string(),
238 }
239 .into());
240 }
241 if alpha_max <= alpha_min {
242 return Err(ArgminError::InvalidParameter {
243 text: "HagerZhangLineSearch: alpha_min must be smaller than alpha_max.".to_string(),
244 }
245 .into());
246 }
247 self.a_x_init = alpha_min;
248 self.b_x_init = alpha_max;
249 Ok(self)
250 }
251
252 fn update<O: ArgminOp<Param = P, Output = F>>(
253 &mut self,
254 op: &mut OpWrapper<O>,
255 (a_x, a_f, a_g): Triplet<F>,
256 (b_x, b_f, b_g): Triplet<F>,
257 (c_x, c_f, c_g): Triplet<F>,
258 ) -> Result<(Triplet<F>, Triplet<F>), Error> {
259 if c_x <= a_x || c_x >= b_x {
261 return Ok(((a_x, a_f, a_g), (b_x, b_f, b_g)));
263 }
264
265 if c_g >= F::from_f64(0.0).unwrap() {
267 return Ok(((a_x, a_f, a_g), (c_x, c_f, c_g)));
268 }
269
270 if c_g < F::from_f64(0.0).unwrap() && c_f <= self.finit + self.epsilon_k {
272 return Ok(((c_x, c_f, c_g), (b_x, b_f, b_g)));
273 }
274
275 if c_g < F::from_f64(0.0).unwrap() && c_f > self.finit + self.epsilon_k {
277 let mut ah_x = a_x;
278 let mut ah_f = a_f;
279 let mut ah_g = a_g;
280 let mut bh_x = c_x;
281 loop {
282 let d_x = (F::from_f64(1.0).unwrap() - self.theta) * ah_x + self.theta * bh_x;
283 let d_f = self.calc(op, d_x)?;
284 let d_g = self.calc_grad(op, d_x)?;
285 if d_g >= F::from_f64(0.0).unwrap() {
286 return Ok(((ah_x, ah_f, ah_g), (d_x, d_f, d_g)));
287 }
288 if d_g < F::from_f64(0.0).unwrap() && d_f <= self.finit + self.epsilon_k {
289 ah_x = d_x;
290 ah_f = d_f;
291 ah_g = d_g;
292 }
293 if d_g < F::from_f64(0.0).unwrap() && d_f > self.finit + self.epsilon_k {
294 bh_x = d_x;
295 }
296 }
297 }
298
299 Err(ArgminError::InvalidParameter {
301 text: "HagerZhangLineSearch: Reached unreachable point in `update` method.".to_string(),
302 }
303 .into())
304 }
305
306 fn secant(&self, a_x: F, a_g: F, b_x: F, b_g: F) -> F {
308 (a_x * b_g - b_x * a_g) / (b_g - a_g)
309 }
310
311 fn secant2<O: ArgminOp<Param = P, Output = F>>(
313 &mut self,
314 op: &mut OpWrapper<O>,
315 (a_x, a_f, a_g): Triplet<F>,
316 (b_x, b_f, b_g): Triplet<F>,
317 ) -> Result<(Triplet<F>, Triplet<F>), Error> {
318 let c_x = self.secant(a_x, a_g, b_x, b_g);
320 let c_f = self.calc(op, c_x)?;
321 let c_g = self.calc_grad(op, c_x)?;
322 let mut c_bar_x: F = F::from_f64(0.0).unwrap();
323
324 let ((aa_x, aa_f, aa_g), (bb_x, bb_f, bb_g)) =
325 self.update(op, (a_x, a_f, a_g), (b_x, b_f, b_g), (c_x, c_f, c_g))?;
326
327 if (c_x - bb_x).abs() < F::epsilon() {
329 c_bar_x = self.secant(b_x, b_g, bb_x, bb_g);
330 }
331
332 if (c_x - aa_x).abs() < F::epsilon() {
334 c_bar_x = self.secant(a_x, a_g, aa_x, aa_g);
335 }
336
337 if (c_x - aa_x).abs() < F::epsilon() || (c_x - bb_x).abs() < F::epsilon() {
339 let c_bar_f = self.calc(op, c_bar_x)?;
340 let c_bar_g = self.calc_grad(op, c_bar_x)?;
341
342 let (a_bar, b_bar) = self.update(
343 op,
344 (aa_x, aa_f, aa_g),
345 (bb_x, bb_f, bb_g),
346 (c_bar_x, c_bar_f, c_bar_g),
347 )?;
348 Ok((a_bar, b_bar))
349 } else {
350 Ok(((aa_x, aa_f, aa_g), (bb_x, bb_f, bb_g)))
351 }
352 }
353
354 fn calc<O: ArgminOp<Param = P, Output = F>>(
355 &mut self,
356 op: &mut OpWrapper<O>,
357 alpha: F,
358 ) -> Result<F, Error> {
359 let tmp = self.init_param.scaled_add(&alpha, &self.search_direction);
360 op.apply(&tmp)
361 }
362
363 fn calc_grad<O: ArgminOp<Param = P, Output = F>>(
364 &mut self,
365 op: &mut OpWrapper<O>,
366 alpha: F,
367 ) -> Result<F, Error> {
368 let tmp = self.init_param.scaled_add(&alpha, &self.search_direction);
369 let grad = op.gradient(&tmp)?;
370 Ok(self.search_direction.dot(&grad))
371 }
372
373 fn set_best(&mut self) {
374 if self.a_f < self.b_f && self.a_f < self.c_f {
375 self.best_x = self.a_x;
376 self.best_f = self.a_f;
377 self.best_g = self.a_g;
378 }
379
380 if self.b_f < self.a_f && self.b_f < self.c_f {
381 self.best_x = self.b_x;
382 self.best_f = self.b_f;
383 self.best_g = self.b_g;
384 }
385
386 if self.c_f < self.a_f && self.c_f < self.b_f {
387 self.best_x = self.c_x;
388 self.best_f = self.c_f;
389 self.best_g = self.c_g;
390 }
391 }
392}
393
394impl<P: Default, F: ArgminFloat> Default for HagerZhangLineSearch<P, F> {
395 fn default() -> Self {
396 HagerZhangLineSearch::new()
397 }
398}
399
400impl<P, F> ArgminLineSearch<P, F> for HagerZhangLineSearch<P, F>
401where
402 P: Clone
403 + Default
404 + Serialize
405 + ArgminSub<P, P>
406 + ArgminDot<P, f64>
407 + ArgminScaledAdd<P, f64, P>,
408 F: ArgminFloat,
409{
410 fn set_search_direction(&mut self, search_direction: P) {
412 self.search_direction_b = Some(search_direction);
413 }
414
415 fn set_init_alpha(&mut self, alpha: F) -> Result<(), Error> {
417 self.c_x_init = alpha;
418 Ok(())
419 }
420}
421
422impl<P, O, F> Solver<O> for HagerZhangLineSearch<P, F>
423where
424 O: ArgminOp<Param = P, Output = F, Float = F>,
425 P: Clone
426 + Default
427 + Serialize
428 + DeserializeOwned
429 + ArgminSub<P, P>
430 + ArgminDot<P, F>
431 + ArgminScaledAdd<P, F, P>,
432 F: ArgminFloat,
433{
434 const NAME: &'static str = "Hager-Zhang Line search";
435
436 fn init(
437 &mut self,
438 op: &mut OpWrapper<O>,
439 state: &IterState<O>,
440 ) -> Result<Option<ArgminIterData<O>>, Error> {
441 if self.sigma < self.delta {
442 return Err(ArgminError::InvalidParameter {
443 text: "HagerZhangLineSearch: sigma must be >= delta.".to_string(),
444 }
445 .into());
446 }
447
448 self.search_direction = check_param!(
449 self.search_direction_b,
450 "HagerZhangLineSearch: Search direction not initialized. Call `set_search_direction`."
451 );
452
453 self.init_param = state.get_param();
454
455 let cost = state.get_cost();
456 self.finit = if cost.is_infinite() {
457 op.apply(&self.init_param)?
458 } else {
459 cost
460 };
461
462 self.init_grad = state.get_grad().unwrap_or(op.gradient(&self.init_param)?);
463
464 self.a_x = self.a_x_init;
465 self.b_x = self.b_x_init;
466 self.c_x = self.c_x_init;
467
468 let at = self.a_x;
469 self.a_f = self.calc(op, at)?;
470 self.a_g = self.calc_grad(op, at)?;
471 let bt = self.b_x;
472 self.b_f = self.calc(op, bt)?;
473 self.b_g = self.calc_grad(op, bt)?;
474 let ct = self.c_x;
475 self.c_f = self.calc(op, ct)?;
476 self.c_g = self.calc_grad(op, ct)?;
477
478 self.epsilon_k = self.epsilon * self.finit.abs();
479
480 self.dginit = self.init_grad.dot(&self.search_direction);
481
482 self.set_best();
483 let new_param = self
484 .init_param
485 .scaled_add(&self.best_x, &self.search_direction);
486 let best_f = self.best_f;
487
488 Ok(Some(ArgminIterData::new().param(new_param).cost(best_f)))
489 }
490
491 fn next_iter(
492 &mut self,
493 op: &mut OpWrapper<O>,
494 _state: &IterState<O>,
495 ) -> Result<ArgminIterData<O>, Error> {
496 let aa = (self.a_x, self.a_f, self.a_g);
498 let bb = (self.b_x, self.b_f, self.b_g);
499 let ((mut at_x, mut at_f, mut at_g), (mut bt_x, mut bt_f, mut bt_g)) =
500 self.secant2(op, aa, bb)?;
501
502 if bt_x - at_x > self.gamma * (self.b_x - self.a_x) {
504 let c_x = (at_x + bt_x) / F::from_f64(2.0).unwrap();
505 let tmp = self.init_param.scaled_add(&c_x, &self.search_direction);
506 let c_f = op.apply(&tmp)?;
507 let grad = op.gradient(&tmp)?;
508 let c_g = self.search_direction.dot(&grad);
509 let ((an_x, an_f, an_g), (bn_x, bn_f, bn_g)) =
510 self.update(op, (at_x, at_f, at_g), (bt_x, bt_f, bt_g), (c_x, c_f, c_g))?;
511 at_x = an_x;
512 at_f = an_f;
513 at_g = an_g;
514 bt_x = bn_x;
515 bt_f = bn_f;
516 bt_g = bn_g;
517 }
518
519 self.a_x = at_x;
521 self.a_f = at_f;
522 self.a_g = at_g;
523 self.b_x = bt_x;
524 self.b_f = bt_f;
525 self.b_g = bt_g;
526
527 self.set_best();
528 let new_param = self
529 .init_param
530 .scaled_add(&self.best_x, &self.search_direction);
531 Ok(ArgminIterData::new().param(new_param).cost(self.best_f))
532 }
533
534 fn terminate(&mut self, _state: &IterState<O>) -> TerminationReason {
535 if self.best_f - self.finit < self.delta * self.best_x * self.dginit {
536 return TerminationReason::LineSearchConditionMet;
537 }
538 if self.best_g > self.sigma * self.dginit {
539 return TerminationReason::LineSearchConditionMet;
540 }
541 if (F::from_f64(2.0).unwrap() * self.delta - F::from_f64(1.0).unwrap()) * self.dginit
542 >= self.best_g
543 && self.best_g >= self.sigma * self.dginit
544 && self.best_f <= self.finit + self.epsilon_k
545 {
546 return TerminationReason::LineSearchConditionMet;
547 }
548 TerminationReason::NotTerminated
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555 use crate::core::MinimalNoOperator;
556 use crate::test_trait_impl;
557
558 test_trait_impl!(hagerzhang, HagerZhangLineSearch<MinimalNoOperator, f64>);
559}