argmin/solver/brent/
mod.rs1use crate::prelude::*;
22use serde::{Deserialize, Serialize};
23use thiserror::Error;
24
25#[derive(Debug, Error)]
27pub enum BrentError {
28 #[error("Brent error: f(min) and f(max) must have different signs.")]
30 WrongSign,
31}
32
33#[derive(Clone, Serialize, Deserialize)]
42pub struct Brent<F> {
43 tol: F,
45 a: F,
47 b: F,
49 c: F,
51 d: F,
53 e: F,
55 fa: F,
57 fb: F,
59 fc: F,
61}
62
63impl<F: ArgminFloat> Brent<F> {
64 pub fn new(min: F, max: F, tol: F) -> Brent<F> {
68 Brent {
69 tol: tol,
70 a: min,
71 b: max,
72 c: max,
73 d: F::nan(),
74 e: F::nan(),
75 fa: F::nan(),
76 fb: F::nan(),
77 fc: F::nan(),
78 }
79 }
80}
81
82impl<O, F> Solver<O> for Brent<F>
83where
84 O: ArgminOp<Param = F, Output = F, Float = F>,
85 F: ArgminFloat,
86{
87 const NAME: &'static str = "Brent";
88
89 fn init(
90 &mut self,
91 op: &mut OpWrapper<O>,
92 _state: &IterState<O>,
94 ) -> Result<Option<ArgminIterData<O>>, Error> {
95 self.fa = op.apply(&self.a)?;
96 self.fb = op.apply(&self.b)?;
97 if self.fa * self.fb > F::from_f64(0.0).unwrap() {
98 return Err(BrentError::WrongSign.into());
99 }
100 self.fc = self.fb;
101 Ok(Some(
102 ArgminIterData::new().param(self.b).cost(self.fb.abs()),
103 ))
104 }
105
106 fn next_iter(
107 &mut self,
108 op: &mut OpWrapper<O>,
109 _state: &IterState<O>,
111 ) -> Result<ArgminIterData<O>, Error> {
112 if (self.fb > F::from_f64(0.0).unwrap() && self.fc > F::from_f64(0.0).unwrap())
113 || self.fb < F::from_f64(0.0).unwrap() && self.fc < F::from_f64(0.0).unwrap()
114 {
115 self.c = self.a;
116 self.fc = self.fa;
117 self.d = self.b - self.a;
118 self.e = self.d;
119 }
120 if self.fc.abs() < self.fb.abs() {
121 self.a = self.b;
122 self.b = self.c;
123 self.c = self.a;
124 self.fa = self.fb;
125 self.fb = self.fc;
126 self.fc = self.fa;
127 }
128 let eff_tol = F::from_f64(2.0).unwrap() * F::epsilon() * self.b.abs()
130 + F::from_f64(0.5).unwrap() * self.tol;
131 let mid = F::from_f64(0.5).unwrap() * (self.c - self.b);
132 if mid.abs() <= eff_tol || self.fb == F::from_f64(0.0).unwrap() {
133 return Ok(ArgminIterData::new()
134 .termination_reason(TerminationReason::TargetPrecisionReached)
135 .param(self.b)
136 .cost(self.fb.abs()));
137 }
138 if self.e.abs() >= eff_tol && self.fa.abs() > self.fb.abs() {
139 let s = self.fb / self.fa;
140 let (mut p, mut q) = if self.a == self.c {
141 (
142 F::from_f64(2.0).unwrap() * mid * s,
143 F::from_f64(1.0).unwrap() - s,
144 )
145 } else {
146 let q = self.fa / self.fc;
147 let r = self.fb / self.fc;
148 (
149 s * (F::from_f64(2.0).unwrap() * mid * q * (q - r)
150 - (self.b - self.a) * (r - F::from_f64(1.0).unwrap())),
151 (q - F::from_f64(1.0).unwrap())
152 * (r - F::from_f64(1.0).unwrap())
153 * (s - F::from_f64(1.0).unwrap()),
154 )
155 };
156 if p > F::from_f64(0.0).unwrap() {
157 q = -q;
158 }
159 p = p.abs();
160 let min1 = F::from_f64(3.0).unwrap() * mid * q - (eff_tol * q).abs();
161 let min2 = (self.e * q).abs();
162 if F::from_f64(2.0).unwrap() * p < if min1 < min2 { min1 } else { min2 } {
163 self.e = self.d;
164 self.d = p / q;
165 } else {
166 self.d = mid;
167 self.e = self.d;
168 };
169 } else {
170 self.d = mid;
171 self.e = self.d;
172 };
173 self.a = self.b;
174 self.fa = self.fb;
175 if self.d.abs() > eff_tol {
176 self.b = self.b + self.d;
177 } else {
178 self.b = self.b
179 + if mid >= F::from_f64(0.0).unwrap() {
180 eff_tol.abs()
181 } else {
182 -eff_tol.abs()
183 };
184 }
185
186 self.fb = op.apply(&self.b)?;
187 Ok(ArgminIterData::new().param(self.b).cost(self.fb.abs()))
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::test_trait_impl;
195
196 test_trait_impl!(brent, Brent<f64>);
197}