argmin/core/
serialization.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{ArgminError, Error};
9use serde::de::DeserializeOwned;
10use serde::{Deserialize, Serialize};
11use std::default::Default;
12use std::fmt::Display;
13use std::fs::File;
14use std::io::{BufReader, BufWriter};
15use std::path::Path;
16
17/// Defines at which intervals a checkpoint is saved.
18#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd, Debug, Hash, Copy)]
19pub enum CheckpointMode {
20    /// Never create checkpoint
21    Never,
22    /// Create checkpoint every N iterations
23    Every(u64),
24    /// Create checkpoint in every iteration
25    Always,
26}
27
28impl Default for CheckpointMode {
29    fn default() -> CheckpointMode {
30        CheckpointMode::Never
31    }
32}
33
34impl Display for CheckpointMode {
35    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
36        match *self {
37            CheckpointMode::Never => write!(f, "Never"),
38            CheckpointMode::Every(i) => write!(f, "Every({})", i),
39            CheckpointMode::Always => write!(f, "Always"),
40        }
41    }
42}
43
44/// Checkpoint
45///
46/// Defines how often and where a checkpoint is saved.
47#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd, Debug, Hash)]
48pub struct ArgminCheckpoint {
49    mode: CheckpointMode,
50    directory: String,
51    name: String,
52}
53
54impl Default for ArgminCheckpoint {
55    fn default() -> ArgminCheckpoint {
56        ArgminCheckpoint {
57            mode: CheckpointMode::Never,
58            directory: ".checkpoints".to_string(),
59            name: "default".to_string(),
60        }
61    }
62}
63
64impl ArgminCheckpoint {
65    /// Define a new checkpoint
66    pub fn new(directory: &str, mode: CheckpointMode) -> Result<Self, Error> {
67        match mode {
68            CheckpointMode::Every(_) | CheckpointMode::Always => {
69                std::fs::create_dir_all(&directory)?
70            }
71            _ => {}
72        }
73        let name = "solver".to_string();
74        let directory = directory.to_string();
75        Ok(ArgminCheckpoint {
76            mode,
77            directory,
78            name,
79        })
80    }
81
82    /// Set directory of checkpoint
83    #[inline]
84    pub fn set_dir(&mut self, dir: &str) {
85        self.directory = dir.to_string();
86    }
87
88    /// Get directory of checkpoint
89    #[inline]
90    pub fn dir(&self) -> String {
91        self.directory.clone()
92    }
93
94    /// Set name of checkpoint
95    #[inline]
96    pub fn set_name(&mut self, name: &str) {
97        self.name = name.to_string();
98    }
99
100    /// Get name of checkpoint
101    #[inline]
102    pub fn name(&self) -> String {
103        self.name.clone()
104    }
105
106    /// Set mode of checkpoint
107    #[inline]
108    pub fn set_mode(&mut self, mode: CheckpointMode) {
109        self.mode = mode
110    }
111
112    /// Write checkpoint to disk
113    #[inline]
114    pub fn store<T: Serialize>(&self, executor: &T, filename: String) -> Result<(), Error> {
115        let dir = Path::new(&self.directory);
116        if !dir.exists() {
117            std::fs::create_dir_all(&dir)?
118        }
119        let fname = dir.join(Path::new(&filename));
120
121        let f = BufWriter::new(File::create(fname)?);
122        bincode::serialize_into(f, executor)?;
123        Ok(())
124    }
125
126    /// Write checkpoint based on the desired `CheckpointMode`
127    #[inline]
128    pub fn store_cond<T: Serialize>(&self, executor: &T, iter: u64) -> Result<(), Error> {
129        let mut filename = self.name();
130        filename.push_str(".arg");
131        match self.mode {
132            CheckpointMode::Always => self.store(executor, filename)?,
133            CheckpointMode::Every(it) if iter % it == 0 => self.store(executor, filename)?,
134            CheckpointMode::Never | CheckpointMode::Every(_) => {}
135        };
136        Ok(())
137    }
138}
139
140/// Load a checkpoint from disk
141pub fn load_checkpoint<T: DeserializeOwned, P: AsRef<Path>>(path: P) -> Result<T, Error> {
142    let path = path.as_ref();
143    if !path.exists() {
144        return Err(ArgminError::CheckpointNotFound {
145            text: path.to_str().unwrap().to_string(),
146        }
147        .into());
148    }
149    let file = File::open(path)?;
150    let reader = BufReader::new(file);
151    Ok(bincode::deserialize_from(reader)?)
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::core::nooperator::MinimalNoOperator;
158    use crate::core::*;
159
160    #[derive(Serialize, Deserialize, Clone, Debug)]
161    pub struct PhonySolver {}
162
163    impl PhonySolver {
164        /// Constructor
165        pub fn new() -> Self {
166            PhonySolver {}
167        }
168    }
169
170    impl<O> Solver<O> for PhonySolver
171    where
172        O: ArgminOp,
173    {
174        fn next_iter(
175            &mut self,
176            _op: &mut OpWrapper<O>,
177            _state: &IterState<O>,
178        ) -> Result<ArgminIterData<O>, Error> {
179            unimplemented!()
180        }
181    }
182
183    #[test]
184    fn test_store() {
185        let op: MinimalNoOperator = MinimalNoOperator::new();
186        let solver = PhonySolver::new();
187        let exec: Executor<MinimalNoOperator, PhonySolver> =
188            Executor::new(op, solver, vec![0.0f64, 0.0]);
189        let check = ArgminCheckpoint::new("checkpoints", CheckpointMode::Always).unwrap();
190        check.store_cond(&exec, 20).unwrap();
191
192        let _loaded: Executor<MinimalNoOperator, PhonySolver> =
193            load_checkpoint("checkpoints/solver.arg").unwrap();
194    }
195}