argmin/core/
iterstate.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{ArgminOp, OpWrapper, TerminationReason};
9use num::traits::float::Float;
10use paste::item;
11use serde::{Deserialize, Serialize};
12
13/// Maintains the state from iteration to iteration of a solver
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct IterState<O: ArgminOp> {
16    /// Current parameter vector
17    pub param: O::Param,
18    /// Previous parameter vector
19    pub prev_param: O::Param,
20    /// Current best parameter vector
21    pub best_param: O::Param,
22    /// Previous best parameter vector
23    pub prev_best_param: O::Param,
24    /// Current cost function value
25    pub cost: O::Float,
26    /// Previous cost function value
27    pub prev_cost: O::Float,
28    /// Current best cost function value
29    pub best_cost: O::Float,
30    /// Previous best cost function value
31    pub prev_best_cost: O::Float,
32    /// Target cost function value
33    pub target_cost: O::Float,
34    /// Current gradient
35    pub grad: Option<O::Param>,
36    /// Previous gradient
37    pub prev_grad: Option<O::Param>,
38    /// Current Hessian
39    pub hessian: Option<O::Hessian>,
40    /// Previous Hessian
41    pub prev_hessian: Option<O::Hessian>,
42    /// Current Jacobian
43    pub jacobian: Option<O::Jacobian>,
44    /// Previous Jacobian
45    pub prev_jacobian: Option<O::Jacobian>,
46    /// All members for population-based algorithms as (param, cost) tuples
47    pub population: Option<Vec<(O::Param, O::Float)>>,
48    /// Current iteration
49    pub iter: u64,
50    /// Iteration number of last best cost
51    pub last_best_iter: u64,
52    /// Maximum number of iterations
53    pub max_iters: u64,
54    /// Number of cost function evaluations so far
55    pub cost_func_count: u64,
56    /// Number of gradient evaluations so far
57    pub grad_func_count: u64,
58    /// Number of Hessian evaluations so far
59    pub hessian_func_count: u64,
60    /// Number of Jacobian evaluations so far
61    pub jacobian_func_count: u64,
62    /// Number of modify evaluations so far
63    pub modify_func_count: u64,
64    /// Time required so far
65    pub time: std::time::Duration,
66    /// Reason of termination
67    pub termination_reason: TerminationReason,
68}
69
70macro_rules! setter {
71    ($name:ident, $type:ty, $doc:tt) => {
72        #[doc=$doc]
73        pub fn $name(&mut self, $name: $type) -> &mut Self {
74            self.$name = $name;
75            self
76        }
77    };
78}
79
80macro_rules! getter_option {
81    ($name:ident, $type:ty, $doc:tt) => {
82        item! {
83            #[doc=$doc]
84            pub fn [<get_ $name>](&self) -> Option<$type> {
85                self.$name.clone()
86            }
87        }
88    };
89}
90
91macro_rules! getter {
92    ($name:ident, $type:ty, $doc:tt) => {
93        item! {
94            #[doc=$doc]
95            pub fn [<get_ $name>](&self) -> $type {
96                self.$name.clone()
97            }
98        }
99    };
100}
101
102impl<O: ArgminOp> std::default::Default for IterState<O>
103where
104    O::Param: Default,
105{
106    fn default() -> Self {
107        IterState::new(O::Param::default())
108    }
109}
110
111impl<O: ArgminOp> IterState<O> {
112    /// Create new IterState from `param`
113    pub fn new(param: O::Param) -> Self {
114        IterState {
115            param: param.clone(),
116            prev_param: param.clone(),
117            best_param: param.clone(),
118            prev_best_param: param,
119            cost: O::Float::infinity(),
120            prev_cost: O::Float::infinity(),
121            best_cost: O::Float::infinity(),
122            prev_best_cost: O::Float::infinity(),
123            target_cost: O::Float::neg_infinity(),
124            grad: None,
125            prev_grad: None,
126            hessian: None,
127            prev_hessian: None,
128            jacobian: None,
129            prev_jacobian: None,
130            population: None,
131            iter: 0,
132            last_best_iter: 0,
133            max_iters: std::u64::MAX,
134            cost_func_count: 0,
135            grad_func_count: 0,
136            hessian_func_count: 0,
137            jacobian_func_count: 0,
138            modify_func_count: 0,
139            time: std::time::Duration::new(0, 0),
140            termination_reason: TerminationReason::NotTerminated,
141        }
142    }
143
144    /// Set parameter vector. This shifts the stored parameter vector to the previous parameter
145    /// vector.
146    pub fn param(&mut self, param: O::Param) -> &mut Self {
147        std::mem::swap(&mut self.prev_param, &mut self.param);
148        self.param = param;
149        self
150    }
151
152    /// Set best paramater vector. This shifts the stored best parameter vector to the previous
153    /// best parameter vector.
154    pub fn best_param(&mut self, param: O::Param) -> &mut Self {
155        std::mem::swap(&mut self.prev_best_param, &mut self.best_param);
156        self.best_param = param;
157        self
158    }
159
160    /// Set the current cost function value. This shifts the stored cost function value to the
161    /// previous cost function value.
162    pub fn cost(&mut self, cost: O::Float) -> &mut Self {
163        std::mem::swap(&mut self.prev_cost, &mut self.cost);
164        self.cost = cost;
165        self
166    }
167
168    /// Set the current best cost function value. This shifts the stored best cost function value to
169    /// the previous cost function value.
170    pub fn best_cost(&mut self, cost: O::Float) -> &mut Self {
171        std::mem::swap(&mut self.prev_best_cost, &mut self.best_cost);
172        self.best_cost = cost;
173        self
174    }
175
176    /// Set gradient. This shifts the stored gradient to the previous gradient.
177    pub fn grad(&mut self, grad: O::Param) -> &mut Self {
178        std::mem::swap(&mut self.prev_grad, &mut self.grad);
179        self.grad = Some(grad);
180        self
181    }
182
183    /// Set Hessian. This shifts the stored Hessian to the previous Hessian.
184    pub fn hessian(&mut self, hessian: O::Hessian) -> &mut Self {
185        std::mem::swap(&mut self.prev_hessian, &mut self.hessian);
186        self.hessian = Some(hessian);
187        self
188    }
189
190    /// Set Jacobian. This shifts the stored Jacobian to the previous Jacobian.
191    pub fn jacobian(&mut self, jacobian: O::Jacobian) -> &mut Self {
192        std::mem::swap(&mut self.prev_jacobian, &mut self.jacobian);
193        self.jacobian = Some(jacobian);
194        self
195    }
196
197    /// Set population
198    pub fn population(&mut self, population: Vec<(O::Param, O::Float)>) -> &mut Self {
199        self.population = Some(population);
200        self
201    }
202
203    setter!(target_cost, O::Float, "Set target cost value");
204    setter!(max_iters, u64, "Set maximum number of iterations");
205    setter!(
206        last_best_iter,
207        u64,
208        "Set iteration number where the previous best parameter vector was found"
209    );
210    setter!(
211        termination_reason,
212        TerminationReason,
213        "Set termination_reason"
214    );
215    setter!(time, std::time::Duration, "Set time required so far");
216    getter!(param, O::Param, "Returns current parameter vector");
217    getter!(prev_param, O::Param, "Returns previous parameter vector");
218    getter!(best_param, O::Param, "Returns best parameter vector");
219    getter!(
220        prev_best_param,
221        O::Param,
222        "Returns previous best parameter vector"
223    );
224    getter!(cost, O::Float, "Returns current cost function value");
225    getter!(prev_cost, O::Float, "Returns previous cost function value");
226    getter!(
227        best_cost,
228        O::Float,
229        "Returns current best cost function value"
230    );
231    getter!(
232        prev_best_cost,
233        O::Float,
234        "Returns previous best cost function value"
235    );
236    getter!(target_cost, O::Float, "Returns target cost");
237    getter!(
238        cost_func_count,
239        u64,
240        "Returns current cost function evaluation count"
241    );
242    getter!(
243        grad_func_count,
244        u64,
245        "Returns current gradient function evaluation count"
246    );
247    getter!(
248        hessian_func_count,
249        u64,
250        "Returns current Hessian function evaluation count"
251    );
252    getter!(
253        jacobian_func_count,
254        u64,
255        "Returns current Jacobian function evaluation count"
256    );
257    getter!(
258        modify_func_count,
259        u64,
260        "Returns current Modify function evaluation count"
261    );
262    getter!(
263        last_best_iter,
264        u64,
265        "Returns iteration number where the last best parameter vector was found"
266    );
267    getter!(
268        termination_reason,
269        TerminationReason,
270        "Get termination_reason"
271    );
272    getter!(time, std::time::Duration, "Get time required so far");
273    getter_option!(grad, O::Param, "Returns gradient");
274    getter_option!(prev_grad, O::Param, "Returns previous gradient");
275    getter_option!(hessian, O::Hessian, "Returns current Hessian");
276    getter_option!(prev_hessian, O::Hessian, "Returns previous Hessian");
277    getter_option!(jacobian, O::Jacobian, "Returns current Jacobian");
278    getter_option!(prev_jacobian, O::Jacobian, "Returns previous Jacobian");
279    getter!(iter, u64, "Returns current number of iterations");
280    getter!(max_iters, u64, "Returns maximum number of iterations");
281
282    /// Returns population
283    pub fn get_population(&self) -> Option<&Vec<(O::Param, O::Float)>> {
284        match &self.population {
285            Some(population) => Some(&population),
286            None => None,
287        }
288    }
289
290    /// Increment the number of iterations by one
291    pub fn increment_iter(&mut self) {
292        self.iter += 1;
293    }
294
295    /// Increment all function evaluation counts by the evaluation counts of another operator
296    /// wrapped in `OpWrapper`.
297    pub fn increment_func_counts(&mut self, op: &OpWrapper<O>) {
298        self.cost_func_count += op.cost_func_count;
299        self.grad_func_count += op.grad_func_count;
300        self.hessian_func_count += op.hessian_func_count;
301        self.jacobian_func_count += op.jacobian_func_count;
302        self.modify_func_count += op.modify_func_count;
303    }
304
305    /// Set all function evaluation counts to the evaluation counts of another operator
306    /// wrapped in `OpWrapper`.
307    pub fn set_func_counts(&mut self, op: &OpWrapper<O>) {
308        self.cost_func_count = op.cost_func_count;
309        self.grad_func_count = op.grad_func_count;
310        self.hessian_func_count = op.hessian_func_count;
311        self.jacobian_func_count = op.jacobian_func_count;
312        self.modify_func_count = op.modify_func_count;
313    }
314
315    /// Increment cost function evaluation count by `num`
316    pub fn increment_cost_func_count(&mut self, num: u64) {
317        self.cost_func_count += num;
318    }
319
320    /// Increment gradient function evaluation count by `num`
321    pub fn increment_grad_func_count(&mut self, num: u64) {
322        self.grad_func_count += num;
323    }
324
325    /// Increment Hessian function evaluation count by `num`
326    pub fn increment_hessian_func_count(&mut self, num: u64) {
327        self.hessian_func_count += num;
328    }
329
330    /// Increment Jacobian function evaluation count by `num`
331    pub fn increment_jacobian_func_count(&mut self, num: u64) {
332        self.jacobian_func_count += num;
333    }
334
335    /// Increment modify function evaluation count by `num`
336    pub fn increment_modify_func_count(&mut self, num: u64) {
337        self.modify_func_count += num;
338    }
339
340    /// Indicate that a new best parameter vector was found
341    pub fn new_best(&mut self) {
342        self.last_best_iter = self.iter;
343    }
344
345    /// Returns whether the current parameter vector is also the best parameter vector found so
346    /// far.
347    pub fn is_best(&self) -> bool {
348        self.last_best_iter == self.iter
349    }
350
351    /// Return whether the algorithm has terminated or not
352    pub fn terminated(&self) -> bool {
353        self.termination_reason.terminated()
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use crate::core::MinimalNoOperator;
361
362    #[test]
363    fn test_iterstate() {
364        let param = vec![1.0f64, 2.0];
365        let cost: f64 = 42.0;
366
367        let mut state: IterState<MinimalNoOperator> = IterState::new(param.clone());
368
369        assert_eq!(state.get_param(), param);
370        assert_eq!(state.get_prev_param(), param);
371        assert_eq!(state.get_best_param(), param);
372        assert_eq!(state.get_prev_best_param(), param);
373        assert_eq!(state.get_cost(), std::f64::INFINITY);
374        assert_eq!(state.get_prev_cost(), std::f64::INFINITY);
375        assert_eq!(state.get_best_cost(), std::f64::INFINITY);
376        assert_eq!(state.get_prev_best_cost(), std::f64::INFINITY);
377        assert_eq!(state.get_target_cost(), std::f64::NEG_INFINITY);
378        assert_eq!(state.get_grad(), None);
379        assert_eq!(state.get_prev_grad(), None);
380        assert_eq!(state.get_hessian(), None);
381        assert_eq!(state.get_prev_hessian(), None);
382        assert_eq!(state.get_jacobian(), None);
383        assert_eq!(state.get_prev_jacobian(), None);
384        assert_eq!(state.get_iter(), 0);
385        assert_eq!(state.is_best(), true);
386        assert_eq!(state.get_max_iters(), std::u64::MAX);
387        assert_eq!(state.get_cost_func_count(), 0);
388        assert_eq!(state.get_grad_func_count(), 0);
389        assert_eq!(state.get_hessian_func_count(), 0);
390        assert_eq!(state.get_jacobian_func_count(), 0);
391        assert_eq!(state.get_modify_func_count(), 0);
392
393        state.max_iters(42);
394
395        assert_eq!(state.get_max_iters(), 42);
396
397        state.cost(cost);
398
399        assert_eq!(state.get_cost(), cost);
400        assert_eq!(state.get_prev_cost(), std::f64::INFINITY);
401
402        state.best_cost(cost);
403
404        assert_eq!(state.get_best_cost(), cost);
405        assert_eq!(state.get_prev_best_cost(), std::f64::INFINITY);
406
407        let new_param = vec![2.0, 1.0];
408
409        state.param(new_param.clone());
410
411        assert_eq!(state.get_param(), new_param);
412        assert_eq!(state.get_prev_param(), param);
413
414        state.best_param(new_param.clone());
415
416        assert_eq!(state.get_best_param(), new_param);
417        assert_eq!(state.get_prev_best_param(), param);
418
419        let new_cost = 21.0;
420
421        state.cost(new_cost);
422
423        assert_eq!(state.get_cost(), new_cost);
424        assert_eq!(state.get_prev_cost(), cost);
425
426        state.best_cost(new_cost);
427
428        assert_eq!(state.get_best_cost(), new_cost);
429        assert_eq!(state.get_prev_best_cost(), cost);
430
431        state.increment_iter();
432
433        assert_eq!(state.get_iter(), 1);
434
435        assert_eq!(state.is_best(), false);
436
437        state.new_best();
438
439        assert_eq!(state.is_best(), true);
440
441        let grad = vec![1.0, 2.0];
442
443        state.grad(grad.clone());
444        assert_eq!(state.get_grad(), Some(grad.clone()));
445        assert_eq!(state.get_prev_grad(), None);
446
447        let new_grad = vec![2.0, 1.0];
448
449        state.grad(new_grad.clone());
450
451        assert_eq!(state.get_grad(), Some(new_grad.clone()));
452        assert_eq!(state.get_prev_grad(), Some(grad.clone()));
453
454        let hessian = vec![vec![1.0, 2.0], vec![2.0, 1.0]];
455
456        state.hessian(hessian.clone());
457        assert_eq!(state.get_hessian(), Some(hessian.clone()));
458        assert_eq!(state.get_prev_hessian(), None);
459
460        let new_hessian = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
461
462        state.hessian(new_hessian.clone());
463
464        assert_eq!(state.get_hessian(), Some(new_hessian.clone()));
465        assert_eq!(state.get_prev_hessian(), Some(hessian.clone()));
466
467        let jacobian = vec![1.0, 2.0];
468
469        state.jacobian(jacobian.clone());
470        assert_eq!(state.get_jacobian(), Some(jacobian.clone()));
471        assert_eq!(state.get_prev_jacobian(), None);
472
473        let new_jacobian = vec![2.0, 1.0];
474
475        state.jacobian(new_jacobian.clone());
476
477        assert_eq!(state.get_jacobian(), Some(new_jacobian.clone()));
478        assert_eq!(state.get_prev_jacobian(), Some(jacobian.clone()));
479
480        state.increment_iter();
481
482        assert_eq!(state.get_iter(), 2);
483        assert_eq!(state.get_last_best_iter(), 1);
484        assert_eq!(state.is_best(), false);
485
486        state.increment_cost_func_count(42);
487        assert_eq!(state.get_cost_func_count(), 42);
488        state.increment_grad_func_count(43);
489        assert_eq!(state.get_grad_func_count(), 43);
490        state.increment_hessian_func_count(44);
491        assert_eq!(state.get_hessian_func_count(), 44);
492        state.increment_jacobian_func_count(46);
493        assert_eq!(state.get_jacobian_func_count(), 46);
494        state.increment_modify_func_count(45);
495        assert_eq!(state.get_modify_func_count(), 45);
496
497        // check again!
498        assert_eq!(state.get_iter(), 2);
499        assert_eq!(state.get_last_best_iter(), 1);
500        assert_eq!(state.get_max_iters(), 42);
501        assert_eq!(state.is_best(), false);
502        assert_eq!(state.get_cost(), new_cost);
503        assert_eq!(state.get_prev_cost(), cost);
504        assert_eq!(state.get_param(), new_param);
505        assert_eq!(state.get_prev_param(), param);
506        assert_eq!(state.get_best_cost(), new_cost);
507        assert_eq!(state.get_prev_best_cost(), cost);
508        assert_eq!(state.get_best_param(), new_param);
509        assert_eq!(state.get_prev_best_param(), param);
510        assert_eq!(state.get_best_cost(), new_cost);
511        assert_eq!(state.get_prev_best_cost(), cost);
512        assert_eq!(state.get_grad(), Some(new_grad));
513        assert_eq!(state.get_prev_grad(), Some(grad));
514        assert_eq!(state.get_hessian(), Some(new_hessian));
515        assert_eq!(state.get_prev_hessian(), Some(hessian));
516        assert_eq!(state.get_jacobian(), Some(new_jacobian));
517        assert_eq!(state.get_prev_jacobian(), Some(jacobian));
518        assert_eq!(state.get_cost_func_count(), 42);
519        assert_eq!(state.get_grad_func_count(), 43);
520        assert_eq!(state.get_hessian_func_count(), 44);
521        assert_eq!(state.get_jacobian_func_count(), 46);
522        assert_eq!(state.get_modify_func_count(), 45);
523    }
524}