argmin/core/
executor.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8// TODO: Logging of "initial info"
9
10use crate::core::serialization::*;
11use crate::core::{
12    ArgminCheckpoint, ArgminIterData, ArgminKV, ArgminOp, ArgminResult, Error, IterState, Observe,
13    Observer, ObserverMode, OpWrapper, Solver, TerminationReason,
14};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use std::path::Path;
18use std::sync::atomic::{AtomicBool, Ordering};
19use std::sync::Arc;
20
21/// Executes a solver
22#[derive(Clone, Serialize, Deserialize)]
23pub struct Executor<O: ArgminOp, S> {
24    /// solver
25    solver: S,
26    /// operator
27    #[serde(skip)]
28    pub op: OpWrapper<O>,
29    /// State
30    #[serde(bound = "IterState<O>: Serialize")]
31    state: IterState<O>,
32    /// Storage for observers
33    #[serde(skip)]
34    observers: Observer<O>,
35    /// Checkpoint
36    checkpoint: ArgminCheckpoint,
37    /// Indicates whether Ctrl-C functionality should be active or not
38    ctrlc: bool,
39}
40
41impl<O, S> Executor<O, S>
42where
43    O: ArgminOp,
44    S: Solver<O>,
45{
46    /// Create a new executor with a `solver` and an initial parameter `init_param`
47    pub fn new(op: O, solver: S, init_param: O::Param) -> Self {
48        let state = IterState::new(init_param);
49        Executor {
50            solver,
51            op: OpWrapper::new(op),
52            state,
53            observers: Observer::new(),
54            checkpoint: ArgminCheckpoint::default(),
55            ctrlc: true,
56        }
57    }
58
59    /// Create a new executor from a checkpoint
60    pub fn from_checkpoint<P: AsRef<Path>>(path: P, op: O) -> Result<Self, Error>
61    where
62        Self: Sized + DeserializeOwned,
63    {
64        let mut executor: Self = load_checkpoint(path)?;
65        executor.op = OpWrapper::new(op);
66        Ok(executor)
67        // load_checkpoint(path)
68    }
69
70    fn update(&mut self, data: &ArgminIterData<O>) -> Result<(), Error> {
71        if let Some(cur_param) = data.get_param() {
72            self.state.param(cur_param);
73        }
74        if let Some(cur_cost) = data.get_cost() {
75            self.state.cost(cur_cost);
76        }
77        // check if parameters are the best so far
78        if self.state.get_cost() <= self.state.get_best_cost() {
79            let param = self.state.get_param().clone();
80            let cost = self.state.get_cost();
81            self.state.best_param(param).best_cost(cost);
82            self.state.new_best();
83        }
84
85        if let Some(grad) = data.get_grad() {
86            self.state.grad(grad);
87        }
88        if let Some(hessian) = data.get_hessian() {
89            self.state.hessian(hessian);
90        }
91        if let Some(jacobian) = data.get_jacobian() {
92            self.state.jacobian(jacobian);
93        }
94        if let Some(population) = data.get_population() {
95            self.state.population(population.clone());
96        }
97
98        if let Some(termination_reason) = data.get_termination_reason() {
99            self.state.termination_reason(termination_reason);
100        }
101        Ok(())
102    }
103
104    /// Run the executor
105    pub fn run(mut self) -> Result<ArgminResult<O>, Error> {
106        let total_time = std::time::Instant::now();
107
108        let running = Arc::new(AtomicBool::new(true));
109
110        if self.ctrlc {
111            #[cfg(feature = "ctrlc")]
112            {
113                // Set up the Ctrl-C handler
114                let r = running.clone();
115                // This is currently a hack to allow checkpoints to be run again within the
116                // same program (usually not really a usecase anyway). Unfortunately, this
117                // means that any subsequent run started afterwards will have not Ctrl-C
118                // handling available... This should also be a problem in case one tries to run
119                // two consecutive optimizations. There is ongoing work in the ctrlc crate
120                // (channels and such) which may solve this problem. So far, we have to live
121                // with this.
122                match ctrlc::set_handler(move || {
123                    r.store(false, Ordering::SeqCst);
124                }) {
125                    Err(ctrlc::Error::MultipleHandlers) => Ok(()),
126                    r => r,
127                }?;
128            }
129        }
130
131        // let mut op_wrapper = OpWrapper::new(&self.op);
132        let init_data = self.solver.init(&mut self.op, &self.state)?;
133
134        let mut logs = make_kv!("max_iters" => self.state.get_max_iters(););
135
136        // If init() returned something, deal with it
137        if let Some(data) = init_data {
138            self.update(&data)?;
139            logs = logs.merge(&mut data.get_kv());
140        }
141
142        // Observe after init
143        self.observers.observe_init(S::NAME, &logs)?;
144
145        self.state.set_func_counts(&self.op);
146
147        while running.load(Ordering::SeqCst) {
148            // check first if it has already terminated
149            // This should probably be solved better.
150            // First, check if it isn't already terminated. If it isn't, evaluate the
151            // stopping criteria. If `self.terminate()` is called without the checking
152            // whether it has terminated already, then it may overwrite a termination set
153            // within `next_iter()`!
154            if !self.state.terminated() {
155                self.state
156                    .termination_reason(self.solver.terminate_internal(&self.state));
157            }
158            // Now check once more if the algorithm has terminated. If yes, then break.
159            if self.state.terminated() {
160                break;
161            }
162
163            // Start time measurement
164            let start = std::time::Instant::now();
165
166            let data = self.solver.next_iter(&mut self.op, &self.state)?;
167
168            self.state.set_func_counts(&self.op);
169
170            // End time measurement
171            let duration = start.elapsed();
172
173            self.update(&data)?;
174
175            let log = data.get_kv().merge(&mut make_kv!(
176                "time" => duration.as_secs() as f64 + f64::from(duration.subsec_nanos()) * 1e-9;
177            ));
178
179            self.observers.observe_iter(&self.state, &log)?;
180
181            // increment iteration number
182            self.state.increment_iter();
183
184            self.checkpoint.store_cond(&self, self.state.get_iter())?;
185
186            self.state.time(total_time.elapsed());
187
188            // Check if termination occured inside next_iter()
189            if self.state.terminated() {
190                break;
191            }
192        }
193
194        // in case it stopped prematurely and `termination_reason` is still `NotTerminated`,
195        // someone must have pulled the handbrake
196        if self.state.get_iter() < self.state.get_max_iters() && !self.state.terminated() {
197            self.state.termination_reason(TerminationReason::Aborted);
198        }
199
200        Ok(ArgminResult::new(self.op.get_op(), self.state))
201    }
202
203    /// Attaches a observer which implements `ArgminLog` to the solver.
204    pub fn add_observer<OBS: Observe<O> + 'static>(
205        mut self,
206        observer: OBS,
207        mode: ObserverMode,
208    ) -> Self {
209        self.observers.push(observer, mode);
210        self
211    }
212
213    /// Set maximum number of iterations
214    pub fn max_iters(mut self, iters: u64) -> Self {
215        self.state.max_iters(iters);
216        self
217    }
218
219    /// Set target cost value
220    pub fn target_cost(mut self, cost: O::Float) -> Self {
221        self.state.target_cost(cost);
222        self
223    }
224
225    /// Set cost value
226    pub fn cost(mut self, cost: O::Float) -> Self {
227        self.state.cost(cost);
228        self
229    }
230
231    /// Set Gradient
232    pub fn grad(mut self, grad: O::Param) -> Self {
233        self.state.grad(grad);
234        self
235    }
236
237    /// Set Hessian
238    pub fn hessian(mut self, hessian: O::Hessian) -> Self {
239        self.state.hessian(hessian);
240        self
241    }
242
243    /// Set Jacobian
244    pub fn jacobian(mut self, jacobian: O::Jacobian) -> Self {
245        self.state.jacobian(jacobian);
246        self
247    }
248
249    /// Set checkpoint directory
250    pub fn checkpoint_dir(mut self, dir: &str) -> Self {
251        self.checkpoint.set_dir(dir);
252        self
253    }
254
255    /// Set checkpoint name
256    pub fn checkpoint_name(mut self, dir: &str) -> Self {
257        self.checkpoint.set_name(dir);
258        self
259    }
260
261    /// Set the checkpoint mode
262    pub fn checkpoint_mode(mut self, mode: CheckpointMode) -> Self {
263        self.checkpoint.set_mode(mode);
264        self
265    }
266
267    /// Turn Ctrl-C handling on or off (default: on)
268    pub fn ctrlc(mut self, ctrlc: bool) -> Self {
269        self.ctrlc = ctrlc;
270        self
271    }
272}