argmin/core/
opwrapper.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{ArgminOp, Error};
9use serde::{Deserialize, Serialize};
10use std::default::Default;
11
12/// This wraps an operator and keeps track of how often the cost, gradient and Hessian have been
13/// computed and how often the modify function has been called. Usually, this is an implementation
14/// detail unless a solver is needed within another solver (such as a line search within a gradient
15/// descent method), then it may be necessary to wrap the operator in an OpWrapper.
16#[derive(Serialize, Deserialize, Clone, Debug, Default)]
17pub struct OpWrapper<O: ArgminOp> {
18    /// Operator
19    pub op: Option<O>,
20    /// Number of cost function evaluations
21    pub cost_func_count: u64,
22    /// Number of gradient function evaluations
23    pub grad_func_count: u64,
24    /// Number of Hessian function evaluations
25    pub hessian_func_count: u64,
26    /// Number of Jacobian function evaluations
27    pub jacobian_func_count: u64,
28    /// Number of `modify` function evaluations
29    pub modify_func_count: u64,
30}
31
32impl<O: ArgminOp> OpWrapper<O> {
33    /// Constructor
34    pub fn new(op: O) -> Self {
35        OpWrapper {
36            op: Some(op),
37            cost_func_count: 0,
38            grad_func_count: 0,
39            hessian_func_count: 0,
40            jacobian_func_count: 0,
41            modify_func_count: 0,
42        }
43    }
44
45    /// Construct struct from other `OpWrapper`. Takes the operator from `op` (replaces it with
46    /// `None`) and crates a new `OpWrapper`
47    pub fn new_from_wrapper(op: &mut OpWrapper<O>) -> Self {
48        OpWrapper {
49            op: op.take_op(),
50            cost_func_count: 0,
51            grad_func_count: 0,
52            hessian_func_count: 0,
53            jacobian_func_count: 0,
54            modify_func_count: 0,
55        }
56    }
57
58    /// Calls the `apply` method of `op` and increments `cost_func_count`.
59    pub fn apply(&mut self, param: &O::Param) -> Result<O::Output, Error> {
60        self.cost_func_count += 1;
61        self.op.as_ref().unwrap().apply(param)
62    }
63
64    /// Calls the `gradient` method of `op` and increments `gradient_func_count`.
65    pub fn gradient(&mut self, param: &O::Param) -> Result<O::Param, Error> {
66        self.grad_func_count += 1;
67        self.op.as_ref().unwrap().gradient(param)
68    }
69
70    /// Calls the `hessian` method of `op` and increments `hessian_func_count`.
71    pub fn hessian(&mut self, param: &O::Param) -> Result<O::Hessian, Error> {
72        self.hessian_func_count += 1;
73        self.op.as_ref().unwrap().hessian(param)
74    }
75
76    /// Calls the `jacobian` method of `op` and increments `jacobian_func_count`.
77    pub fn jacobian(&mut self, param: &O::Param) -> Result<O::Jacobian, Error> {
78        self.jacobian_func_count += 1;
79        self.op.as_ref().unwrap().jacobian(param)
80    }
81
82    /// Calls the `modify` method of `op` and increments `modify_func_count`.
83    pub fn modify(&mut self, param: &O::Param, extent: O::Float) -> Result<O::Param, Error> {
84        self.modify_func_count += 1;
85        self.op.as_ref().unwrap().modify(param, extent)
86    }
87
88    /// Moves the operator out of the struct and replaces it with `None`
89    pub fn take_op(&mut self) -> Option<O> {
90        self.op.take()
91    }
92
93    /// Consumes an operator by increasing the function call counts of `self` by the ones in
94    /// `other`.
95    pub fn consume_op(&mut self, other: OpWrapper<O>) {
96        self.op = other.op;
97        self.cost_func_count += other.cost_func_count;
98        self.grad_func_count += other.grad_func_count;
99        self.hessian_func_count += other.hessian_func_count;
100        self.jacobian_func_count += other.jacobian_func_count;
101        self.modify_func_count += other.modify_func_count;
102    }
103
104    /// Adds function evaluation counts of another operator.
105    pub fn consume_func_counts<O2: ArgminOp>(&mut self, other: OpWrapper<O2>) {
106        self.cost_func_count += other.cost_func_count;
107        self.grad_func_count += other.grad_func_count;
108        self.hessian_func_count += other.hessian_func_count;
109        self.jacobian_func_count += other.jacobian_func_count;
110        self.modify_func_count += other.modify_func_count;
111    }
112
113    /// Reset the cost function counts to zero.
114    pub fn reset(mut self) -> Self {
115        self.cost_func_count = 0;
116        self.grad_func_count = 0;
117        self.hessian_func_count = 0;
118        self.jacobian_func_count = 0;
119        self.modify_func_count = 0;
120        self
121    }
122
123    /// Returns the operator `op` by taking ownership of `self`.
124    pub fn get_op(self) -> O {
125        self.op.unwrap()
126    }
127}
128
129/// The OpWrapper<O> should behave just like any other `ArgminOp`
130impl<O: ArgminOp> ArgminOp for OpWrapper<O> {
131    type Param = O::Param;
132    type Output = O::Output;
133    type Hessian = O::Hessian;
134    type Jacobian = O::Jacobian;
135    type Float = O::Float;
136
137    fn apply(&self, param: &Self::Param) -> Result<Self::Output, Error> {
138        self.op.as_ref().unwrap().apply(param)
139    }
140
141    fn gradient(&self, param: &Self::Param) -> Result<Self::Param, Error> {
142        self.op.as_ref().unwrap().gradient(param)
143    }
144
145    fn hessian(&self, param: &Self::Param) -> Result<Self::Hessian, Error> {
146        self.op.as_ref().unwrap().hessian(param)
147    }
148
149    fn jacobian(&self, param: &Self::Param) -> Result<Self::Jacobian, Error> {
150        self.op.as_ref().unwrap().jacobian(param)
151    }
152
153    fn modify(&self, param: &Self::Param, extent: Self::Float) -> Result<Self::Param, Error> {
154        self.op.as_ref().unwrap().modify(param, extent)
155    }
156}