argmin/solver/brent/
mod.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Brent's method
9//!
10//! A root-finding algorithm combining the bisection method, the secant method
11//! and inverse quadratic interpolation. It has the reliability of bisection
12//! but it can be as quick as some of the less-reliable methods.
13//!
14//! # References:
15//!
16//! https://en.wikipedia.org/wiki/Brent%27s_method
17//!
18
19/// Implementation of Brent's optimization method,
20/// see https://en.wikipedia.org/wiki/Brent%27s_method
21use crate::prelude::*;
22use serde::{Deserialize, Serialize};
23use thiserror::Error;
24
25/// Error to be thrown if Brent is initialized with improper parameters.
26#[derive(Debug, Error)]
27pub enum BrentError {
28    /// f(min) and f(max) must have different signs
29    #[error("Brent error: f(min) and f(max) must have different signs.")]
30    WrongSign,
31}
32
33/// Brent's method
34///
35/// A root-finding algorithm combining the bisection method, the secant method
36/// and inverse quadratic interpolation. It has the reliability of bisection
37/// but it can be as quick as some of the less-reliable methods.
38///
39/// # References:
40/// https://en.wikipedia.org/wiki/Brent%27s_method
41#[derive(Clone, Serialize, Deserialize)]
42pub struct Brent<F> {
43    /// required relative accuracy
44    tol: F,
45    /// left or right boundary of current interval
46    a: F,
47    /// currently proposed best guess
48    b: F,
49    /// left or right boundary of current interval
50    c: F,
51    /// helper variable
52    d: F,
53    /// another helper variable
54    e: F,
55    /// function value at `a`
56    fa: F,
57    /// function value at `b`
58    fb: F,
59    /// function value at `c`
60    fc: F,
61}
62
63impl<F: ArgminFloat> Brent<F> {
64    /// Constructor
65    /// The values `min` and `max` must bracketing the root of the function.
66    /// The parameter `tol` specifies the relative error to be targeted.
67    pub fn new(min: F, max: F, tol: F) -> Brent<F> {
68        Brent {
69            tol: tol,
70            a: min,
71            b: max,
72            c: max,
73            d: F::nan(),
74            e: F::nan(),
75            fa: F::nan(),
76            fb: F::nan(),
77            fc: F::nan(),
78        }
79    }
80}
81
82impl<O, F> Solver<O> for Brent<F>
83where
84    O: ArgminOp<Param = F, Output = F, Float = F>,
85    F: ArgminFloat,
86{
87    const NAME: &'static str = "Brent";
88
89    fn init(
90        &mut self,
91        op: &mut OpWrapper<O>,
92        // Brent maintains its own state
93        _state: &IterState<O>,
94    ) -> Result<Option<ArgminIterData<O>>, Error> {
95        self.fa = op.apply(&self.a)?;
96        self.fb = op.apply(&self.b)?;
97        if self.fa * self.fb > F::from_f64(0.0).unwrap() {
98            return Err(BrentError::WrongSign.into());
99        }
100        self.fc = self.fb;
101        Ok(Some(
102            ArgminIterData::new().param(self.b).cost(self.fb.abs()),
103        ))
104    }
105
106    fn next_iter(
107        &mut self,
108        op: &mut OpWrapper<O>,
109        // Brent maintains its own state
110        _state: &IterState<O>,
111    ) -> Result<ArgminIterData<O>, Error> {
112        if (self.fb > F::from_f64(0.0).unwrap() && self.fc > F::from_f64(0.0).unwrap())
113            || self.fb < F::from_f64(0.0).unwrap() && self.fc < F::from_f64(0.0).unwrap()
114        {
115            self.c = self.a;
116            self.fc = self.fa;
117            self.d = self.b - self.a;
118            self.e = self.d;
119        }
120        if self.fc.abs() < self.fb.abs() {
121            self.a = self.b;
122            self.b = self.c;
123            self.c = self.a;
124            self.fa = self.fb;
125            self.fb = self.fc;
126            self.fc = self.fa;
127        }
128        // effective tolerance is double machine precision plus half tolerance as given.
129        let eff_tol = F::from_f64(2.0).unwrap() * F::epsilon() * self.b.abs()
130            + F::from_f64(0.5).unwrap() * self.tol;
131        let mid = F::from_f64(0.5).unwrap() * (self.c - self.b);
132        if mid.abs() <= eff_tol || self.fb == F::from_f64(0.0).unwrap() {
133            return Ok(ArgminIterData::new()
134                .termination_reason(TerminationReason::TargetPrecisionReached)
135                .param(self.b)
136                .cost(self.fb.abs()));
137        }
138        if self.e.abs() >= eff_tol && self.fa.abs() > self.fb.abs() {
139            let s = self.fb / self.fa;
140            let (mut p, mut q) = if self.a == self.c {
141                (
142                    F::from_f64(2.0).unwrap() * mid * s,
143                    F::from_f64(1.0).unwrap() - s,
144                )
145            } else {
146                let q = self.fa / self.fc;
147                let r = self.fb / self.fc;
148                (
149                    s * (F::from_f64(2.0).unwrap() * mid * q * (q - r)
150                        - (self.b - self.a) * (r - F::from_f64(1.0).unwrap())),
151                    (q - F::from_f64(1.0).unwrap())
152                        * (r - F::from_f64(1.0).unwrap())
153                        * (s - F::from_f64(1.0).unwrap()),
154                )
155            };
156            if p > F::from_f64(0.0).unwrap() {
157                q = -q;
158            }
159            p = p.abs();
160            let min1 = F::from_f64(3.0).unwrap() * mid * q - (eff_tol * q).abs();
161            let min2 = (self.e * q).abs();
162            if F::from_f64(2.0).unwrap() * p < if min1 < min2 { min1 } else { min2 } {
163                self.e = self.d;
164                self.d = p / q;
165            } else {
166                self.d = mid;
167                self.e = self.d;
168            };
169        } else {
170            self.d = mid;
171            self.e = self.d;
172        };
173        self.a = self.b;
174        self.fa = self.fb;
175        if self.d.abs() > eff_tol {
176            self.b = self.b + self.d;
177        } else {
178            self.b = self.b
179                + if mid >= F::from_f64(0.0).unwrap() {
180                    eff_tol.abs()
181                } else {
182                    -eff_tol.abs()
183                };
184        }
185
186        self.fb = op.apply(&self.b)?;
187        Ok(ArgminIterData::new().param(self.b).cost(self.fb.abs()))
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::test_trait_impl;
195
196    test_trait_impl!(brent, Brent<f64>);
197}