1use crate::core::{ArgminOp, OpWrapper, TerminationReason};
9use num::traits::float::Float;
10use paste::item;
11use serde::{Deserialize, Serialize};
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct IterState<O: ArgminOp> {
16 pub param: O::Param,
18 pub prev_param: O::Param,
20 pub best_param: O::Param,
22 pub prev_best_param: O::Param,
24 pub cost: O::Float,
26 pub prev_cost: O::Float,
28 pub best_cost: O::Float,
30 pub prev_best_cost: O::Float,
32 pub target_cost: O::Float,
34 pub grad: Option<O::Param>,
36 pub prev_grad: Option<O::Param>,
38 pub hessian: Option<O::Hessian>,
40 pub prev_hessian: Option<O::Hessian>,
42 pub jacobian: Option<O::Jacobian>,
44 pub prev_jacobian: Option<O::Jacobian>,
46 pub population: Option<Vec<(O::Param, O::Float)>>,
48 pub iter: u64,
50 pub last_best_iter: u64,
52 pub max_iters: u64,
54 pub cost_func_count: u64,
56 pub grad_func_count: u64,
58 pub hessian_func_count: u64,
60 pub jacobian_func_count: u64,
62 pub modify_func_count: u64,
64 pub time: std::time::Duration,
66 pub termination_reason: TerminationReason,
68}
69
70macro_rules! setter {
71 ($name:ident, $type:ty, $doc:tt) => {
72 #[doc=$doc]
73 pub fn $name(&mut self, $name: $type) -> &mut Self {
74 self.$name = $name;
75 self
76 }
77 };
78}
79
80macro_rules! getter_option {
81 ($name:ident, $type:ty, $doc:tt) => {
82 item! {
83 #[doc=$doc]
84 pub fn [<get_ $name>](&self) -> Option<$type> {
85 self.$name.clone()
86 }
87 }
88 };
89}
90
91macro_rules! getter {
92 ($name:ident, $type:ty, $doc:tt) => {
93 item! {
94 #[doc=$doc]
95 pub fn [<get_ $name>](&self) -> $type {
96 self.$name.clone()
97 }
98 }
99 };
100}
101
102impl<O: ArgminOp> std::default::Default for IterState<O>
103where
104 O::Param: Default,
105{
106 fn default() -> Self {
107 IterState::new(O::Param::default())
108 }
109}
110
111impl<O: ArgminOp> IterState<O> {
112 pub fn new(param: O::Param) -> Self {
114 IterState {
115 param: param.clone(),
116 prev_param: param.clone(),
117 best_param: param.clone(),
118 prev_best_param: param,
119 cost: O::Float::infinity(),
120 prev_cost: O::Float::infinity(),
121 best_cost: O::Float::infinity(),
122 prev_best_cost: O::Float::infinity(),
123 target_cost: O::Float::neg_infinity(),
124 grad: None,
125 prev_grad: None,
126 hessian: None,
127 prev_hessian: None,
128 jacobian: None,
129 prev_jacobian: None,
130 population: None,
131 iter: 0,
132 last_best_iter: 0,
133 max_iters: std::u64::MAX,
134 cost_func_count: 0,
135 grad_func_count: 0,
136 hessian_func_count: 0,
137 jacobian_func_count: 0,
138 modify_func_count: 0,
139 time: std::time::Duration::new(0, 0),
140 termination_reason: TerminationReason::NotTerminated,
141 }
142 }
143
144 pub fn param(&mut self, param: O::Param) -> &mut Self {
147 std::mem::swap(&mut self.prev_param, &mut self.param);
148 self.param = param;
149 self
150 }
151
152 pub fn best_param(&mut self, param: O::Param) -> &mut Self {
155 std::mem::swap(&mut self.prev_best_param, &mut self.best_param);
156 self.best_param = param;
157 self
158 }
159
160 pub fn cost(&mut self, cost: O::Float) -> &mut Self {
163 std::mem::swap(&mut self.prev_cost, &mut self.cost);
164 self.cost = cost;
165 self
166 }
167
168 pub fn best_cost(&mut self, cost: O::Float) -> &mut Self {
171 std::mem::swap(&mut self.prev_best_cost, &mut self.best_cost);
172 self.best_cost = cost;
173 self
174 }
175
176 pub fn grad(&mut self, grad: O::Param) -> &mut Self {
178 std::mem::swap(&mut self.prev_grad, &mut self.grad);
179 self.grad = Some(grad);
180 self
181 }
182
183 pub fn hessian(&mut self, hessian: O::Hessian) -> &mut Self {
185 std::mem::swap(&mut self.prev_hessian, &mut self.hessian);
186 self.hessian = Some(hessian);
187 self
188 }
189
190 pub fn jacobian(&mut self, jacobian: O::Jacobian) -> &mut Self {
192 std::mem::swap(&mut self.prev_jacobian, &mut self.jacobian);
193 self.jacobian = Some(jacobian);
194 self
195 }
196
197 pub fn population(&mut self, population: Vec<(O::Param, O::Float)>) -> &mut Self {
199 self.population = Some(population);
200 self
201 }
202
203 setter!(target_cost, O::Float, "Set target cost value");
204 setter!(max_iters, u64, "Set maximum number of iterations");
205 setter!(
206 last_best_iter,
207 u64,
208 "Set iteration number where the previous best parameter vector was found"
209 );
210 setter!(
211 termination_reason,
212 TerminationReason,
213 "Set termination_reason"
214 );
215 setter!(time, std::time::Duration, "Set time required so far");
216 getter!(param, O::Param, "Returns current parameter vector");
217 getter!(prev_param, O::Param, "Returns previous parameter vector");
218 getter!(best_param, O::Param, "Returns best parameter vector");
219 getter!(
220 prev_best_param,
221 O::Param,
222 "Returns previous best parameter vector"
223 );
224 getter!(cost, O::Float, "Returns current cost function value");
225 getter!(prev_cost, O::Float, "Returns previous cost function value");
226 getter!(
227 best_cost,
228 O::Float,
229 "Returns current best cost function value"
230 );
231 getter!(
232 prev_best_cost,
233 O::Float,
234 "Returns previous best cost function value"
235 );
236 getter!(target_cost, O::Float, "Returns target cost");
237 getter!(
238 cost_func_count,
239 u64,
240 "Returns current cost function evaluation count"
241 );
242 getter!(
243 grad_func_count,
244 u64,
245 "Returns current gradient function evaluation count"
246 );
247 getter!(
248 hessian_func_count,
249 u64,
250 "Returns current Hessian function evaluation count"
251 );
252 getter!(
253 jacobian_func_count,
254 u64,
255 "Returns current Jacobian function evaluation count"
256 );
257 getter!(
258 modify_func_count,
259 u64,
260 "Returns current Modify function evaluation count"
261 );
262 getter!(
263 last_best_iter,
264 u64,
265 "Returns iteration number where the last best parameter vector was found"
266 );
267 getter!(
268 termination_reason,
269 TerminationReason,
270 "Get termination_reason"
271 );
272 getter!(time, std::time::Duration, "Get time required so far");
273 getter_option!(grad, O::Param, "Returns gradient");
274 getter_option!(prev_grad, O::Param, "Returns previous gradient");
275 getter_option!(hessian, O::Hessian, "Returns current Hessian");
276 getter_option!(prev_hessian, O::Hessian, "Returns previous Hessian");
277 getter_option!(jacobian, O::Jacobian, "Returns current Jacobian");
278 getter_option!(prev_jacobian, O::Jacobian, "Returns previous Jacobian");
279 getter!(iter, u64, "Returns current number of iterations");
280 getter!(max_iters, u64, "Returns maximum number of iterations");
281
282 pub fn get_population(&self) -> Option<&Vec<(O::Param, O::Float)>> {
284 match &self.population {
285 Some(population) => Some(&population),
286 None => None,
287 }
288 }
289
290 pub fn increment_iter(&mut self) {
292 self.iter += 1;
293 }
294
295 pub fn increment_func_counts(&mut self, op: &OpWrapper<O>) {
298 self.cost_func_count += op.cost_func_count;
299 self.grad_func_count += op.grad_func_count;
300 self.hessian_func_count += op.hessian_func_count;
301 self.jacobian_func_count += op.jacobian_func_count;
302 self.modify_func_count += op.modify_func_count;
303 }
304
305 pub fn set_func_counts(&mut self, op: &OpWrapper<O>) {
308 self.cost_func_count = op.cost_func_count;
309 self.grad_func_count = op.grad_func_count;
310 self.hessian_func_count = op.hessian_func_count;
311 self.jacobian_func_count = op.jacobian_func_count;
312 self.modify_func_count = op.modify_func_count;
313 }
314
315 pub fn increment_cost_func_count(&mut self, num: u64) {
317 self.cost_func_count += num;
318 }
319
320 pub fn increment_grad_func_count(&mut self, num: u64) {
322 self.grad_func_count += num;
323 }
324
325 pub fn increment_hessian_func_count(&mut self, num: u64) {
327 self.hessian_func_count += num;
328 }
329
330 pub fn increment_jacobian_func_count(&mut self, num: u64) {
332 self.jacobian_func_count += num;
333 }
334
335 pub fn increment_modify_func_count(&mut self, num: u64) {
337 self.modify_func_count += num;
338 }
339
340 pub fn new_best(&mut self) {
342 self.last_best_iter = self.iter;
343 }
344
345 pub fn is_best(&self) -> bool {
348 self.last_best_iter == self.iter
349 }
350
351 pub fn terminated(&self) -> bool {
353 self.termination_reason.terminated()
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::core::MinimalNoOperator;
361
362 #[test]
363 fn test_iterstate() {
364 let param = vec![1.0f64, 2.0];
365 let cost: f64 = 42.0;
366
367 let mut state: IterState<MinimalNoOperator> = IterState::new(param.clone());
368
369 assert_eq!(state.get_param(), param);
370 assert_eq!(state.get_prev_param(), param);
371 assert_eq!(state.get_best_param(), param);
372 assert_eq!(state.get_prev_best_param(), param);
373 assert_eq!(state.get_cost(), std::f64::INFINITY);
374 assert_eq!(state.get_prev_cost(), std::f64::INFINITY);
375 assert_eq!(state.get_best_cost(), std::f64::INFINITY);
376 assert_eq!(state.get_prev_best_cost(), std::f64::INFINITY);
377 assert_eq!(state.get_target_cost(), std::f64::NEG_INFINITY);
378 assert_eq!(state.get_grad(), None);
379 assert_eq!(state.get_prev_grad(), None);
380 assert_eq!(state.get_hessian(), None);
381 assert_eq!(state.get_prev_hessian(), None);
382 assert_eq!(state.get_jacobian(), None);
383 assert_eq!(state.get_prev_jacobian(), None);
384 assert_eq!(state.get_iter(), 0);
385 assert_eq!(state.is_best(), true);
386 assert_eq!(state.get_max_iters(), std::u64::MAX);
387 assert_eq!(state.get_cost_func_count(), 0);
388 assert_eq!(state.get_grad_func_count(), 0);
389 assert_eq!(state.get_hessian_func_count(), 0);
390 assert_eq!(state.get_jacobian_func_count(), 0);
391 assert_eq!(state.get_modify_func_count(), 0);
392
393 state.max_iters(42);
394
395 assert_eq!(state.get_max_iters(), 42);
396
397 state.cost(cost);
398
399 assert_eq!(state.get_cost(), cost);
400 assert_eq!(state.get_prev_cost(), std::f64::INFINITY);
401
402 state.best_cost(cost);
403
404 assert_eq!(state.get_best_cost(), cost);
405 assert_eq!(state.get_prev_best_cost(), std::f64::INFINITY);
406
407 let new_param = vec![2.0, 1.0];
408
409 state.param(new_param.clone());
410
411 assert_eq!(state.get_param(), new_param);
412 assert_eq!(state.get_prev_param(), param);
413
414 state.best_param(new_param.clone());
415
416 assert_eq!(state.get_best_param(), new_param);
417 assert_eq!(state.get_prev_best_param(), param);
418
419 let new_cost = 21.0;
420
421 state.cost(new_cost);
422
423 assert_eq!(state.get_cost(), new_cost);
424 assert_eq!(state.get_prev_cost(), cost);
425
426 state.best_cost(new_cost);
427
428 assert_eq!(state.get_best_cost(), new_cost);
429 assert_eq!(state.get_prev_best_cost(), cost);
430
431 state.increment_iter();
432
433 assert_eq!(state.get_iter(), 1);
434
435 assert_eq!(state.is_best(), false);
436
437 state.new_best();
438
439 assert_eq!(state.is_best(), true);
440
441 let grad = vec![1.0, 2.0];
442
443 state.grad(grad.clone());
444 assert_eq!(state.get_grad(), Some(grad.clone()));
445 assert_eq!(state.get_prev_grad(), None);
446
447 let new_grad = vec![2.0, 1.0];
448
449 state.grad(new_grad.clone());
450
451 assert_eq!(state.get_grad(), Some(new_grad.clone()));
452 assert_eq!(state.get_prev_grad(), Some(grad.clone()));
453
454 let hessian = vec![vec![1.0, 2.0], vec![2.0, 1.0]];
455
456 state.hessian(hessian.clone());
457 assert_eq!(state.get_hessian(), Some(hessian.clone()));
458 assert_eq!(state.get_prev_hessian(), None);
459
460 let new_hessian = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
461
462 state.hessian(new_hessian.clone());
463
464 assert_eq!(state.get_hessian(), Some(new_hessian.clone()));
465 assert_eq!(state.get_prev_hessian(), Some(hessian.clone()));
466
467 let jacobian = vec![1.0, 2.0];
468
469 state.jacobian(jacobian.clone());
470 assert_eq!(state.get_jacobian(), Some(jacobian.clone()));
471 assert_eq!(state.get_prev_jacobian(), None);
472
473 let new_jacobian = vec![2.0, 1.0];
474
475 state.jacobian(new_jacobian.clone());
476
477 assert_eq!(state.get_jacobian(), Some(new_jacobian.clone()));
478 assert_eq!(state.get_prev_jacobian(), Some(jacobian.clone()));
479
480 state.increment_iter();
481
482 assert_eq!(state.get_iter(), 2);
483 assert_eq!(state.get_last_best_iter(), 1);
484 assert_eq!(state.is_best(), false);
485
486 state.increment_cost_func_count(42);
487 assert_eq!(state.get_cost_func_count(), 42);
488 state.increment_grad_func_count(43);
489 assert_eq!(state.get_grad_func_count(), 43);
490 state.increment_hessian_func_count(44);
491 assert_eq!(state.get_hessian_func_count(), 44);
492 state.increment_jacobian_func_count(46);
493 assert_eq!(state.get_jacobian_func_count(), 46);
494 state.increment_modify_func_count(45);
495 assert_eq!(state.get_modify_func_count(), 45);
496
497 assert_eq!(state.get_iter(), 2);
499 assert_eq!(state.get_last_best_iter(), 1);
500 assert_eq!(state.get_max_iters(), 42);
501 assert_eq!(state.is_best(), false);
502 assert_eq!(state.get_cost(), new_cost);
503 assert_eq!(state.get_prev_cost(), cost);
504 assert_eq!(state.get_param(), new_param);
505 assert_eq!(state.get_prev_param(), param);
506 assert_eq!(state.get_best_cost(), new_cost);
507 assert_eq!(state.get_prev_best_cost(), cost);
508 assert_eq!(state.get_best_param(), new_param);
509 assert_eq!(state.get_prev_best_param(), param);
510 assert_eq!(state.get_best_cost(), new_cost);
511 assert_eq!(state.get_prev_best_cost(), cost);
512 assert_eq!(state.get_grad(), Some(new_grad));
513 assert_eq!(state.get_prev_grad(), Some(grad));
514 assert_eq!(state.get_hessian(), Some(new_hessian));
515 assert_eq!(state.get_prev_hessian(), Some(hessian));
516 assert_eq!(state.get_jacobian(), Some(new_jacobian));
517 assert_eq!(state.get_prev_jacobian(), Some(jacobian));
518 assert_eq!(state.get_cost_func_count(), 42);
519 assert_eq!(state.get_grad_func_count(), 43);
520 assert_eq!(state.get_hessian_func_count(), 44);
521 assert_eq!(state.get_jacobian_func_count(), 46);
522 assert_eq!(state.get_modify_func_count(), 45);
523 }
524}