argmin/solver/neldermead/
mod.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [Wikipedia](https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method)
11
12use crate::prelude::*;
13use serde::{de::DeserializeOwned, Deserialize, Serialize};
14use std::default::Default;
15
16/// Nelder-Mead method
17///
18/// The Nelder-Mead method a heuristic search method for nonlinear optimization problems which does
19/// not require derivatives.
20///
21/// The method is based on simplices which consist of n+1 vertices for an optimization problem with
22/// n dimensions.
23/// The function to be optimized is evaluated at all vertices. Based on these cost function values
24/// the behaviour of the cost function is extrapolated in order to find the next point to be
25/// evaluated.
26///
27/// The following actions are possible:
28///
29/// 1) Reflection: (Parameter `alpha`, default `1`)
30/// 2) Expansion: (Parameter `gamma`, default `2`)
31/// 3) Contraction: (Parameter `rho`, default `0.5`)
32/// 4) Shrink: (Parameter `sigma`, default `0.5`)
33///
34/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/neldermead.rs)
35///
36/// # References:
37///
38/// [Wikipedia](https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method)
39#[derive(Clone, Serialize, Deserialize)]
40pub struct NelderMead<P, F> {
41    /// alpha
42    alpha: F,
43    /// gamma
44    gamma: F,
45    /// rho
46    rho: F,
47    /// sigma
48    sigma: F,
49    /// parameters
50    params: Vec<(P, F)>,
51    /// Sample standard deviation tolerance
52    sd_tolerance: F,
53}
54
55impl<P, F> NelderMead<P, F>
56where
57    P: Clone + Default + ArgminAdd<P, P> + ArgminSub<P, P> + ArgminMul<F, P>,
58    F: ArgminFloat,
59{
60    /// Constructor
61    pub fn new() -> Self {
62        NelderMead {
63            alpha: F::from_f64(1.0).unwrap(),
64            gamma: F::from_f64(2.0).unwrap(),
65            rho: F::from_f64(0.5).unwrap(),
66            sigma: F::from_f64(0.5).unwrap(),
67            params: vec![],
68            sd_tolerance: F::epsilon(),
69        }
70    }
71
72    /// Add initial parameters
73    pub fn with_initial_params(mut self, params: Vec<P>) -> Self {
74        self.params = params.into_iter().map(|p| (p, F::nan())).collect();
75        self
76    }
77
78    /// Set Sample standard deviation tolerance
79    pub fn sd_tolerance(mut self, tol: F) -> Self {
80        self.sd_tolerance = tol;
81        self
82    }
83
84    /// set alpha
85    pub fn alpha(mut self, alpha: F) -> Result<Self, Error> {
86        if alpha <= F::from_f64(0.0).unwrap() {
87            return Err(ArgminError::InvalidParameter {
88                text: "Nelder-Mead:  must be > 0.".to_string(),
89            }
90            .into());
91        }
92        self.alpha = alpha;
93        Ok(self)
94    }
95
96    /// set gamma
97    pub fn gamma(mut self, gamma: F) -> Result<Self, Error> {
98        if gamma <= F::from_f64(1.0).unwrap() {
99            return Err(ArgminError::InvalidParameter {
100                text: "Nelder-Mead: gamma must be > 1.".to_string(),
101            }
102            .into());
103        }
104        self.gamma = gamma;
105        Ok(self)
106    }
107
108    /// set rho
109    pub fn rho(mut self, rho: F) -> Result<Self, Error> {
110        if rho <= F::from_f64(0.0).unwrap() || rho > F::from_f64(0.5).unwrap() {
111            return Err(ArgminError::InvalidParameter {
112                text: "Nelder-Mead: rho must be in  (0.0, 0.5].".to_string(),
113            }
114            .into());
115        }
116        self.rho = rho;
117        Ok(self)
118    }
119
120    /// set sigma
121    pub fn sigma(mut self, sigma: F) -> Result<Self, Error> {
122        if sigma <= F::from_f64(0.0).unwrap() || sigma > F::from_f64(1.0).unwrap() {
123            return Err(ArgminError::InvalidParameter {
124                text: "Nelder-Mead: sigma must be in  (0.0, 1.0].".to_string(),
125            }
126            .into());
127        }
128        self.sigma = sigma;
129        Ok(self)
130    }
131
132    /// Sort parameters vectors based on their cost function values
133    fn sort_param_vecs(&mut self) {
134        self.params
135            .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
136    }
137
138    /// Calculate centroid of all but the worst vectors
139    fn calculate_centroid(&self) -> P {
140        let num_param = self.params.len() - 1;
141        let mut x0: P = self.params[0].0.clone();
142        for idx in 1..num_param {
143            x0 = x0.add(&self.params[idx].0)
144        }
145        x0.mul(&(F::from_f64(1.0).unwrap() / (F::from_usize(num_param).unwrap())))
146    }
147
148    /// Reflect
149    fn reflect(&self, x0: &P, x: &P) -> P {
150        x0.add(&x0.sub(&x).mul(&self.alpha))
151    }
152
153    /// Expand
154    fn expand(&self, x0: &P, x: &P) -> P {
155        x0.add(&x.sub(&x0).mul(&self.gamma))
156    }
157
158    /// Contract
159    fn contract(&self, x0: &P, x: &P) -> P {
160        x0.add(&x.sub(&x0).mul(&self.rho))
161    }
162
163    /// Shrink
164    fn shrink<S>(&mut self, mut cost: S) -> Result<(), Error>
165    where
166        S: FnMut(&P) -> Result<F, Error>,
167    {
168        let mut out = Vec::with_capacity(self.params.len());
169        out.push(self.params[0].clone());
170
171        for idx in 1..self.params.len() {
172            let xi = out[0]
173                .0
174                .add(&self.params[idx].0.sub(&out[0].0).mul(&self.sigma));
175            let ci = (cost)(&xi)?;
176            out.push((xi, ci));
177        }
178        self.params = out;
179        Ok(())
180    }
181}
182
183impl<P, F> Default for NelderMead<P, F>
184where
185    P: Clone + Default + ArgminAdd<P, P> + ArgminSub<P, P> + ArgminMul<F, P>,
186    F: ArgminFloat,
187{
188    fn default() -> NelderMead<P, F> {
189        NelderMead::new()
190    }
191}
192
193impl<O, P, F> Solver<O> for NelderMead<P, F>
194where
195    O: ArgminOp<Output = F, Param = P, Float = F>,
196    P: Default
197        + Clone
198        + Serialize
199        + DeserializeOwned
200        + ArgminScaledSub<O::Param, O::Float, O::Param>
201        + ArgminSub<O::Param, O::Param>
202        + ArgminAdd<O::Param, O::Param>
203        + ArgminMul<O::Float, O::Param>,
204    F: ArgminFloat + std::iter::Sum<O::Float>,
205{
206    const NAME: &'static str = "Nelder-Mead method";
207
208    fn init(
209        &mut self,
210        op: &mut OpWrapper<O>,
211        _state: &IterState<O>,
212    ) -> Result<Option<ArgminIterData<O>>, Error> {
213        self.params = self
214            .params
215            .iter()
216            .cloned()
217            .map(|(p, _)| {
218                let c = op.apply(&p).unwrap();
219                (p, c)
220            })
221            .collect();
222        self.sort_param_vecs();
223
224        Ok(Some(
225            ArgminIterData::new()
226                .param(self.params[0].0.clone())
227                .cost(self.params[0].1),
228        ))
229    }
230
231    fn next_iter(
232        &mut self,
233        op: &mut OpWrapper<O>,
234        _state: &IterState<O>,
235    ) -> Result<ArgminIterData<O>, Error> {
236        let num_param = self.params.len();
237
238        let x0 = self.calculate_centroid();
239
240        let xr = self.reflect(&x0, &self.params[num_param - 1].0);
241        let xr_cost = op.apply(&xr)?;
242
243        let action = if xr_cost < self.params[num_param - 2].1 && xr_cost >= self.params[0].1 {
244            // reflection
245            self.params.last_mut().unwrap().0 = xr;
246            self.params.last_mut().unwrap().1 = xr_cost;
247            "reflection"
248        } else if xr_cost < self.params[0].1 {
249            // expansion
250            let xe = self.expand(&x0, &xr);
251            let xe_cost = op.apply(&xe)?;
252            if xe_cost < xr_cost {
253                self.params.last_mut().unwrap().0 = xe;
254                self.params.last_mut().unwrap().1 = xe_cost;
255            } else {
256                self.params.last_mut().unwrap().0 = xr;
257                self.params.last_mut().unwrap().1 = xr_cost;
258            }
259            "expansion"
260        } else if xr_cost >= self.params[num_param - 2].1 {
261            // contraction
262            let xc = self.contract(&x0, &self.params[num_param - 1].0);
263            let xc_cost = op.apply(&xc)?;
264            if xc_cost < self.params[num_param - 1].1 {
265                self.params.last_mut().unwrap().0 = xc;
266                self.params.last_mut().unwrap().1 = xc_cost;
267            }
268            "contraction"
269        } else {
270            // shrink
271            self.shrink(|x| op.apply(x))?;
272            "shrink"
273        };
274
275        self.sort_param_vecs();
276
277        Ok(ArgminIterData::new()
278            .param(self.params[0].0.clone())
279            .cost(self.params[0].1)
280            .kv(make_kv!("action" => action;)))
281    }
282
283    fn terminate(&mut self, _state: &IterState<O>) -> TerminationReason {
284        let n = F::from_usize(self.params.len()).unwrap();
285        let c0: F = self.params.iter().map(|(_, c)| *c).sum::<F>() / n;
286        let s: F = (F::from_f64(1.0).unwrap() / (n - F::from_f64(1.0).unwrap())
287            * self
288                .params
289                .iter()
290                .map(|(_, c)| (*c - c0).powi(2))
291                .sum::<F>())
292        .sqrt();
293        if s < self.sd_tolerance {
294            return TerminationReason::TargetToleranceReached;
295        }
296        TerminationReason::NotTerminated
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::test_trait_impl;
304    type Operator = MinimalNoOperator;
305
306    test_trait_impl!(nelder_mead, NelderMead<Operator, f64>);
307}