1use crate::core::serialization::*;
11use crate::core::{
12 ArgminCheckpoint, ArgminIterData, ArgminKV, ArgminOp, ArgminResult, Error, IterState, Observe,
13 Observer, ObserverMode, OpWrapper, Solver, TerminationReason,
14};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use std::path::Path;
18use std::sync::atomic::{AtomicBool, Ordering};
19use std::sync::Arc;
20
21#[derive(Clone, Serialize, Deserialize)]
23pub struct Executor<O: ArgminOp, S> {
24 solver: S,
26 #[serde(skip)]
28 pub op: OpWrapper<O>,
29 #[serde(bound = "IterState<O>: Serialize")]
31 state: IterState<O>,
32 #[serde(skip)]
34 observers: Observer<O>,
35 checkpoint: ArgminCheckpoint,
37 ctrlc: bool,
39}
40
41impl<O, S> Executor<O, S>
42where
43 O: ArgminOp,
44 S: Solver<O>,
45{
46 pub fn new(op: O, solver: S, init_param: O::Param) -> Self {
48 let state = IterState::new(init_param);
49 Executor {
50 solver,
51 op: OpWrapper::new(op),
52 state,
53 observers: Observer::new(),
54 checkpoint: ArgminCheckpoint::default(),
55 ctrlc: true,
56 }
57 }
58
59 pub fn from_checkpoint<P: AsRef<Path>>(path: P, op: O) -> Result<Self, Error>
61 where
62 Self: Sized + DeserializeOwned,
63 {
64 let mut executor: Self = load_checkpoint(path)?;
65 executor.op = OpWrapper::new(op);
66 Ok(executor)
67 }
69
70 fn update(&mut self, data: &ArgminIterData<O>) -> Result<(), Error> {
71 if let Some(cur_param) = data.get_param() {
72 self.state.param(cur_param);
73 }
74 if let Some(cur_cost) = data.get_cost() {
75 self.state.cost(cur_cost);
76 }
77 if self.state.get_cost() <= self.state.get_best_cost() {
79 let param = self.state.get_param().clone();
80 let cost = self.state.get_cost();
81 self.state.best_param(param).best_cost(cost);
82 self.state.new_best();
83 }
84
85 if let Some(grad) = data.get_grad() {
86 self.state.grad(grad);
87 }
88 if let Some(hessian) = data.get_hessian() {
89 self.state.hessian(hessian);
90 }
91 if let Some(jacobian) = data.get_jacobian() {
92 self.state.jacobian(jacobian);
93 }
94 if let Some(population) = data.get_population() {
95 self.state.population(population.clone());
96 }
97
98 if let Some(termination_reason) = data.get_termination_reason() {
99 self.state.termination_reason(termination_reason);
100 }
101 Ok(())
102 }
103
104 pub fn run(mut self) -> Result<ArgminResult<O>, Error> {
106 let total_time = std::time::Instant::now();
107
108 let running = Arc::new(AtomicBool::new(true));
109
110 if self.ctrlc {
111 #[cfg(feature = "ctrlc")]
112 {
113 let r = running.clone();
115 match ctrlc::set_handler(move || {
123 r.store(false, Ordering::SeqCst);
124 }) {
125 Err(ctrlc::Error::MultipleHandlers) => Ok(()),
126 r => r,
127 }?;
128 }
129 }
130
131 let init_data = self.solver.init(&mut self.op, &self.state)?;
133
134 let mut logs = make_kv!("max_iters" => self.state.get_max_iters(););
135
136 if let Some(data) = init_data {
138 self.update(&data)?;
139 logs = logs.merge(&mut data.get_kv());
140 }
141
142 self.observers.observe_init(S::NAME, &logs)?;
144
145 self.state.set_func_counts(&self.op);
146
147 while running.load(Ordering::SeqCst) {
148 if !self.state.terminated() {
155 self.state
156 .termination_reason(self.solver.terminate_internal(&self.state));
157 }
158 if self.state.terminated() {
160 break;
161 }
162
163 let start = std::time::Instant::now();
165
166 let data = self.solver.next_iter(&mut self.op, &self.state)?;
167
168 self.state.set_func_counts(&self.op);
169
170 let duration = start.elapsed();
172
173 self.update(&data)?;
174
175 let log = data.get_kv().merge(&mut make_kv!(
176 "time" => duration.as_secs() as f64 + f64::from(duration.subsec_nanos()) * 1e-9;
177 ));
178
179 self.observers.observe_iter(&self.state, &log)?;
180
181 self.state.increment_iter();
183
184 self.checkpoint.store_cond(&self, self.state.get_iter())?;
185
186 self.state.time(total_time.elapsed());
187
188 if self.state.terminated() {
190 break;
191 }
192 }
193
194 if self.state.get_iter() < self.state.get_max_iters() && !self.state.terminated() {
197 self.state.termination_reason(TerminationReason::Aborted);
198 }
199
200 Ok(ArgminResult::new(self.op.get_op(), self.state))
201 }
202
203 pub fn add_observer<OBS: Observe<O> + 'static>(
205 mut self,
206 observer: OBS,
207 mode: ObserverMode,
208 ) -> Self {
209 self.observers.push(observer, mode);
210 self
211 }
212
213 pub fn max_iters(mut self, iters: u64) -> Self {
215 self.state.max_iters(iters);
216 self
217 }
218
219 pub fn target_cost(mut self, cost: O::Float) -> Self {
221 self.state.target_cost(cost);
222 self
223 }
224
225 pub fn cost(mut self, cost: O::Float) -> Self {
227 self.state.cost(cost);
228 self
229 }
230
231 pub fn grad(mut self, grad: O::Param) -> Self {
233 self.state.grad(grad);
234 self
235 }
236
237 pub fn hessian(mut self, hessian: O::Hessian) -> Self {
239 self.state.hessian(hessian);
240 self
241 }
242
243 pub fn jacobian(mut self, jacobian: O::Jacobian) -> Self {
245 self.state.jacobian(jacobian);
246 self
247 }
248
249 pub fn checkpoint_dir(mut self, dir: &str) -> Self {
251 self.checkpoint.set_dir(dir);
252 self
253 }
254
255 pub fn checkpoint_name(mut self, dir: &str) -> Self {
257 self.checkpoint.set_name(dir);
258 self
259 }
260
261 pub fn checkpoint_mode(mut self, mode: CheckpointMode) -> Self {
263 self.checkpoint.set_mode(mode);
264 self
265 }
266
267 pub fn ctrlc(mut self, ctrlc: bool) -> Self {
269 self.ctrlc = ctrlc;
270 self
271 }
272}