argmin/solver/linesearch/
morethuente.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! * [More-Thuente line search](struct.MoreThuenteLineSearch.html)
9//!
10//! TODO: Apparently it is missing stopping criteria!
11//!
12//! This implementation follows the excellent MATLAB implementation of Dianne P. O'Leary at
13//! http://www.cs.umd.edu/users/oleary/software/
14//!
15//! # Reference
16//!
17//! Jorge J. More and David J. Thuente. "Line search algorithms with guaranteed sufficient
18//! decrease." ACM Trans. Math. Softw. 20, 3 (September 1994), 286-307.
19//! DOI: https://doi.org/10.1145/192115.192132
20
21use crate::prelude::*;
22use serde::de::DeserializeOwned;
23use serde::{Deserialize, Serialize};
24use std::default::Default;
25
26/// The More-Thuente line search is a method to find a step length which obeys the strong Wolfe
27/// conditions.
28///
29/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/morethuente.rs)
30///
31/// # References
32///
33/// This implementation follows the excellent MATLAB implementation of Dianne P. O'Leary at
34/// http://www.cs.umd.edu/users/oleary/software/
35///
36/// [0] Jorge J. More and David J. Thuente. "Line search algorithms with guaranteed sufficient
37/// decrease." ACM Trans. Math. Softw. 20, 3 (September 1994), 286-307.
38/// DOI: https://doi.org/10.1145/192115.192132
39#[derive(Serialize, Deserialize, Clone)]
40pub struct MoreThuenteLineSearch<P, F> {
41    /// Search direction (builder)
42    search_direction_b: Option<P>,
43    /// initial parameter vector
44    init_param: P,
45    /// initial cost
46    finit: F,
47    /// initial gradient
48    init_grad: P,
49    /// Search direction
50    search_direction: P,
51    /// Search direction in 1D
52    dginit: F,
53    /// dgtest
54    dgtest: F,
55    /// c1
56    ftol: F,
57    /// c2
58    gtol: F,
59    /// xtrapf?
60    xtrapf: F,
61    /// width of interval
62    width: F,
63    /// width of what?
64    width1: F,
65    /// xtol
66    xtol: F,
67    /// alpha
68    alpha: F,
69    /// stpmin
70    stpmin: F,
71    /// stpmax
72    stpmax: F,
73    /// current step
74    stp: Step<F>,
75    /// stx
76    stx: Step<F>,
77    /// sty
78    sty: Step<F>,
79    /// f
80    f: F,
81    /// bracketed
82    brackt: bool,
83    /// stage1
84    stage1: bool,
85    /// infoc
86    infoc: usize,
87}
88
89#[derive(Clone, Serialize, Deserialize)]
90struct Step<F> {
91    pub x: F,
92    pub fx: F,
93    pub gx: F,
94}
95
96impl<F> Step<F> {
97    pub fn new(x: F, fx: F, gx: F) -> Self {
98        Step { x, fx, gx }
99    }
100}
101
102impl<F: ArgminFloat> Default for Step<F> {
103    fn default() -> Self {
104        Step {
105            x: F::from_f64(0.0).unwrap(),
106            fx: F::from_f64(0.0).unwrap(),
107            gx: F::from_f64(0.0).unwrap(),
108        }
109    }
110}
111
112impl<P: Default, F: ArgminFloat> MoreThuenteLineSearch<P, F> {
113    /// Constructor
114    pub fn new() -> Self {
115        MoreThuenteLineSearch {
116            search_direction_b: None,
117            init_param: P::default(),
118            finit: F::infinity(),
119            init_grad: P::default(),
120            search_direction: P::default(),
121            dginit: F::from_f64(0.0).unwrap(),
122            dgtest: F::from_f64(0.0).unwrap(),
123            ftol: F::from_f64(1e-4).unwrap(),
124            gtol: F::from_f64(0.9).unwrap(),
125            xtrapf: F::from_f64(4.0).unwrap(),
126            width: F::nan(),
127            width1: F::nan(),
128            xtol: F::from_f64(1e-10).unwrap(),
129            alpha: F::from_f64(1.0).unwrap(),
130            stpmin: F::epsilon().sqrt(),
131            stpmax: F::infinity(),
132            stp: Step::default(),
133            stx: Step::default(),
134            sty: Step::default(),
135            f: F::nan(),
136            brackt: false,
137            stage1: true,
138            infoc: 1,
139        }
140    }
141
142    /// Set c1 and c2 where 0 < c1 < c2 < 1.
143    pub fn c(mut self, c1: F, c2: F) -> Result<Self, Error> {
144        if c1 <= F::from_f64(0.0).unwrap() || c1 >= c2 {
145            return Err(ArgminError::InvalidParameter {
146                text: "MoreThuenteLineSearch: Parameter c1 must be in (0, c2).".to_string(),
147            }
148            .into());
149        }
150        if c2 <= c1 || c2 >= F::from_f64(1.0).unwrap() {
151            return Err(ArgminError::InvalidParameter {
152                text: "MoreThuenteLineSearch: Parameter c2 must be in (c1, 1).".to_string(),
153            }
154            .into());
155        }
156        self.ftol = c1;
157        self.gtol = c2;
158        Ok(self)
159    }
160
161    /// set alpha limits
162    pub fn alpha(mut self, alpha_min: F, alpha_max: F) -> Result<Self, Error> {
163        if alpha_min < F::from_f64(0.0).unwrap() {
164            return Err(ArgminError::InvalidParameter {
165                text: "MoreThuenteLineSearch: alpha_min must be >= 0.0.".to_string(),
166            }
167            .into());
168        }
169        if alpha_max <= alpha_min {
170            return Err(ArgminError::InvalidParameter {
171                text: "MoreThuenteLineSearch: alpha_min must be smaller than alpha_max."
172                    .to_string(),
173            }
174            .into());
175        }
176        self.stpmin = alpha_min;
177        self.stpmax = alpha_max;
178        Ok(self)
179    }
180}
181
182impl<P: Default, F: ArgminFloat> Default for MoreThuenteLineSearch<P, F> {
183    fn default() -> Self {
184        MoreThuenteLineSearch::new()
185    }
186}
187
188impl<P, F> ArgminLineSearch<P, F> for MoreThuenteLineSearch<P, F>
189where
190    P: Clone + Serialize + ArgminSub<P, P> + ArgminDot<P, F> + ArgminScaledAdd<P, F, P>,
191    F: ArgminFloat,
192{
193    /// Set search direction
194    fn set_search_direction(&mut self, search_direction: P) {
195        self.search_direction_b = Some(search_direction);
196    }
197
198    /// Set initial alpha value
199    fn set_init_alpha(&mut self, alpha: F) -> Result<(), Error> {
200        if alpha <= F::from_f64(0.0).unwrap() {
201            return Err(ArgminError::InvalidParameter {
202                text: "MoreThuenteLineSearch: Initial alpha must be > 0.".to_string(),
203            }
204            .into());
205        }
206        self.alpha = alpha;
207        Ok(())
208    }
209}
210
211impl<P, O, F> Solver<O> for MoreThuenteLineSearch<P, F>
212where
213    O: ArgminOp<Param = P, Output = F, Float = F>,
214    P: Clone
215        + Serialize
216        + DeserializeOwned
217        + ArgminSub<P, P>
218        + ArgminDot<P, F>
219        + ArgminScaledAdd<P, F, P>,
220    F: ArgminFloat,
221{
222    const NAME: &'static str = "More-Thuente Line search";
223
224    fn init(
225        &mut self,
226        op: &mut OpWrapper<O>,
227        state: &IterState<O>,
228    ) -> Result<Option<ArgminIterData<O>>, Error> {
229        self.search_direction = check_param!(
230            self.search_direction_b,
231            "MoreThuenteLineSearch: Search direction not initialized. Call `set_search_direction`."
232        );
233
234        self.init_param = state.get_param();
235
236        let cost = state.get_cost();
237        self.finit = if cost.is_infinite() {
238            op.apply(&self.init_param)?
239        } else {
240            cost
241        };
242
243        self.init_grad = state
244            .get_grad()
245            .unwrap_or_else(|| op.gradient(&self.init_param).unwrap());
246
247        self.dginit = self.init_grad.dot(&self.search_direction);
248
249        // compute search direction in 1D
250        if self.dginit >= F::from_f64(0.0).unwrap() {
251            return Err(ArgminError::ConditionViolated {
252                text: "MoreThuenteLineSearch: Search direction must be a descent direction."
253                    .to_string(),
254            }
255            .into());
256        }
257
258        self.stage1 = true;
259        self.brackt = false;
260
261        self.dgtest = self.ftol * self.dginit;
262        self.width = self.stpmax - self.stpmin;
263        self.width1 = F::from_f64(2.0).unwrap() * self.width;
264        self.f = self.finit;
265
266        self.stp = Step::new(self.alpha, F::nan(), F::nan());
267        self.stx = Step::new(F::from_f64(0.0).unwrap(), self.finit, self.dginit);
268        self.sty = Step::new(F::from_f64(0.0).unwrap(), self.finit, self.dginit);
269
270        Ok(None)
271    }
272
273    fn next_iter(
274        &mut self,
275        op: &mut OpWrapper<O>,
276        _state: &IterState<O>,
277    ) -> Result<ArgminIterData<O>, Error> {
278        // set the minimum and maximum steps to correspond to the present interval of uncertainty
279        let mut info = 0;
280        let (stmin, stmax) = if self.brackt {
281            (self.stx.x.min(self.sty.x), self.stx.x.max(self.sty.x))
282        } else {
283            (
284                self.stx.x,
285                self.stp.x + self.xtrapf * (self.stp.x - self.stx.x),
286            )
287        };
288
289        // alpha needs to be within bounds
290        self.stp.x = self.stp.x.max(self.stpmin);
291        self.stp.x = self.stp.x.min(self.stpmax);
292
293        // If an unusual termination is to occur then let alpha be the lowest point obtained so
294        // far.
295        if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax))
296            || (self.brackt && (stmax - stmin) <= self.xtol * stmax)
297            || self.infoc == 0
298        {
299            self.stp.x = self.stx.x;
300        }
301
302        // Evaluate the function and gradient at new stp.x and compute the directional derivative
303        let new_param = self
304            .init_param
305            .scaled_add(&self.stp.x, &self.search_direction);
306        self.f = op.apply(&new_param)?;
307        let new_grad = op.gradient(&new_param)?;
308        let cur_cost = self.f;
309        let cur_param = new_param;
310        let cur_grad = new_grad.clone();
311        // self.stx.fx = new_cost;
312        let dg = self.search_direction.dot(&new_grad);
313        let ftest1 = self.finit + self.stp.x * self.dgtest;
314        // self.stp.fx = new_cost;
315        // self.stp.gx = dg;
316
317        if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax)) || self.infoc == 0 {
318            info = 6;
319        }
320
321        if (self.stp.x - self.stpmax).abs() < F::epsilon() && self.f <= ftest1 && dg <= self.dgtest
322        {
323            info = 5;
324        }
325
326        if (self.stp.x - self.stpmin).abs() < F::epsilon() && (self.f > ftest1 || dg >= self.dgtest)
327        {
328            info = 4;
329        }
330
331        if self.brackt && stmax - stmin <= self.xtol * stmax {
332            info = 2;
333        }
334
335        if self.f <= ftest1 && dg.abs() <= self.gtol * (-self.dginit) {
336            info = 1;
337        }
338
339        if info != 0 {
340            return Ok(ArgminIterData::new()
341                .param(cur_param)
342                .cost(cur_cost)
343                .grad(cur_grad)
344                .termination_reason(TerminationReason::LineSearchConditionMet));
345        }
346
347        if self.stage1 && self.f <= ftest1 && dg >= self.ftol.min(self.gtol) * self.dginit {
348            self.stage1 = false;
349        }
350
351        if self.stage1 && self.f <= self.stp.fx && self.f > ftest1 {
352            let fm = self.f - self.stp.x * self.dgtest;
353            let fxm = self.stx.fx - self.stx.x * self.dgtest;
354            let fym = self.sty.fx - self.sty.x * self.dgtest;
355            let dgm = dg - self.dgtest;
356            let dgxm = self.stx.gx - self.dgtest;
357            let dgym = self.sty.gx - self.dgtest;
358
359            let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
360                Step::new(self.stx.x, fxm, dgxm),
361                Step::new(self.sty.x, fym, dgym),
362                Step::new(self.stp.x, fm, dgm),
363                self.brackt,
364                stmin,
365                stmax,
366            )?;
367
368            self.stx.x = stx1.x;
369            self.sty.x = sty1.x;
370            self.stp.x = stp1.x;
371            self.stx.fx = self.stx.fx + stx1.x * self.dgtest;
372            self.sty.fx = self.sty.fx + sty1.x * self.dgtest;
373            self.stx.gx = self.stx.gx + self.dgtest;
374            self.sty.gx = self.sty.gx + self.dgtest;
375            self.brackt = brackt1;
376            self.stp = stp1;
377            self.infoc = infoc;
378        } else {
379            let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
380                self.stx.clone(),
381                self.sty.clone(),
382                Step::new(self.stp.x, self.f, dg),
383                self.brackt,
384                stmin,
385                stmax,
386            )?;
387            self.stx = stx1;
388            self.sty = sty1;
389            self.stp = stp1;
390            self.f = self.stp.fx;
391            // dg = self.stp.gx;
392            self.brackt = brackt1;
393            self.infoc = infoc;
394        }
395
396        if self.brackt {
397            if (self.sty.x - self.stx.x).abs() >= F::from_f64(0.66).unwrap() * self.width1 {
398                self.stp.x = self.stx.x + F::from_f64(0.5).unwrap() * (self.sty.x - self.stx.x);
399            }
400            self.width1 = self.width;
401            self.width = (self.sty.x - self.stx.x).abs();
402        }
403
404        // let new_param = self
405        //     .init_param
406        //     .scaled_add(&self.stp.x, &self.search_direction);
407        // Ok(ArgminIterData::new().param(new_param))
408        Ok(ArgminIterData::new())
409    }
410}
411
412fn cstep<F: ArgminFloat>(
413    stx: Step<F>,
414    sty: Step<F>,
415    stp: Step<F>,
416    brackt: bool,
417    stpmin: F,
418    stpmax: F,
419) -> Result<(Step<F>, Step<F>, Step<F>, bool, F, F, usize), Error> {
420    let mut info: usize = 0;
421    let bound: bool;
422    let mut stpf: F;
423    let stpc: F;
424    let stpq: F;
425    let mut brackt = brackt;
426
427    // check inputs
428    if (brackt && (stp.x <= stx.x.min(sty.x) || stp.x >= stx.x.max(sty.x)))
429        || stx.gx * (stp.x - stx.x) >= F::from_f64(0.0).unwrap()
430        || stpmax < stpmin
431    {
432        return Ok((stx, sty, stp, brackt, stpmin, stpmax, info));
433    }
434
435    // determine if the derivatives have opposite sign
436    let sgnd = stp.gx * (stx.gx / stx.gx.abs());
437
438    if stp.fx > stx.fx {
439        // First case. A higher function value. The minimum is bracketed. If the cubic step is closer to
440        // stx.x than the quadratic step, the cubic step is taken, else the average of the cubic and
441        // the quadratic steps is taken.
442        info = 1;
443        bound = true;
444        let theta =
445            F::from_f64(3.0).unwrap() * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
446        let tmp = vec![theta, stx.gx, stp.gx];
447        // Check for a NaN or Inf in tmp before sorting
448        if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
449            return Err(ArgminError::ConditionViolated {
450                text: "MoreThuenteLineSearch: NaN or Inf encountered during iteration".to_string(),
451            }
452            .into());
453        }
454        let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
455        let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
456        if stp.x < stx.x {
457            gamma = -gamma;
458        }
459
460        let p = (gamma - stx.gx) + theta;
461        let q = ((gamma - stx.gx) + gamma) + stp.gx;
462        let r = p / q;
463        stpc = stx.x + r * (stp.x - stx.x);
464        stpq = stx.x
465            + ((stx.gx / ((stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx))
466                / F::from_f64(2.0).unwrap())
467                * (stp.x - stx.x);
468        if (stpc - stx.x).abs() < (stpq - stx.x).abs() {
469            stpf = stpc;
470        } else {
471            stpf = stpc + (stpq - stpc) / F::from_f64(2.0).unwrap();
472        }
473        brackt = true;
474    } else if sgnd < F::from_f64(0.0).unwrap() {
475        // Second case. A lower function value and derivatives of opposite sign. The minimum is
476        // bracketed. If the cubic step is closer to stx.x than the quadtratic (secant) step, the
477        // cubic step is taken, else the quadratic step is taken.
478        info = 2;
479        bound = false;
480        let theta =
481            F::from_f64(3.0).unwrap() * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
482        let tmp = vec![theta, stx.gx, stp.gx];
483        // Check for a NaN or Inf in tmp before sorting
484        if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
485            return Err(ArgminError::ConditionViolated {
486                text: "MoreThuenteLineSearch: NaN or Inf encountered during iteration".to_string(),
487            }
488            .into());
489        }
490        let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
491        let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
492        if stp.x > stx.x {
493            gamma = -gamma;
494        }
495        let p = (gamma - stp.gx) + theta;
496        let q = ((gamma - stp.gx) + gamma) + stx.gx;
497        let r = p / q;
498        stpc = stp.x + r * (stx.x - stp.x);
499        stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
500        if (stpc - stp.x).abs() > (stpq - stp.x).abs() {
501            stpf = stpc;
502        } else {
503            stpf = stpq;
504        }
505        brackt = true;
506    } else if stp.gx.abs() < stx.gx.abs() {
507        // Third case. A lower function value, derivatives of the same sign, and the magnitude of
508        // the derivative decreases. The cubic step is only used if the cubic tends to infinity in
509        // the direction of the step or if the minimum of the cubic is beyond stp.x. Otherwise the
510        // cubic step is defined to be either stpmin or stpmax. The quadtratic (secant) step is
511        // also computed and if the minimum is bracketed then the step closest to stx.x is taken,
512        // else the step farthest away is taken.
513        info = 3;
514        bound = true;
515        let theta =
516            F::from_f64(3.0).unwrap() * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
517        let tmp = vec![theta, stx.gx, stp.gx];
518        // Check for a NaN or Inf in tmp before sorting
519        if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
520            return Err(ArgminError::ConditionViolated {
521                text: "MoreThuenteLineSearch: NaN or Inf encountered during iteration".to_string(),
522            }
523            .into());
524        }
525        let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
526        // the case gamma == 0 only arises if the cubic does not tend to infinity in the direction
527        // of the step.
528
529        let mut gamma = *s
530            * F::from_f64(0.0)
531                .unwrap()
532                .max((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s))
533                .sqrt();
534        if stp.x > stx.x {
535            gamma = -gamma;
536        }
537
538        let p = (gamma - stp.gx) + theta;
539        let q = (gamma + (stx.gx - stp.gx)) + gamma;
540        let r = p / q;
541        if r < F::from_f64(0.0).unwrap() && gamma != F::from_f64(0.0).unwrap() {
542            stpc = stp.x + r * (stx.x - stp.x);
543        } else if stp.x > stx.x {
544            stpc = stpmax;
545        } else {
546            stpc = stpmin;
547        }
548        stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
549        if brackt {
550            if (stp.x - stpc).abs() < (stp.x - stpq).abs() {
551                stpf = stpc;
552            } else {
553                stpf = stpq;
554            }
555        } else if (stp.x - stpc).abs() > (stp.x - stpq).abs() {
556            stpf = stpc;
557        } else {
558            stpf = stpq;
559        }
560    } else {
561        // Fourth case. A lower function value, derivatives of the same sign, and the magnitued of
562        // the derivative does not decrease. If the minimum is not bracketed, the step is either
563        // stpmin or stpmax, else the cubic step is taken.
564        info = 4;
565        bound = false;
566        if brackt {
567            let theta =
568                F::from_f64(3.0).unwrap() * (stp.fx - sty.fx) / (sty.x - stp.x) + sty.gx + stp.gx;
569            let tmp = vec![theta, sty.gx, stp.gx];
570            // Check for a NaN or Inf in tmp before sorting
571            if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
572                return Err(ArgminError::ConditionViolated {
573                    text: "MoreThuenteLineSearch: NaN or Inf encountered during iteration"
574                        .to_string(),
575                }
576                .into());
577            }
578            let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
579            let mut gamma = *s * ((theta / *s).powi(2) - (sty.gx / *s) * (stp.gx / *s)).sqrt();
580            if stp.x > sty.x {
581                gamma = -gamma;
582            }
583            let p = (gamma - stp.gx) + theta;
584            let q = ((gamma - stp.gx) + gamma) + sty.gx;
585            let r = p / q;
586            stpc = stp.x + r * (sty.x - stp.x);
587            stpf = stpc;
588        } else if stp.x > stx.x {
589            stpf = stpmax;
590        } else {
591            stpf = stpmin;
592        }
593    }
594    // Update the interval of uncertainty. This update does not depend on the new step or the case
595    // analysis above.
596
597    let mut stx_o = stx.clone();
598    let mut sty_o = sty.clone();
599    let mut stp_o = stp.clone();
600    if stp_o.fx > stx_o.fx {
601        sty_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
602    } else {
603        if sgnd < F::from_f64(0.0).unwrap() {
604            sty_o = Step::new(stx_o.x, stx_o.fx, stx_o.gx);
605        }
606        stx_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
607    }
608
609    // compute the new step and safeguard it.
610
611    stpf = stpmax.min(stpf);
612    stpf = stpmin.max(stpf);
613
614    stp_o.x = stpf;
615    if brackt && bound {
616        if sty_o.x > stx_o.x {
617            stp_o.x = stp_o
618                .x
619                .min(stx_o.x + F::from_f64(0.66).unwrap() * (sty_o.x - stx_o.x));
620        } else {
621            stp_o.x = stp_o
622                .x
623                .max(stx_o.x + F::from_f64(0.66).unwrap() * (sty_o.x - stx_o.x));
624        }
625    }
626
627    Ok((stx_o, sty_o, stp_o, brackt, stpmin, stpmax, info))
628}
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633    use crate::core::MinimalNoOperator;
634    use crate::test_trait_impl;
635
636    test_trait_impl!(morethuente, MoreThuenteLineSearch<MinimalNoOperator, f64>);
637}