argmin/solver/simulatedannealing/
mod.rs1use crate::prelude::*;
19use rand::prelude::*;
20use rand_xorshift::XorShiftRng;
21use serde::{Deserialize, Serialize};
22
23#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
32pub enum SATempFunc<F> {
33 TemperatureFast,
35 Boltzmann,
37 Exponential(F),
39 }
43
44impl<F> std::default::Default for SATempFunc<F> {
45 fn default() -> Self {
46 SATempFunc::Boltzmann
47 }
48}
49
50#[derive(Clone, Serialize, Deserialize)]
62pub struct SimulatedAnnealing<F> {
63 init_temp: F,
65 temp_func: SATempFunc<F>,
67 temp_iter: u64,
70 stall_iter_accepted: u64,
72 stall_iter_accepted_limit: u64,
74 stall_iter_best: u64,
76 stall_iter_best_limit: u64,
78 reanneal_fixed: u64,
80 reanneal_iter_fixed: u64,
82 reanneal_accepted: u64,
84 reanneal_iter_accepted: u64,
86 reanneal_best: u64,
88 reanneal_iter_best: u64,
90 cur_temp: F,
92 rng: XorShiftRng,
94}
95
96impl<F> SimulatedAnnealing<F>
97where
98 F: ArgminFloat,
99{
100 pub fn new(init_temp: F) -> Result<Self, Error> {
106 if init_temp <= F::from_f64(0.0).unwrap() {
107 Err(ArgminError::InvalidParameter {
108 text: "Initial temperature must be > 0.".to_string(),
109 }
110 .into())
111 } else {
112 Ok(SimulatedAnnealing {
113 init_temp,
114 temp_func: SATempFunc::TemperatureFast,
115 temp_iter: 0,
116 stall_iter_accepted: 0,
117 stall_iter_accepted_limit: std::u64::MAX,
118 stall_iter_best: 0,
119 stall_iter_best_limit: std::u64::MAX,
120 reanneal_fixed: std::u64::MAX,
121 reanneal_iter_fixed: 0,
122 reanneal_accepted: std::u64::MAX,
123 reanneal_iter_accepted: 0,
124 reanneal_best: std::u64::MAX,
125 reanneal_iter_best: 0,
126 cur_temp: init_temp,
127 rng: XorShiftRng::from_entropy(),
128 })
129 }
130 }
131
132 pub fn temp_func(mut self, temperature_func: SATempFunc<F>) -> Self {
134 self.temp_func = temperature_func;
135 self
136 }
137
138 pub fn stall_accepted(mut self, iter: u64) -> Self {
140 self.stall_iter_accepted_limit = iter;
141 self
142 }
143
144 pub fn stall_best(mut self, iter: u64) -> Self {
146 self.stall_iter_best_limit = iter;
147 self
148 }
149
150 pub fn reannealing_fixed(mut self, iter: u64) -> Self {
152 self.reanneal_fixed = iter;
153 self
154 }
155
156 pub fn reannealing_accepted(mut self, iter: u64) -> Self {
158 self.reanneal_accepted = iter;
159 self
160 }
161
162 pub fn reannealing_best(mut self, iter: u64) -> Self {
164 self.reanneal_best = iter;
165 self
166 }
167
168 fn update_temperature(&mut self) {
172 self.cur_temp = match self.temp_func {
173 SATempFunc::TemperatureFast => {
174 self.init_temp / F::from_u64(self.temp_iter + 1).unwrap()
175 }
176 SATempFunc::Boltzmann => self.init_temp / F::from_u64(self.temp_iter + 1).unwrap().ln(),
177 SATempFunc::Exponential(x) => {
178 self.init_temp * x.powf(F::from_u64(self.temp_iter + 1).unwrap())
179 }
180 };
181 }
182
183 fn reanneal(&mut self) -> (bool, bool, bool) {
185 let out = (
186 self.reanneal_iter_fixed >= self.reanneal_fixed,
187 self.reanneal_iter_accepted >= self.reanneal_accepted,
188 self.reanneal_iter_best >= self.reanneal_best,
189 );
190 if out.0 || out.1 || out.2 {
191 self.reanneal_iter_fixed = 0;
192 self.reanneal_iter_accepted = 0;
193 self.reanneal_iter_best = 0;
194 self.cur_temp = self.init_temp;
195 self.temp_iter = 0;
196 }
197 out
198 }
199
200 fn update_stall_and_reanneal_iter(&mut self, accepted: bool, new_best: bool) {
202 self.stall_iter_accepted = if accepted {
203 0
204 } else {
205 self.stall_iter_accepted + 1
206 };
207
208 self.reanneal_iter_accepted = if accepted {
209 0
210 } else {
211 self.reanneal_iter_accepted + 1
212 };
213
214 self.stall_iter_best = if new_best {
215 0
216 } else {
217 self.stall_iter_best + 1
218 };
219
220 self.reanneal_iter_best = if new_best {
221 0
222 } else {
223 self.reanneal_iter_best + 1
224 };
225 }
226}
227
228impl<O, F> Solver<O> for SimulatedAnnealing<F>
229where
230 O: ArgminOp<Output = F, Float = F>,
231 F: ArgminFloat,
232{
233 const NAME: &'static str = "Simulated Annealing";
234 fn init(
235 &mut self,
236 _op: &mut OpWrapper<O>,
237 _state: &IterState<O>,
238 ) -> Result<Option<ArgminIterData<O>>, Error> {
239 Ok(Some(ArgminIterData::new().kv(make_kv!(
240 "initial_temperature" => self.init_temp;
241 "stall_iter_accepted_limit" => self.stall_iter_accepted_limit;
242 "stall_iter_best_limit" => self.stall_iter_best_limit;
243 "reanneal_fixed" => self.reanneal_fixed;
244 "reanneal_accepted" => self.reanneal_accepted;
245 "reanneal_best" => self.reanneal_best;
246 ))))
247 }
248
249 fn next_iter(
251 &mut self,
252 op: &mut OpWrapper<O>,
253 state: &IterState<O>,
254 ) -> Result<ArgminIterData<O>, Error> {
255 let prev_param = state.get_param();
260 let prev_cost = state.get_cost();
261
262 let new_param = op.modify(&prev_param, self.cur_temp)?;
264 let new_cost = op.apply(&new_param)?;
268
269 let prob: f64 = self.rng.gen();
281 let prob = F::from_f64(prob).unwrap();
282 let accepted = (new_cost < state.get_prev_cost())
283 || (F::from_f64(1.0).unwrap()
284 / (F::from_f64(1.0).unwrap()
285 + ((new_cost - state.get_prev_cost()) / self.cur_temp).exp())
286 > prob);
287
288 self.update_stall_and_reanneal_iter(accepted, new_cost <= state.get_best_cost());
290
291 let (r_fixed, r_accepted, r_best) = self.reanneal();
292
293 self.temp_iter += 1;
295 self.reanneal_iter_fixed += 1;
297
298 self.update_temperature();
299
300 Ok(if accepted {
301 ArgminIterData::new().param(new_param).cost(new_cost)
302 } else {
303 ArgminIterData::new().param(prev_param).cost(prev_cost)
304 }
305 .kv(make_kv!(
306 "t" => self.cur_temp;
307 "new_be" => new_cost <= state.get_best_cost();
308 "acc" => accepted;
309 "st_i_be" => self.stall_iter_best;
310 "st_i_ac" => self.stall_iter_accepted;
311 "ra_i_fi" => self.reanneal_iter_fixed;
312 "ra_i_be" => self.reanneal_iter_best;
313 "ra_i_ac" => self.reanneal_iter_accepted;
314 "ra_fi" => r_fixed;
315 "ra_be" => r_best;
316 "ra_ac" => r_accepted;
317 )))
318 }
319
320 fn terminate(&mut self, _state: &IterState<O>) -> TerminationReason {
321 if self.stall_iter_accepted > self.stall_iter_accepted_limit {
322 return TerminationReason::AcceptedStallIterExceeded;
323 }
324 if self.stall_iter_best > self.stall_iter_best_limit {
325 return TerminationReason::BestStallIterExceeded;
326 }
327 TerminationReason::NotTerminated
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::test_trait_impl;
335
336 test_trait_impl!(sa, SimulatedAnnealing<f64>);
337}