argmin/solver/neldermead/
mod.rs1use crate::prelude::*;
13use serde::{de::DeserializeOwned, Deserialize, Serialize};
14use std::default::Default;
15
16#[derive(Clone, Serialize, Deserialize)]
40pub struct NelderMead<P, F> {
41 alpha: F,
43 gamma: F,
45 rho: F,
47 sigma: F,
49 params: Vec<(P, F)>,
51 sd_tolerance: F,
53}
54
55impl<P, F> NelderMead<P, F>
56where
57 P: Clone + Default + ArgminAdd<P, P> + ArgminSub<P, P> + ArgminMul<F, P>,
58 F: ArgminFloat,
59{
60 pub fn new() -> Self {
62 NelderMead {
63 alpha: F::from_f64(1.0).unwrap(),
64 gamma: F::from_f64(2.0).unwrap(),
65 rho: F::from_f64(0.5).unwrap(),
66 sigma: F::from_f64(0.5).unwrap(),
67 params: vec![],
68 sd_tolerance: F::epsilon(),
69 }
70 }
71
72 pub fn with_initial_params(mut self, params: Vec<P>) -> Self {
74 self.params = params.into_iter().map(|p| (p, F::nan())).collect();
75 self
76 }
77
78 pub fn sd_tolerance(mut self, tol: F) -> Self {
80 self.sd_tolerance = tol;
81 self
82 }
83
84 pub fn alpha(mut self, alpha: F) -> Result<Self, Error> {
86 if alpha <= F::from_f64(0.0).unwrap() {
87 return Err(ArgminError::InvalidParameter {
88 text: "Nelder-Mead: must be > 0.".to_string(),
89 }
90 .into());
91 }
92 self.alpha = alpha;
93 Ok(self)
94 }
95
96 pub fn gamma(mut self, gamma: F) -> Result<Self, Error> {
98 if gamma <= F::from_f64(1.0).unwrap() {
99 return Err(ArgminError::InvalidParameter {
100 text: "Nelder-Mead: gamma must be > 1.".to_string(),
101 }
102 .into());
103 }
104 self.gamma = gamma;
105 Ok(self)
106 }
107
108 pub fn rho(mut self, rho: F) -> Result<Self, Error> {
110 if rho <= F::from_f64(0.0).unwrap() || rho > F::from_f64(0.5).unwrap() {
111 return Err(ArgminError::InvalidParameter {
112 text: "Nelder-Mead: rho must be in (0.0, 0.5].".to_string(),
113 }
114 .into());
115 }
116 self.rho = rho;
117 Ok(self)
118 }
119
120 pub fn sigma(mut self, sigma: F) -> Result<Self, Error> {
122 if sigma <= F::from_f64(0.0).unwrap() || sigma > F::from_f64(1.0).unwrap() {
123 return Err(ArgminError::InvalidParameter {
124 text: "Nelder-Mead: sigma must be in (0.0, 1.0].".to_string(),
125 }
126 .into());
127 }
128 self.sigma = sigma;
129 Ok(self)
130 }
131
132 fn sort_param_vecs(&mut self) {
134 self.params
135 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
136 }
137
138 fn calculate_centroid(&self) -> P {
140 let num_param = self.params.len() - 1;
141 let mut x0: P = self.params[0].0.clone();
142 for idx in 1..num_param {
143 x0 = x0.add(&self.params[idx].0)
144 }
145 x0.mul(&(F::from_f64(1.0).unwrap() / (F::from_usize(num_param).unwrap())))
146 }
147
148 fn reflect(&self, x0: &P, x: &P) -> P {
150 x0.add(&x0.sub(&x).mul(&self.alpha))
151 }
152
153 fn expand(&self, x0: &P, x: &P) -> P {
155 x0.add(&x.sub(&x0).mul(&self.gamma))
156 }
157
158 fn contract(&self, x0: &P, x: &P) -> P {
160 x0.add(&x.sub(&x0).mul(&self.rho))
161 }
162
163 fn shrink<S>(&mut self, mut cost: S) -> Result<(), Error>
165 where
166 S: FnMut(&P) -> Result<F, Error>,
167 {
168 let mut out = Vec::with_capacity(self.params.len());
169 out.push(self.params[0].clone());
170
171 for idx in 1..self.params.len() {
172 let xi = out[0]
173 .0
174 .add(&self.params[idx].0.sub(&out[0].0).mul(&self.sigma));
175 let ci = (cost)(&xi)?;
176 out.push((xi, ci));
177 }
178 self.params = out;
179 Ok(())
180 }
181}
182
183impl<P, F> Default for NelderMead<P, F>
184where
185 P: Clone + Default + ArgminAdd<P, P> + ArgminSub<P, P> + ArgminMul<F, P>,
186 F: ArgminFloat,
187{
188 fn default() -> NelderMead<P, F> {
189 NelderMead::new()
190 }
191}
192
193impl<O, P, F> Solver<O> for NelderMead<P, F>
194where
195 O: ArgminOp<Output = F, Param = P, Float = F>,
196 P: Default
197 + Clone
198 + Serialize
199 + DeserializeOwned
200 + ArgminScaledSub<O::Param, O::Float, O::Param>
201 + ArgminSub<O::Param, O::Param>
202 + ArgminAdd<O::Param, O::Param>
203 + ArgminMul<O::Float, O::Param>,
204 F: ArgminFloat + std::iter::Sum<O::Float>,
205{
206 const NAME: &'static str = "Nelder-Mead method";
207
208 fn init(
209 &mut self,
210 op: &mut OpWrapper<O>,
211 _state: &IterState<O>,
212 ) -> Result<Option<ArgminIterData<O>>, Error> {
213 self.params = self
214 .params
215 .iter()
216 .cloned()
217 .map(|(p, _)| {
218 let c = op.apply(&p).unwrap();
219 (p, c)
220 })
221 .collect();
222 self.sort_param_vecs();
223
224 Ok(Some(
225 ArgminIterData::new()
226 .param(self.params[0].0.clone())
227 .cost(self.params[0].1),
228 ))
229 }
230
231 fn next_iter(
232 &mut self,
233 op: &mut OpWrapper<O>,
234 _state: &IterState<O>,
235 ) -> Result<ArgminIterData<O>, Error> {
236 let num_param = self.params.len();
237
238 let x0 = self.calculate_centroid();
239
240 let xr = self.reflect(&x0, &self.params[num_param - 1].0);
241 let xr_cost = op.apply(&xr)?;
242
243 let action = if xr_cost < self.params[num_param - 2].1 && xr_cost >= self.params[0].1 {
244 self.params.last_mut().unwrap().0 = xr;
246 self.params.last_mut().unwrap().1 = xr_cost;
247 "reflection"
248 } else if xr_cost < self.params[0].1 {
249 let xe = self.expand(&x0, &xr);
251 let xe_cost = op.apply(&xe)?;
252 if xe_cost < xr_cost {
253 self.params.last_mut().unwrap().0 = xe;
254 self.params.last_mut().unwrap().1 = xe_cost;
255 } else {
256 self.params.last_mut().unwrap().0 = xr;
257 self.params.last_mut().unwrap().1 = xr_cost;
258 }
259 "expansion"
260 } else if xr_cost >= self.params[num_param - 2].1 {
261 let xc = self.contract(&x0, &self.params[num_param - 1].0);
263 let xc_cost = op.apply(&xc)?;
264 if xc_cost < self.params[num_param - 1].1 {
265 self.params.last_mut().unwrap().0 = xc;
266 self.params.last_mut().unwrap().1 = xc_cost;
267 }
268 "contraction"
269 } else {
270 self.shrink(|x| op.apply(x))?;
272 "shrink"
273 };
274
275 self.sort_param_vecs();
276
277 Ok(ArgminIterData::new()
278 .param(self.params[0].0.clone())
279 .cost(self.params[0].1)
280 .kv(make_kv!("action" => action;)))
281 }
282
283 fn terminate(&mut self, _state: &IterState<O>) -> TerminationReason {
284 let n = F::from_usize(self.params.len()).unwrap();
285 let c0: F = self.params.iter().map(|(_, c)| *c).sum::<F>() / n;
286 let s: F = (F::from_f64(1.0).unwrap() / (n - F::from_f64(1.0).unwrap())
287 * self
288 .params
289 .iter()
290 .map(|(_, c)| (*c - c0).powi(2))
291 .sum::<F>())
292 .sqrt();
293 if s < self.sd_tolerance {
294 return TerminationReason::TargetToleranceReached;
295 }
296 TerminationReason::NotTerminated
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::test_trait_impl;
304 type Operator = MinimalNoOperator;
305
306 test_trait_impl!(nelder_mead, NelderMead<Operator, f64>);
307}