argmin/solver/conjugategradient/
cg.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::default::Default;
17use std::fmt::Debug;
18
19/// The conjugate gradient method is a solver for systems of linear equations with a symmetric and
20/// positive-definite matrix.
21///
22/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/conjugategradient.rs)
23///
24/// # References:
25///
26/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
27/// Springer. ISBN 0-387-30303-0.
28#[derive(Clone, Serialize, Deserialize)]
29pub struct ConjugateGradient<P, S> {
30    /// b (right hand side)
31    b: P,
32    /// residual
33    r: P,
34    /// p
35    p: P,
36    /// previous p
37    p_prev: P,
38    /// r^T * r
39    #[serde(skip)]
40    rtr: S,
41    /// alpha
42    #[serde(skip)]
43    alpha: S,
44    /// beta
45    #[serde(skip)]
46    beta: S,
47}
48
49impl<P, S> ConjugateGradient<P, S>
50where
51    P: Clone + Default,
52    S: Default,
53{
54    /// Constructor
55    ///
56    /// Parameters:
57    ///
58    /// `b`: right hand side of `A * x = b`
59    pub fn new(b: P) -> Result<Self, Error> {
60        Ok(ConjugateGradient {
61            b,
62            r: P::default(),
63            p: P::default(),
64            p_prev: P::default(),
65            rtr: S::default(),
66            alpha: S::default(),
67            beta: S::default(),
68        })
69    }
70
71    /// Return the current search direction (This is needed by NewtonCG for instance)
72    pub fn p(&self) -> P {
73        self.p.clone()
74    }
75
76    /// Return the previous search direction (This is needed by NewtonCG for instance)
77    pub fn p_prev(&self) -> P {
78        self.p_prev.clone()
79    }
80
81    /// Return the current residual (This is needed by NewtonCG for instance)
82    pub fn residual(&self) -> P {
83        self.r.clone()
84    }
85}
86
87impl<P, O, S, F> Solver<O> for ConjugateGradient<P, S>
88where
89    O: ArgminOp<Param = P, Output = P, Float = F>,
90    P: Clone
91        + Serialize
92        + DeserializeOwned
93        + ArgminDot<O::Param, S>
94        + ArgminSub<O::Param, O::Param>
95        + ArgminScaledAdd<O::Param, S, O::Param>
96        + ArgminAdd<O::Param, O::Param>
97        + ArgminConj
98        + ArgminMul<O::Float, O::Param>,
99    S: Debug + ArgminDiv<S, S> + ArgminNorm<O::Float> + ArgminConj,
100    F: ArgminFloat,
101{
102    const NAME: &'static str = "Conjugate Gradient";
103
104    fn init(
105        &mut self,
106        op: &mut OpWrapper<O>,
107        state: &IterState<O>,
108    ) -> Result<Option<ArgminIterData<O>>, Error> {
109        let init_param = state.get_param();
110        let ap = op.apply(&init_param)?;
111        let r0 = self.b.sub(&ap).mul(&(F::from_f64(-1.0).unwrap()));
112        self.r = r0.clone();
113        self.p = r0.mul(&(F::from_f64(-1.0).unwrap()));
114        self.rtr = self.r.dot(&self.r.conj());
115        Ok(None)
116    }
117
118    /// Perform one iteration of CG algorithm
119    fn next_iter(
120        &mut self,
121        op: &mut OpWrapper<O>,
122        state: &IterState<O>,
123    ) -> Result<ArgminIterData<O>, Error> {
124        self.p_prev = self.p.clone();
125        let apk = op.apply(&self.p)?;
126        self.alpha = self.rtr.div(&self.p.dot(&apk.conj()));
127        let new_param = state.get_param().scaled_add(&self.alpha, &self.p);
128        self.r = self.r.scaled_add(&self.alpha, &apk);
129        let rtr_n = self.r.dot(&self.r.conj());
130        self.beta = rtr_n.div(&self.rtr);
131        self.rtr = rtr_n;
132        self.p = self
133            .r
134            .mul(&(F::from_f64(-1.0).unwrap()))
135            .scaled_add(&self.beta, &self.p);
136        let norm = self.r.dot(&self.r.conj());
137
138        Ok(ArgminIterData::new()
139            .param(new_param)
140            // .cost(norm.sqrt())
141            .cost(norm.norm())
142            .kv(make_kv!("alpha" => self.alpha; "beta" => self.beta;)))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::test_trait_impl;
150
151    test_trait_impl!(conjugate_gradient, ConjugateGradient<Vec<f64>, f64>);
152}