argmin/solver/linesearch/
hagerzhang.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! * [Hager-Zhang line search](struct.HagerZhangLineSearch.html)
9//!
10//! TODO: Not all stopping criteria implemented
11//!
12//! # Reference
13//!
14//! William W. Hager and Hongchao Zhang. "A new conjugate gradient method with guaranteed descent
15//! and an efficient line search." SIAM J. Optim. 16(1), 2006, 170-192.
16//! DOI: https://doi.org/10.1137/030601880
17
18use crate::prelude::*;
19use serde::de::DeserializeOwned;
20use serde::{Deserialize, Serialize};
21use std::default::Default;
22
23type Triplet<F> = (F, F, F);
24
25/// The Hager-Zhang line search is a method to find a step length which obeys the strong Wolfe
26/// conditions.
27///
28/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/hagerzhang.rs)
29///
30/// # References
31///
32/// [0] William W. Hager and Hongchao Zhang. "A new conjugate gradient method with guaranteed
33/// descent and an efficient line search." SIAM J. Optim. 16(1), 2006, 170-192.
34/// DOI: https://doi.org/10.1137/030601880
35#[derive(Serialize, Deserialize, Clone)]
36pub struct HagerZhangLineSearch<P, F> {
37    /// delta: (0, 0.5), used in the Wolve conditions
38    delta: F,
39    /// sigma: [delta, 1), used in the Wolfe conditions
40    sigma: F,
41    /// epsilon: [0, infinity), used in the approximate Wolfe termination
42    epsilon: F,
43    /// epsilon_k
44    epsilon_k: F,
45    /// theta: (0, 1), used in the update rules when the potential intervals [a, c] or [c, b]
46    /// viloate the opposite slope condition
47    theta: F,
48    /// gamma: (0, 1), determines when a bisection step is performed
49    gamma: F,
50    /// eta: (0, infinity), used in the lower bound for beta_k^N
51    eta: F,
52    /// initial a
53    a_x_init: F,
54    /// a
55    a_x: F,
56    /// phi(a)
57    a_f: F,
58    /// phi'(a)
59    a_g: F,
60    /// initial b
61    b_x_init: F,
62    /// b
63    b_x: F,
64    /// phi(b)
65    b_f: F,
66    /// phi'(b)
67    b_g: F,
68    /// initial c
69    c_x_init: F,
70    /// c
71    c_x: F,
72    /// phi(c)
73    c_f: F,
74    /// phi'(c)
75    c_g: F,
76    /// best x
77    best_x: F,
78    /// best function value
79    best_f: F,
80    /// best slope
81    best_g: F,
82    /// Search direction (builder)
83    search_direction_b: Option<P>,
84    /// initial parameter vector
85    init_param: P,
86    /// initial cost
87    finit: F,
88    /// initial gradient (builder)
89    init_grad: P,
90    /// Search direction (builder)
91    search_direction: P,
92    /// Search direction in 1D
93    dginit: F,
94}
95
96impl<P: Default, F: ArgminFloat> HagerZhangLineSearch<P, F> {
97    /// Constructor
98    pub fn new() -> Self {
99        HagerZhangLineSearch {
100            delta: F::from_f64(0.1).unwrap(),
101            sigma: F::from_f64(0.9).unwrap(),
102            epsilon: F::from_f64(1e-6).unwrap(),
103            epsilon_k: F::nan(),
104            theta: F::from_f64(0.5).unwrap(),
105            gamma: F::from_f64(0.66).unwrap(),
106            eta: F::from_f64(0.01).unwrap(),
107            a_x_init: F::epsilon(),
108            a_x: F::nan(),
109            a_f: F::nan(),
110            a_g: F::nan(),
111            b_x_init: F::from_f64(100.0).unwrap(),
112            b_x: F::nan(),
113            b_f: F::nan(),
114            b_g: F::nan(),
115            c_x_init: F::from_f64(1.0).unwrap(),
116            c_x: F::nan(),
117            c_f: F::nan(),
118            c_g: F::nan(),
119            best_x: F::from_f64(0.0).unwrap(),
120            best_f: F::infinity(),
121            best_g: F::nan(),
122            search_direction_b: None,
123            init_param: P::default(),
124            init_grad: P::default(),
125            search_direction: P::default(),
126            dginit: F::nan(),
127            finit: F::infinity(),
128        }
129    }
130}
131
132impl<P, F> HagerZhangLineSearch<P, F>
133where
134    P: Clone + Default + Serialize + DeserializeOwned + ArgminScaledAdd<P, F, P> + ArgminDot<P, F>,
135    F: ArgminFloat,
136{
137    /// set delta
138    pub fn delta(mut self, delta: F) -> Result<Self, Error> {
139        if delta <= F::from_f64(0.0).unwrap() {
140            return Err(ArgminError::InvalidParameter {
141                text: "HagerZhangLineSearch: delta must be > 0.0.".to_string(),
142            }
143            .into());
144        }
145        if delta >= F::from_f64(1.0).unwrap() {
146            return Err(ArgminError::InvalidParameter {
147                text: "HagerZhangLineSearch: delta must be < 1.0.".to_string(),
148            }
149            .into());
150        }
151        self.delta = delta;
152        Ok(self)
153    }
154
155    /// set sigma
156    pub fn sigma(mut self, sigma: F) -> Result<Self, Error> {
157        if sigma < self.delta {
158            return Err(ArgminError::InvalidParameter {
159                text: "HagerZhangLineSearch: sigma must be >= delta.".to_string(),
160            }
161            .into());
162        }
163        if sigma >= F::from_f64(1.0).unwrap() {
164            return Err(ArgminError::InvalidParameter {
165                text: "HagerZhangLineSearch: sigma must be < 1.0.".to_string(),
166            }
167            .into());
168        }
169        self.sigma = sigma;
170        Ok(self)
171    }
172
173    /// set epsilon
174    pub fn epsilon(mut self, epsilon: F) -> Result<Self, Error> {
175        if epsilon < F::from_f64(0.0).unwrap() {
176            return Err(ArgminError::InvalidParameter {
177                text: "HagerZhangLineSearch: epsilon must be >= 0.0.".to_string(),
178            }
179            .into());
180        }
181        self.epsilon = epsilon;
182        Ok(self)
183    }
184
185    /// set theta
186    pub fn theta(mut self, theta: F) -> Result<Self, Error> {
187        if theta <= F::from_f64(0.0).unwrap() {
188            return Err(ArgminError::InvalidParameter {
189                text: "HagerZhangLineSearch: theta must be > 0.0.".to_string(),
190            }
191            .into());
192        }
193        if theta >= F::from_f64(1.0).unwrap() {
194            return Err(ArgminError::InvalidParameter {
195                text: "HagerZhangLineSearch: theta must be < 1.0.".to_string(),
196            }
197            .into());
198        }
199        self.theta = theta;
200        Ok(self)
201    }
202
203    /// set gamma
204    pub fn gamma(mut self, gamma: F) -> Result<Self, Error> {
205        if gamma <= F::from_f64(0.0).unwrap() {
206            return Err(ArgminError::InvalidParameter {
207                text: "HagerZhangLineSearch: gamma must be > 0.0.".to_string(),
208            }
209            .into());
210        }
211        if gamma >= F::from_f64(1.0).unwrap() {
212            return Err(ArgminError::InvalidParameter {
213                text: "HagerZhangLineSearch: gamma must be < 1.0.".to_string(),
214            }
215            .into());
216        }
217        self.gamma = gamma;
218        Ok(self)
219    }
220
221    /// set eta
222    pub fn eta(mut self, eta: F) -> Result<Self, Error> {
223        if eta <= F::from_f64(0.0).unwrap() {
224            return Err(ArgminError::InvalidParameter {
225                text: "HagerZhangLineSearch: eta must be > 0.0.".to_string(),
226            }
227            .into());
228        }
229        self.eta = eta;
230        Ok(self)
231    }
232
233    /// set alpha limits
234    pub fn alpha(mut self, alpha_min: F, alpha_max: F) -> Result<Self, Error> {
235        if alpha_min < F::from_f64(0.0).unwrap() {
236            return Err(ArgminError::InvalidParameter {
237                text: "HagerZhangLineSearch: alpha_min must be >= 0.0.".to_string(),
238            }
239            .into());
240        }
241        if alpha_max <= alpha_min {
242            return Err(ArgminError::InvalidParameter {
243                text: "HagerZhangLineSearch: alpha_min must be smaller than alpha_max.".to_string(),
244            }
245            .into());
246        }
247        self.a_x_init = alpha_min;
248        self.b_x_init = alpha_max;
249        Ok(self)
250    }
251
252    fn update<O: ArgminOp<Param = P, Output = F>>(
253        &mut self,
254        op: &mut OpWrapper<O>,
255        (a_x, a_f, a_g): Triplet<F>,
256        (b_x, b_f, b_g): Triplet<F>,
257        (c_x, c_f, c_g): Triplet<F>,
258    ) -> Result<(Triplet<F>, Triplet<F>), Error> {
259        // U0
260        if c_x <= a_x || c_x >= b_x {
261            // nothing changes.
262            return Ok(((a_x, a_f, a_g), (b_x, b_f, b_g)));
263        }
264
265        // U1
266        if c_g >= F::from_f64(0.0).unwrap() {
267            return Ok(((a_x, a_f, a_g), (c_x, c_f, c_g)));
268        }
269
270        // U2
271        if c_g < F::from_f64(0.0).unwrap() && c_f <= self.finit + self.epsilon_k {
272            return Ok(((c_x, c_f, c_g), (b_x, b_f, b_g)));
273        }
274
275        // U3
276        if c_g < F::from_f64(0.0).unwrap() && c_f > self.finit + self.epsilon_k {
277            let mut ah_x = a_x;
278            let mut ah_f = a_f;
279            let mut ah_g = a_g;
280            let mut bh_x = c_x;
281            loop {
282                let d_x = (F::from_f64(1.0).unwrap() - self.theta) * ah_x + self.theta * bh_x;
283                let d_f = self.calc(op, d_x)?;
284                let d_g = self.calc_grad(op, d_x)?;
285                if d_g >= F::from_f64(0.0).unwrap() {
286                    return Ok(((ah_x, ah_f, ah_g), (d_x, d_f, d_g)));
287                }
288                if d_g < F::from_f64(0.0).unwrap() && d_f <= self.finit + self.epsilon_k {
289                    ah_x = d_x;
290                    ah_f = d_f;
291                    ah_g = d_g;
292                }
293                if d_g < F::from_f64(0.0).unwrap() && d_f > self.finit + self.epsilon_k {
294                    bh_x = d_x;
295                }
296            }
297        }
298
299        // return Ok(((a_x, a_f, a_g), (b_x, b_f, b_g)));
300        Err(ArgminError::InvalidParameter {
301            text: "HagerZhangLineSearch: Reached unreachable point in `update` method.".to_string(),
302        }
303        .into())
304    }
305
306    /// secant step
307    fn secant(&self, a_x: F, a_g: F, b_x: F, b_g: F) -> F {
308        (a_x * b_g - b_x * a_g) / (b_g - a_g)
309    }
310
311    /// double secant step
312    fn secant2<O: ArgminOp<Param = P, Output = F>>(
313        &mut self,
314        op: &mut OpWrapper<O>,
315        (a_x, a_f, a_g): Triplet<F>,
316        (b_x, b_f, b_g): Triplet<F>,
317    ) -> Result<(Triplet<F>, Triplet<F>), Error> {
318        // S1
319        let c_x = self.secant(a_x, a_g, b_x, b_g);
320        let c_f = self.calc(op, c_x)?;
321        let c_g = self.calc_grad(op, c_x)?;
322        let mut c_bar_x: F = F::from_f64(0.0).unwrap();
323
324        let ((aa_x, aa_f, aa_g), (bb_x, bb_f, bb_g)) =
325            self.update(op, (a_x, a_f, a_g), (b_x, b_f, b_g), (c_x, c_f, c_g))?;
326
327        // S2
328        if (c_x - bb_x).abs() < F::epsilon() {
329            c_bar_x = self.secant(b_x, b_g, bb_x, bb_g);
330        }
331
332        // S3
333        if (c_x - aa_x).abs() < F::epsilon() {
334            c_bar_x = self.secant(a_x, a_g, aa_x, aa_g);
335        }
336
337        // S4
338        if (c_x - aa_x).abs() < F::epsilon() || (c_x - bb_x).abs() < F::epsilon() {
339            let c_bar_f = self.calc(op, c_bar_x)?;
340            let c_bar_g = self.calc_grad(op, c_bar_x)?;
341
342            let (a_bar, b_bar) = self.update(
343                op,
344                (aa_x, aa_f, aa_g),
345                (bb_x, bb_f, bb_g),
346                (c_bar_x, c_bar_f, c_bar_g),
347            )?;
348            Ok((a_bar, b_bar))
349        } else {
350            Ok(((aa_x, aa_f, aa_g), (bb_x, bb_f, bb_g)))
351        }
352    }
353
354    fn calc<O: ArgminOp<Param = P, Output = F>>(
355        &mut self,
356        op: &mut OpWrapper<O>,
357        alpha: F,
358    ) -> Result<F, Error> {
359        let tmp = self.init_param.scaled_add(&alpha, &self.search_direction);
360        op.apply(&tmp)
361    }
362
363    fn calc_grad<O: ArgminOp<Param = P, Output = F>>(
364        &mut self,
365        op: &mut OpWrapper<O>,
366        alpha: F,
367    ) -> Result<F, Error> {
368        let tmp = self.init_param.scaled_add(&alpha, &self.search_direction);
369        let grad = op.gradient(&tmp)?;
370        Ok(self.search_direction.dot(&grad))
371    }
372
373    fn set_best(&mut self) {
374        if self.a_f < self.b_f && self.a_f < self.c_f {
375            self.best_x = self.a_x;
376            self.best_f = self.a_f;
377            self.best_g = self.a_g;
378        }
379
380        if self.b_f < self.a_f && self.b_f < self.c_f {
381            self.best_x = self.b_x;
382            self.best_f = self.b_f;
383            self.best_g = self.b_g;
384        }
385
386        if self.c_f < self.a_f && self.c_f < self.b_f {
387            self.best_x = self.c_x;
388            self.best_f = self.c_f;
389            self.best_g = self.c_g;
390        }
391    }
392}
393
394impl<P: Default, F: ArgminFloat> Default for HagerZhangLineSearch<P, F> {
395    fn default() -> Self {
396        HagerZhangLineSearch::new()
397    }
398}
399
400impl<P, F> ArgminLineSearch<P, F> for HagerZhangLineSearch<P, F>
401where
402    P: Clone
403        + Default
404        + Serialize
405        + ArgminSub<P, P>
406        + ArgminDot<P, f64>
407        + ArgminScaledAdd<P, f64, P>,
408    F: ArgminFloat,
409{
410    /// Set search direction
411    fn set_search_direction(&mut self, search_direction: P) {
412        self.search_direction_b = Some(search_direction);
413    }
414
415    /// Set initial alpha value
416    fn set_init_alpha(&mut self, alpha: F) -> Result<(), Error> {
417        self.c_x_init = alpha;
418        Ok(())
419    }
420}
421
422impl<P, O, F> Solver<O> for HagerZhangLineSearch<P, F>
423where
424    O: ArgminOp<Param = P, Output = F, Float = F>,
425    P: Clone
426        + Default
427        + Serialize
428        + DeserializeOwned
429        + ArgminSub<P, P>
430        + ArgminDot<P, F>
431        + ArgminScaledAdd<P, F, P>,
432    F: ArgminFloat,
433{
434    const NAME: &'static str = "Hager-Zhang Line search";
435
436    fn init(
437        &mut self,
438        op: &mut OpWrapper<O>,
439        state: &IterState<O>,
440    ) -> Result<Option<ArgminIterData<O>>, Error> {
441        if self.sigma < self.delta {
442            return Err(ArgminError::InvalidParameter {
443                text: "HagerZhangLineSearch: sigma must be >= delta.".to_string(),
444            }
445            .into());
446        }
447
448        self.search_direction = check_param!(
449            self.search_direction_b,
450            "HagerZhangLineSearch: Search direction not initialized. Call `set_search_direction`."
451        );
452
453        self.init_param = state.get_param();
454
455        let cost = state.get_cost();
456        self.finit = if cost.is_infinite() {
457            op.apply(&self.init_param)?
458        } else {
459            cost
460        };
461
462        self.init_grad = state.get_grad().unwrap_or(op.gradient(&self.init_param)?);
463
464        self.a_x = self.a_x_init;
465        self.b_x = self.b_x_init;
466        self.c_x = self.c_x_init;
467
468        let at = self.a_x;
469        self.a_f = self.calc(op, at)?;
470        self.a_g = self.calc_grad(op, at)?;
471        let bt = self.b_x;
472        self.b_f = self.calc(op, bt)?;
473        self.b_g = self.calc_grad(op, bt)?;
474        let ct = self.c_x;
475        self.c_f = self.calc(op, ct)?;
476        self.c_g = self.calc_grad(op, ct)?;
477
478        self.epsilon_k = self.epsilon * self.finit.abs();
479
480        self.dginit = self.init_grad.dot(&self.search_direction);
481
482        self.set_best();
483        let new_param = self
484            .init_param
485            .scaled_add(&self.best_x, &self.search_direction);
486        let best_f = self.best_f;
487
488        Ok(Some(ArgminIterData::new().param(new_param).cost(best_f)))
489    }
490
491    fn next_iter(
492        &mut self,
493        op: &mut OpWrapper<O>,
494        _state: &IterState<O>,
495    ) -> Result<ArgminIterData<O>, Error> {
496        // L1
497        let aa = (self.a_x, self.a_f, self.a_g);
498        let bb = (self.b_x, self.b_f, self.b_g);
499        let ((mut at_x, mut at_f, mut at_g), (mut bt_x, mut bt_f, mut bt_g)) =
500            self.secant2(op, aa, bb)?;
501
502        // L2
503        if bt_x - at_x > self.gamma * (self.b_x - self.a_x) {
504            let c_x = (at_x + bt_x) / F::from_f64(2.0).unwrap();
505            let tmp = self.init_param.scaled_add(&c_x, &self.search_direction);
506            let c_f = op.apply(&tmp)?;
507            let grad = op.gradient(&tmp)?;
508            let c_g = self.search_direction.dot(&grad);
509            let ((an_x, an_f, an_g), (bn_x, bn_f, bn_g)) =
510                self.update(op, (at_x, at_f, at_g), (bt_x, bt_f, bt_g), (c_x, c_f, c_g))?;
511            at_x = an_x;
512            at_f = an_f;
513            at_g = an_g;
514            bt_x = bn_x;
515            bt_f = bn_f;
516            bt_g = bn_g;
517        }
518
519        // L3
520        self.a_x = at_x;
521        self.a_f = at_f;
522        self.a_g = at_g;
523        self.b_x = bt_x;
524        self.b_f = bt_f;
525        self.b_g = bt_g;
526
527        self.set_best();
528        let new_param = self
529            .init_param
530            .scaled_add(&self.best_x, &self.search_direction);
531        Ok(ArgminIterData::new().param(new_param).cost(self.best_f))
532    }
533
534    fn terminate(&mut self, _state: &IterState<O>) -> TerminationReason {
535        if self.best_f - self.finit < self.delta * self.best_x * self.dginit {
536            return TerminationReason::LineSearchConditionMet;
537        }
538        if self.best_g > self.sigma * self.dginit {
539            return TerminationReason::LineSearchConditionMet;
540        }
541        if (F::from_f64(2.0).unwrap() * self.delta - F::from_f64(1.0).unwrap()) * self.dginit
542            >= self.best_g
543            && self.best_g >= self.sigma * self.dginit
544            && self.best_f <= self.finit + self.epsilon_k
545        {
546            return TerminationReason::LineSearchConditionMet;
547        }
548        TerminationReason::NotTerminated
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use crate::core::MinimalNoOperator;
556    use crate::test_trait_impl;
557
558    test_trait_impl!(hagerzhang, HagerZhangLineSearch<MinimalNoOperator, f64>);
559}