1use crate::core::{ArgminOp, Error};
9use serde::{Deserialize, Serialize};
10use std::default::Default;
11
12#[derive(Serialize, Deserialize, Clone, Debug, Default)]
17pub struct OpWrapper<O: ArgminOp> {
18 pub op: Option<O>,
20 pub cost_func_count: u64,
22 pub grad_func_count: u64,
24 pub hessian_func_count: u64,
26 pub jacobian_func_count: u64,
28 pub modify_func_count: u64,
30}
31
32impl<O: ArgminOp> OpWrapper<O> {
33 pub fn new(op: O) -> Self {
35 OpWrapper {
36 op: Some(op),
37 cost_func_count: 0,
38 grad_func_count: 0,
39 hessian_func_count: 0,
40 jacobian_func_count: 0,
41 modify_func_count: 0,
42 }
43 }
44
45 pub fn new_from_wrapper(op: &mut OpWrapper<O>) -> Self {
48 OpWrapper {
49 op: op.take_op(),
50 cost_func_count: 0,
51 grad_func_count: 0,
52 hessian_func_count: 0,
53 jacobian_func_count: 0,
54 modify_func_count: 0,
55 }
56 }
57
58 pub fn apply(&mut self, param: &O::Param) -> Result<O::Output, Error> {
60 self.cost_func_count += 1;
61 self.op.as_ref().unwrap().apply(param)
62 }
63
64 pub fn gradient(&mut self, param: &O::Param) -> Result<O::Param, Error> {
66 self.grad_func_count += 1;
67 self.op.as_ref().unwrap().gradient(param)
68 }
69
70 pub fn hessian(&mut self, param: &O::Param) -> Result<O::Hessian, Error> {
72 self.hessian_func_count += 1;
73 self.op.as_ref().unwrap().hessian(param)
74 }
75
76 pub fn jacobian(&mut self, param: &O::Param) -> Result<O::Jacobian, Error> {
78 self.jacobian_func_count += 1;
79 self.op.as_ref().unwrap().jacobian(param)
80 }
81
82 pub fn modify(&mut self, param: &O::Param, extent: O::Float) -> Result<O::Param, Error> {
84 self.modify_func_count += 1;
85 self.op.as_ref().unwrap().modify(param, extent)
86 }
87
88 pub fn take_op(&mut self) -> Option<O> {
90 self.op.take()
91 }
92
93 pub fn consume_op(&mut self, other: OpWrapper<O>) {
96 self.op = other.op;
97 self.cost_func_count += other.cost_func_count;
98 self.grad_func_count += other.grad_func_count;
99 self.hessian_func_count += other.hessian_func_count;
100 self.jacobian_func_count += other.jacobian_func_count;
101 self.modify_func_count += other.modify_func_count;
102 }
103
104 pub fn consume_func_counts<O2: ArgminOp>(&mut self, other: OpWrapper<O2>) {
106 self.cost_func_count += other.cost_func_count;
107 self.grad_func_count += other.grad_func_count;
108 self.hessian_func_count += other.hessian_func_count;
109 self.jacobian_func_count += other.jacobian_func_count;
110 self.modify_func_count += other.modify_func_count;
111 }
112
113 pub fn reset(mut self) -> Self {
115 self.cost_func_count = 0;
116 self.grad_func_count = 0;
117 self.hessian_func_count = 0;
118 self.jacobian_func_count = 0;
119 self.modify_func_count = 0;
120 self
121 }
122
123 pub fn get_op(self) -> O {
125 self.op.unwrap()
126 }
127}
128
129impl<O: ArgminOp> ArgminOp for OpWrapper<O> {
131 type Param = O::Param;
132 type Output = O::Output;
133 type Hessian = O::Hessian;
134 type Jacobian = O::Jacobian;
135 type Float = O::Float;
136
137 fn apply(&self, param: &Self::Param) -> Result<Self::Output, Error> {
138 self.op.as_ref().unwrap().apply(param)
139 }
140
141 fn gradient(&self, param: &Self::Param) -> Result<Self::Param, Error> {
142 self.op.as_ref().unwrap().gradient(param)
143 }
144
145 fn hessian(&self, param: &Self::Param) -> Result<Self::Hessian, Error> {
146 self.op.as_ref().unwrap().hessian(param)
147 }
148
149 fn jacobian(&self, param: &Self::Param) -> Result<Self::Jacobian, Error> {
150 self.op.as_ref().unwrap().jacobian(param)
151 }
152
153 fn modify(&self, param: &Self::Param, extent: Self::Float) -> Result<Self::Param, Error> {
154 self.op.as_ref().unwrap().modify(param, extent)
155 }
156}