argmin/solver/brent/mod.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
// Copyright 2018-2020 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
//! Brent's method
//!
//! A root-finding algorithm combining the bisection method, the secant method
//! and inverse quadratic interpolation. It has the reliability of bisection
//! but it can be as quick as some of the less-reliable methods.
//!
//! # References:
//!
//! https://en.wikipedia.org/wiki/Brent%27s_method
//!
/// Implementation of Brent's optimization method,
/// see https://en.wikipedia.org/wiki/Brent%27s_method
use crate::prelude::*;
use serde::{Deserialize, Serialize};
use thiserror::Error;
/// Error to be thrown if Brent is initialized with improper parameters.
#[derive(Debug, Error)]
pub enum BrentError {
/// f(min) and f(max) must have different signs
#[error("Brent error: f(min) and f(max) must have different signs.")]
WrongSign,
}
/// Brent's method
///
/// A root-finding algorithm combining the bisection method, the secant method
/// and inverse quadratic interpolation. It has the reliability of bisection
/// but it can be as quick as some of the less-reliable methods.
///
/// # References:
/// https://en.wikipedia.org/wiki/Brent%27s_method
#[derive(Clone, Serialize, Deserialize)]
pub struct Brent<F> {
/// required relative accuracy
tol: F,
/// left or right boundary of current interval
a: F,
/// currently proposed best guess
b: F,
/// left or right boundary of current interval
c: F,
/// helper variable
d: F,
/// another helper variable
e: F,
/// function value at `a`
fa: F,
/// function value at `b`
fb: F,
/// function value at `c`
fc: F,
}
impl<F: ArgminFloat> Brent<F> {
/// Constructor
/// The values `min` and `max` must bracketing the root of the function.
/// The parameter `tol` specifies the relative error to be targeted.
pub fn new(min: F, max: F, tol: F) -> Brent<F> {
Brent {
tol: tol,
a: min,
b: max,
c: max,
d: F::nan(),
e: F::nan(),
fa: F::nan(),
fb: F::nan(),
fc: F::nan(),
}
}
}
impl<O, F> Solver<O> for Brent<F>
where
O: ArgminOp<Param = F, Output = F, Float = F>,
F: ArgminFloat,
{
const NAME: &'static str = "Brent";
fn init(
&mut self,
op: &mut OpWrapper<O>,
// Brent maintains its own state
_state: &IterState<O>,
) -> Result<Option<ArgminIterData<O>>, Error> {
self.fa = op.apply(&self.a)?;
self.fb = op.apply(&self.b)?;
if self.fa * self.fb > F::from_f64(0.0).unwrap() {
return Err(BrentError::WrongSign.into());
}
self.fc = self.fb;
Ok(Some(
ArgminIterData::new().param(self.b).cost(self.fb.abs()),
))
}
fn next_iter(
&mut self,
op: &mut OpWrapper<O>,
// Brent maintains its own state
_state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
if (self.fb > F::from_f64(0.0).unwrap() && self.fc > F::from_f64(0.0).unwrap())
|| self.fb < F::from_f64(0.0).unwrap() && self.fc < F::from_f64(0.0).unwrap()
{
self.c = self.a;
self.fc = self.fa;
self.d = self.b - self.a;
self.e = self.d;
}
if self.fc.abs() < self.fb.abs() {
self.a = self.b;
self.b = self.c;
self.c = self.a;
self.fa = self.fb;
self.fb = self.fc;
self.fc = self.fa;
}
// effective tolerance is double machine precision plus half tolerance as given.
let eff_tol = F::from_f64(2.0).unwrap() * F::epsilon() * self.b.abs()
+ F::from_f64(0.5).unwrap() * self.tol;
let mid = F::from_f64(0.5).unwrap() * (self.c - self.b);
if mid.abs() <= eff_tol || self.fb == F::from_f64(0.0).unwrap() {
return Ok(ArgminIterData::new()
.termination_reason(TerminationReason::TargetPrecisionReached)
.param(self.b)
.cost(self.fb.abs()));
}
if self.e.abs() >= eff_tol && self.fa.abs() > self.fb.abs() {
let s = self.fb / self.fa;
let (mut p, mut q) = if self.a == self.c {
(
F::from_f64(2.0).unwrap() * mid * s,
F::from_f64(1.0).unwrap() - s,
)
} else {
let q = self.fa / self.fc;
let r = self.fb / self.fc;
(
s * (F::from_f64(2.0).unwrap() * mid * q * (q - r)
- (self.b - self.a) * (r - F::from_f64(1.0).unwrap())),
(q - F::from_f64(1.0).unwrap())
* (r - F::from_f64(1.0).unwrap())
* (s - F::from_f64(1.0).unwrap()),
)
};
if p > F::from_f64(0.0).unwrap() {
q = -q;
}
p = p.abs();
let min1 = F::from_f64(3.0).unwrap() * mid * q - (eff_tol * q).abs();
let min2 = (self.e * q).abs();
if F::from_f64(2.0).unwrap() * p < if min1 < min2 { min1 } else { min2 } {
self.e = self.d;
self.d = p / q;
} else {
self.d = mid;
self.e = self.d;
};
} else {
self.d = mid;
self.e = self.d;
};
self.a = self.b;
self.fa = self.fb;
if self.d.abs() > eff_tol {
self.b = self.b + self.d;
} else {
self.b = self.b
+ if mid >= F::from_f64(0.0).unwrap() {
eff_tol.abs()
} else {
-eff_tol.abs()
};
}
self.fb = op.apply(&self.b)?;
Ok(ArgminIterData::new().param(self.b).cost(self.fb.abs()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_trait_impl;
test_trait_impl!(brent, Brent<f64>);
}