argmin/core/
serialization.rs1use crate::core::{ArgminError, Error};
9use serde::de::DeserializeOwned;
10use serde::{Deserialize, Serialize};
11use std::default::Default;
12use std::fmt::Display;
13use std::fs::File;
14use std::io::{BufReader, BufWriter};
15use std::path::Path;
16
17#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd, Debug, Hash, Copy)]
19pub enum CheckpointMode {
20 Never,
22 Every(u64),
24 Always,
26}
27
28impl Default for CheckpointMode {
29 fn default() -> CheckpointMode {
30 CheckpointMode::Never
31 }
32}
33
34impl Display for CheckpointMode {
35 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
36 match *self {
37 CheckpointMode::Never => write!(f, "Never"),
38 CheckpointMode::Every(i) => write!(f, "Every({})", i),
39 CheckpointMode::Always => write!(f, "Always"),
40 }
41 }
42}
43
44#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd, Debug, Hash)]
48pub struct ArgminCheckpoint {
49 mode: CheckpointMode,
50 directory: String,
51 name: String,
52}
53
54impl Default for ArgminCheckpoint {
55 fn default() -> ArgminCheckpoint {
56 ArgminCheckpoint {
57 mode: CheckpointMode::Never,
58 directory: ".checkpoints".to_string(),
59 name: "default".to_string(),
60 }
61 }
62}
63
64impl ArgminCheckpoint {
65 pub fn new(directory: &str, mode: CheckpointMode) -> Result<Self, Error> {
67 match mode {
68 CheckpointMode::Every(_) | CheckpointMode::Always => {
69 std::fs::create_dir_all(&directory)?
70 }
71 _ => {}
72 }
73 let name = "solver".to_string();
74 let directory = directory.to_string();
75 Ok(ArgminCheckpoint {
76 mode,
77 directory,
78 name,
79 })
80 }
81
82 #[inline]
84 pub fn set_dir(&mut self, dir: &str) {
85 self.directory = dir.to_string();
86 }
87
88 #[inline]
90 pub fn dir(&self) -> String {
91 self.directory.clone()
92 }
93
94 #[inline]
96 pub fn set_name(&mut self, name: &str) {
97 self.name = name.to_string();
98 }
99
100 #[inline]
102 pub fn name(&self) -> String {
103 self.name.clone()
104 }
105
106 #[inline]
108 pub fn set_mode(&mut self, mode: CheckpointMode) {
109 self.mode = mode
110 }
111
112 #[inline]
114 pub fn store<T: Serialize>(&self, executor: &T, filename: String) -> Result<(), Error> {
115 let dir = Path::new(&self.directory);
116 if !dir.exists() {
117 std::fs::create_dir_all(&dir)?
118 }
119 let fname = dir.join(Path::new(&filename));
120
121 let f = BufWriter::new(File::create(fname)?);
122 bincode::serialize_into(f, executor)?;
123 Ok(())
124 }
125
126 #[inline]
128 pub fn store_cond<T: Serialize>(&self, executor: &T, iter: u64) -> Result<(), Error> {
129 let mut filename = self.name();
130 filename.push_str(".arg");
131 match self.mode {
132 CheckpointMode::Always => self.store(executor, filename)?,
133 CheckpointMode::Every(it) if iter % it == 0 => self.store(executor, filename)?,
134 CheckpointMode::Never | CheckpointMode::Every(_) => {}
135 };
136 Ok(())
137 }
138}
139
140pub fn load_checkpoint<T: DeserializeOwned, P: AsRef<Path>>(path: P) -> Result<T, Error> {
142 let path = path.as_ref();
143 if !path.exists() {
144 return Err(ArgminError::CheckpointNotFound {
145 text: path.to_str().unwrap().to_string(),
146 }
147 .into());
148 }
149 let file = File::open(path)?;
150 let reader = BufReader::new(file);
151 Ok(bincode::deserialize_from(reader)?)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::core::nooperator::MinimalNoOperator;
158 use crate::core::*;
159
160 #[derive(Serialize, Deserialize, Clone, Debug)]
161 pub struct PhonySolver {}
162
163 impl PhonySolver {
164 pub fn new() -> Self {
166 PhonySolver {}
167 }
168 }
169
170 impl<O> Solver<O> for PhonySolver
171 where
172 O: ArgminOp,
173 {
174 fn next_iter(
175 &mut self,
176 _op: &mut OpWrapper<O>,
177 _state: &IterState<O>,
178 ) -> Result<ArgminIterData<O>, Error> {
179 unimplemented!()
180 }
181 }
182
183 #[test]
184 fn test_store() {
185 let op: MinimalNoOperator = MinimalNoOperator::new();
186 let solver = PhonySolver::new();
187 let exec: Executor<MinimalNoOperator, PhonySolver> =
188 Executor::new(op, solver, vec![0.0f64, 0.0]);
189 let check = ArgminCheckpoint::new("checkpoints", CheckpointMode::Always).unwrap();
190 check.store_cond(&exec, 20).unwrap();
191
192 let _loaded: Executor<MinimalNoOperator, PhonySolver> =
193 load_checkpoint("checkpoints/solver.arg").unwrap();
194 }
195}