argmin/core/
mod.rs

1// Copyright 2018-2020-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Argmin Optimizaton toolbox core
9//!
10//! This crate contains the core functionality of argmin. If you just want to run an optimization
11//! method, this is *not* what you are looking for. However, if you want to implement your own
12//! solver based on the argmin architecture, you should find all necessary tools here.
13
14// I really do not like the a..=b syntax
15#![allow(clippy::range_plus_one)]
16
17/// Macros
18#[macro_use]
19pub mod macros;
20/// Error handling
21mod errors;
22/// Executor
23pub mod executor;
24/// iteration state
25mod iterstate;
26/// Key value datastructure
27mod kv;
28/// Math utilities
29mod math;
30/// Phony Operator
31// #[cfg(test)]
32mod nooperator;
33/// Observers;
34mod observers;
35/// Wrapper around operators which keeps track of function evaluation counts
36mod opwrapper;
37/// Definition of the return type of the solvers
38mod result;
39/// Serialization of `ArgminSolver`s
40mod serialization;
41/// Definition of termination reasons
42mod termination;
43
44pub use anyhow::Error;
45pub use errors::*;
46pub use executor::*;
47pub use iterstate::*;
48pub use kv::ArgminKV;
49pub use math::*;
50pub use nooperator::*;
51use num::traits::{Float, FloatConst, FromPrimitive, ToPrimitive};
52pub use observers::*;
53pub use opwrapper::*;
54pub use result::ArgminResult;
55use serde::de::DeserializeOwned;
56use serde::Serialize;
57pub use serialization::*;
58use std::fmt::{Debug, Display};
59pub use termination::TerminationReason;
60
61/// Trait alias to simplify common trait bounds
62pub trait ArgminFloat:
63    Float + FloatConst + FromPrimitive + ToPrimitive + Debug + Display + Serialize + DeserializeOwned
64{
65}
66impl<I> ArgminFloat for I where
67    I: Float
68        + FloatConst
69        + FromPrimitive
70        + ToPrimitive
71        + Debug
72        + Display
73        + Serialize
74        + DeserializeOwned
75{
76}
77
78/// This trait needs to be implemented for every operator/cost function.
79///
80/// It is required to implement the `apply` method, all others are optional and provide a default
81/// implementation which is essentially returning an error which indicates that the method has not
82/// been implemented. Those methods (`gradient` and `modify`) only need to be implemented if the
83/// uses solver requires it.
84pub trait ArgminOp {
85    // TODO: Once associated type defaults are stable, it hopefully will be possible to define
86    // default types for `Hessian` and `Jacobian`.
87    /// Type of the parameter vector
88    type Param: Clone + Serialize + DeserializeOwned;
89    /// Output of the operator
90    type Output: Clone + Serialize + DeserializeOwned;
91    /// Type of Hessian
92    type Hessian: Clone + Serialize + DeserializeOwned;
93    /// Type of Jacobian
94    type Jacobian: Clone + Serialize + DeserializeOwned;
95    /// Precision of floats
96    type Float: ArgminFloat;
97
98    /// Applies the operator/cost function to parameters
99    fn apply(&self, _param: &Self::Param) -> Result<Self::Output, Error> {
100        Err(ArgminError::NotImplemented {
101            text: "Method `apply` of ArgminOp trait not implemented!".to_string(),
102        }
103        .into())
104    }
105
106    /// Computes the gradient at the given parameters
107    fn gradient(&self, _param: &Self::Param) -> Result<Self::Param, Error> {
108        Err(ArgminError::NotImplemented {
109            text: "Method `gradient` of ArgminOp trait not implemented!".to_string(),
110        }
111        .into())
112    }
113
114    /// Computes the Hessian at the given parameters
115    fn hessian(&self, _param: &Self::Param) -> Result<Self::Hessian, Error> {
116        Err(ArgminError::NotImplemented {
117            text: "Method `hessian` of ArgminOp trait not implemented!".to_string(),
118        }
119        .into())
120    }
121
122    /// Computes the Hessian at the given parameters
123    fn jacobian(&self, _param: &Self::Param) -> Result<Self::Jacobian, Error> {
124        Err(ArgminError::NotImplemented {
125            text: "Method `jacobian` of ArgminOp trait not implemented!".to_string(),
126        }
127        .into())
128    }
129
130    /// Modifies a parameter vector. Comes with a variable that indicates the "degree" of the
131    /// modification.
132    fn modify(&self, _param: &Self::Param, _extent: Self::Float) -> Result<Self::Param, Error> {
133        Err(ArgminError::NotImplemented {
134            text: "Method `modify` of ArgminOp trait not implemented!".to_string(),
135        }
136        .into())
137    }
138}
139
140/// Solver
141///
142/// Every solver needs to implement this trait.
143pub trait Solver<O: ArgminOp>: Serialize {
144    /// Name of the solver
145    const NAME: &'static str = "UNDEFINED";
146
147    /// Computes one iteration of the algorithm.
148    fn next_iter(
149        &mut self,
150        op: &mut OpWrapper<O>,
151        state: &IterState<O>,
152    ) -> Result<ArgminIterData<O>, Error>;
153
154    /// Initializes the algorithm
155    ///
156    /// This is executed before any iterations are performed. It can be used to perform
157    /// precomputations. The default implementation corresponds to doing nothing.
158    fn init(
159        &mut self,
160        _op: &mut OpWrapper<O>,
161        _state: &IterState<O>,
162    ) -> Result<Option<ArgminIterData<O>>, Error> {
163        Ok(None)
164    }
165
166    /// Checks whether basic termination reasons apply.
167    ///
168    /// Terminate if
169    ///
170    /// 1) algorithm was terminated somewhere else in the Executor
171    /// 2) iteration count exceeds maximum number of iterations
172    /// 3) cost is lower than target cost
173    ///
174    /// This can be overwritten in a `Solver` implementation; however it is not advised.
175    fn terminate_internal(&mut self, state: &IterState<O>) -> TerminationReason {
176        let solver_terminate = self.terminate(state);
177        if solver_terminate.terminated() {
178            return solver_terminate;
179        }
180        if state.get_iter() >= state.get_max_iters() {
181            return TerminationReason::MaxItersReached;
182        }
183        if state.get_cost() <= state.get_target_cost() {
184            return TerminationReason::TargetCostReached;
185        }
186        TerminationReason::NotTerminated
187    }
188
189    /// Checks whether the algorithm must be terminated
190    fn terminate(&mut self, _state: &IterState<O>) -> TerminationReason {
191        TerminationReason::NotTerminated
192    }
193}
194
195/// The datastructure which is returned by the `next_iter` method of the `Solver` trait.
196///
197/// TODO: Rename to IterResult?
198#[derive(Clone, Debug, Default)]
199pub struct ArgminIterData<O: ArgminOp> {
200    /// Current parameter vector
201    param: Option<O::Param>,
202    /// Current cost function value
203    cost: Option<O::Float>,
204    /// Current gradient
205    grad: Option<O::Param>,
206    /// Current Hessian
207    hessian: Option<O::Hessian>,
208    /// Current Jacobian
209    jacobian: Option<O::Jacobian>,
210    /// Current population
211    population: Option<Vec<(O::Param, O::Float)>>,
212    /// terminationreason
213    termination_reason: Option<TerminationReason>,
214    /// Key value pairs which are used to provide additional information for the Observers
215    kv: ArgminKV,
216}
217
218// TODO: Many clones are necessary in the getters.. maybe a complete "deconstruct" method would be
219// better?
220impl<O: ArgminOp> ArgminIterData<O> {
221    /// Constructor
222    pub fn new() -> Self {
223        ArgminIterData {
224            param: None,
225            cost: None,
226            grad: None,
227            hessian: None,
228            jacobian: None,
229            termination_reason: None,
230            population: None,
231            kv: make_kv!(),
232        }
233    }
234
235    /// Set parameter vector
236    pub fn param(mut self, param: O::Param) -> Self {
237        self.param = Some(param);
238        self
239    }
240
241    /// Set cost function value
242    pub fn cost(mut self, cost: O::Float) -> Self {
243        self.cost = Some(cost);
244        self
245    }
246
247    /// Set gradient
248    pub fn grad(mut self, grad: O::Param) -> Self {
249        self.grad = Some(grad);
250        self
251    }
252
253    /// Set Hessian
254    pub fn hessian(mut self, hessian: O::Hessian) -> Self {
255        self.hessian = Some(hessian);
256        self
257    }
258
259    /// Set Jacobian
260    pub fn jacobian(mut self, jacobian: O::Jacobian) -> Self {
261        self.jacobian = Some(jacobian);
262        self
263    }
264
265    /// Set Population
266    pub fn population(mut self, population: Vec<(O::Param, O::Float)>) -> Self {
267        self.population = Some(population);
268        self
269    }
270
271    /// Adds an `ArgminKV`
272    pub fn kv(mut self, kv: ArgminKV) -> Self {
273        self.kv = kv;
274        self
275    }
276
277    /// Set termination reason
278    pub fn termination_reason(mut self, reason: TerminationReason) -> Self {
279        self.termination_reason = Some(reason);
280        self
281    }
282
283    /// Get parameter vector
284    pub fn get_param(&self) -> Option<O::Param> {
285        self.param.clone()
286    }
287
288    /// Get cost function value
289    pub fn get_cost(&self) -> Option<O::Float> {
290        self.cost
291    }
292
293    /// Get gradient
294    pub fn get_grad(&self) -> Option<O::Param> {
295        self.grad.clone()
296    }
297
298    /// Get Hessian
299    pub fn get_hessian(&self) -> Option<O::Hessian> {
300        self.hessian.clone()
301    }
302
303    /// Get Jacobian
304    pub fn get_jacobian(&self) -> Option<O::Jacobian> {
305        self.jacobian.clone()
306    }
307
308    /// Get reference to population
309    pub fn get_population(&self) -> Option<&Vec<(O::Param, O::Float)>> {
310        match &self.population {
311            Some(population) => Some(&population),
312            None => None,
313        }
314    }
315
316    /// Get termination reason
317    pub fn get_termination_reason(&self) -> Option<TerminationReason> {
318        self.termination_reason
319    }
320
321    /// Return KV
322    pub fn get_kv(&self) -> ArgminKV {
323        self.kv.clone()
324    }
325}
326
327/// Defines a common interface for line search methods.
328pub trait ArgminLineSearch<P, F>: Serialize {
329    /// Set the search direction
330    fn set_search_direction(&mut self, direction: P);
331
332    /// Set the initial step length
333    fn set_init_alpha(&mut self, step_length: F) -> Result<(), Error>;
334}
335
336/// Defines a common interface to methods which calculate approximate steps for trust region
337/// methods.
338pub trait ArgminTrustRegion<F>: Clone + Serialize {
339    /// Set the initial step length
340    fn set_radius(&mut self, radius: F);
341}
342//
343/// Common interface for beta update methods (Nonlinear-CG)
344pub trait ArgminNLCGBetaUpdate<T, F: ArgminFloat>: Serialize {
345    /// Update beta
346    /// Parameter 1: \nabla f_k
347    /// Parameter 2: \nabla f_{k+1}
348    /// Parameter 3: p_k
349    fn update(&self, nabla_f_k: &T, nabla_f_k_p_1: &T, p_k: &T) -> F;
350}