argmin/core/observers/
file.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Output parameter vectors to file
9
10use crate::core::{ArgminKV, ArgminOp, Error, IterState, Observe};
11use serde::{Deserialize, Serialize};
12use std::default::Default;
13use std::fs::File;
14use std::io::BufWriter;
15use std::path::Path;
16
17/// Different kinds of serializers
18#[derive(Copy, Clone, Serialize, Deserialize, Debug, Eq, PartialEq, Ord, PartialOrd)]
19pub enum WriteToFileSerializer {
20    /// Bincode
21    Bincode,
22    /// JSON
23    JSON,
24}
25
26impl Default for WriteToFileSerializer {
27    fn default() -> Self {
28        WriteToFileSerializer::Bincode
29    }
30}
31
32/// Write parameter vectors to file
33#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, Ord, PartialOrd)]
34pub struct WriteToFile<O> {
35    /// Directory
36    dir: String,
37    /// File prefix
38    prefix: String,
39    /// Chosen serializer
40    serializer: WriteToFileSerializer,
41    _param: std::marker::PhantomData<O>,
42}
43
44impl<O: ArgminOp> WriteToFile<O> {
45    /// Create a new `WriteToFile` struct
46    pub fn new(dir: &str, prefix: &str) -> Self {
47        WriteToFile {
48            dir: dir.to_string(),
49            prefix: prefix.to_string(),
50            serializer: WriteToFileSerializer::Bincode,
51            _param: std::marker::PhantomData,
52        }
53    }
54
55    /// Set serializer
56    pub fn serializer(mut self, serializer: WriteToFileSerializer) -> Self {
57        self.serializer = serializer;
58        self
59    }
60}
61
62impl<O: ArgminOp> Observe<O> for WriteToFile<O> {
63    fn observe_iter(&mut self, state: &IterState<O>, _kv: &ArgminKV) -> Result<(), Error> {
64        let param = state.get_param();
65        let iter = state.get_iter();
66        let dir = Path::new(&self.dir);
67        if !dir.exists() {
68            std::fs::create_dir_all(&dir)?
69        }
70
71        let mut fname = self.prefix.clone();
72        fname.push_str("_");
73        fname.push_str(&iter.to_string());
74        fname.push_str(".arp");
75        let fname = dir.join(fname);
76
77        let f = BufWriter::new(File::create(fname)?);
78        match self.serializer {
79            WriteToFileSerializer::Bincode => {
80                bincode::serialize_into(f, &param)?;
81            }
82            WriteToFileSerializer::JSON => {
83                serde_json::to_writer_pretty(f, &param)?;
84            }
85        }
86        Ok(())
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    send_sync_test!(write_to_file, WriteToFile<Vec<f64>>);
95}