argmin/core/observers/
file.rs1use crate::core::{ArgminKV, ArgminOp, Error, IterState, Observe};
11use serde::{Deserialize, Serialize};
12use std::default::Default;
13use std::fs::File;
14use std::io::BufWriter;
15use std::path::Path;
16
17#[derive(Copy, Clone, Serialize, Deserialize, Debug, Eq, PartialEq, Ord, PartialOrd)]
19pub enum WriteToFileSerializer {
20 Bincode,
22 JSON,
24}
25
26impl Default for WriteToFileSerializer {
27 fn default() -> Self {
28 WriteToFileSerializer::Bincode
29 }
30}
31
32#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, Ord, PartialOrd)]
34pub struct WriteToFile<O> {
35 dir: String,
37 prefix: String,
39 serializer: WriteToFileSerializer,
41 _param: std::marker::PhantomData<O>,
42}
43
44impl<O: ArgminOp> WriteToFile<O> {
45 pub fn new(dir: &str, prefix: &str) -> Self {
47 WriteToFile {
48 dir: dir.to_string(),
49 prefix: prefix.to_string(),
50 serializer: WriteToFileSerializer::Bincode,
51 _param: std::marker::PhantomData,
52 }
53 }
54
55 pub fn serializer(mut self, serializer: WriteToFileSerializer) -> Self {
57 self.serializer = serializer;
58 self
59 }
60}
61
62impl<O: ArgminOp> Observe<O> for WriteToFile<O> {
63 fn observe_iter(&mut self, state: &IterState<O>, _kv: &ArgminKV) -> Result<(), Error> {
64 let param = state.get_param();
65 let iter = state.get_iter();
66 let dir = Path::new(&self.dir);
67 if !dir.exists() {
68 std::fs::create_dir_all(&dir)?
69 }
70
71 let mut fname = self.prefix.clone();
72 fname.push_str("_");
73 fname.push_str(&iter.to_string());
74 fname.push_str(".arp");
75 let fname = dir.join(fname);
76
77 let f = BufWriter::new(File::create(fname)?);
78 match self.serializer {
79 WriteToFileSerializer::Bincode => {
80 bincode::serialize_into(f, ¶m)?;
81 }
82 WriteToFileSerializer::JSON => {
83 serde_json::to_writer_pretty(f, ¶m)?;
84 }
85 }
86 Ok(())
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 send_sync_test!(write_to_file, WriteToFile<Vec<f64>>);
95}