argmin/core/result.rs
1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # `ArgminResult`
9//!
10//! Returned by a solver and consists of the used operator and the last `IterState` of the solver.
11//! Both can be accessed by the methods `operator()` and `state()`.
12//!
13//! The reference to the struct returned by `state()` allows one to for instance access the final
14//! parameter vector or the final cost function value.
15//!
16//! ## Examples:
17//!
18//! ```
19//! # #![allow(unused_imports)]
20//! # extern crate argmin;
21//! # extern crate argmin_testfunctions;
22//! # use argmin::prelude::*;
23//! # use argmin::solver::gradientdescent::SteepestDescent;
24//! # use argmin::solver::linesearch::MoreThuenteLineSearch;
25//! # use argmin_testfunctions::{rosenbrock_2d, rosenbrock_2d_derivative};
26//! # use serde::{Deserialize, Serialize};
27//! #
28//! # #[derive(Clone, Default, Serialize, Deserialize)]
29//! # struct Rosenbrock {
30//! # a: f64,
31//! # b: f64,
32//! # }
33//! #
34//! # impl ArgminOp for Rosenbrock {
35//! # type Param = Vec<f64>;
36//! # type Output = f64;
37//! # type Hessian = ();
38//! # type Jacobian = ();
39//! # type Float = f64;
40//! #
41//! # fn apply(&self, p: &Self::Param) -> Result<Self::Output, Error> {
42//! # Ok(rosenbrock_2d(p, self.a, self.b))
43//! # }
44//! #
45//! # fn gradient(&self, p: &Self::Param) -> Result<Self::Param, Error> {
46//! # Ok(rosenbrock_2d_derivative(p, self.a, self.b))
47//! # }
48//! # }
49//! #
50//! # fn run() -> Result<(), Error> {
51//! # // Define cost function (must implement `ArgminOperator`)
52//! # let cost = Rosenbrock { a: 1.0, b: 100.0 };
53//! # // Define initial parameter vector
54//! # let init_param: Vec<f64> = vec![-1.2, 1.0];
55//! # // Set up line search
56//! # let linesearch = MoreThuenteLineSearch::new();
57//! # // Set up solver
58//! # let solver = SteepestDescent::new(linesearch);
59//! # // Run solver
60//! # let result = Executor::new(cost, solver, init_param)
61//! # // Set maximum iterations to 10
62//! # .max_iters(1)
63//! # // run the solver on the defined problem
64//! # .run()?;
65//! // Get best parameter vector
66//! let best_parameter = result.state().get_best_param();
67//!
68//! // Get best cost function value
69//! let best_cost = result.state().get_best_cost();
70//!
71//! // Get the number of iterations
72//! let num_iters = result.state().get_iter();
73//! # Ok(())
74//! # }
75//! #
76//! # fn main() {
77//! # if let Err(ref e) = run() {
78//! # println!("{}", e);
79//! # std::process::exit(1);
80//! # }
81//! # }
82//! ```
83//!
84//! More details can be found in the `IterState` documentation.
85
86use crate::prelude::*;
87use std::cmp::Ordering;
88
89/// Final struct returned by the `run` method of `Executor`.
90#[derive(Clone)]
91pub struct ArgminResult<O: ArgminOp> {
92 /// operator
93 pub operator: O,
94 /// iteration state
95 pub state: IterState<O>,
96}
97
98impl<O: ArgminOp> ArgminResult<O> {
99 /// Constructor
100 pub fn new(operator: O, state: IterState<O>) -> Self {
101 ArgminResult { operator, state }
102 }
103
104 /// Return handle to operator
105 pub fn operator(&self) -> &O {
106 &self.operator
107 }
108
109 /// Return handle to state
110 pub fn state(&self) -> &IterState<O> {
111 &self.state
112 }
113}
114
115impl<O> std::fmt::Display for ArgminResult<O>
116where
117 O: ArgminOp,
118 O::Param: std::fmt::Debug,
119{
120 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
121 writeln!(f, "ArgminResult:")?;
122 writeln!(f, " param (best): {:?}", self.state.get_best_param())?;
123 writeln!(f, " cost (best): {}", self.state.get_best_cost())?;
124 writeln!(f, " iters (best): {}", self.state.get_last_best_iter())?;
125 writeln!(f, " iters (total): {}", self.state.get_iter())?;
126 writeln!(
127 f,
128 " termination: {}",
129 self.state.get_termination_reason()
130 )?;
131 writeln!(f, " time: {:?}", self.state.get_time())?;
132 Ok(())
133 }
134}
135
136impl<O: ArgminOp> PartialEq for ArgminResult<O> {
137 fn eq(&self, other: &ArgminResult<O>) -> bool {
138 (self.state.get_cost() - other.state.get_cost()).abs() < O::Float::epsilon()
139 }
140}
141
142impl<O: ArgminOp> Eq for ArgminResult<O> {}
143
144impl<O: ArgminOp> Ord for ArgminResult<O> {
145 fn cmp(&self, other: &ArgminResult<O>) -> Ordering {
146 let t = self.state.get_cost() - other.state.get_cost();
147 if t.abs() < O::Float::epsilon() {
148 Ordering::Equal
149 } else if t > O::Float::from_f64(0.0).unwrap() {
150 Ordering::Greater
151 } else {
152 Ordering::Less
153 }
154 }
155}
156
157impl<O: ArgminOp> PartialOrd for ArgminResult<O> {
158 fn partial_cmp(&self, other: &ArgminResult<O>) -> Option<Ordering> {
159 Some(self.cmp(other))
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::core::MinimalNoOperator;
167
168 send_sync_test!(argmin_result, ArgminResult<MinimalNoOperator>);
169}