argmin/solver/conjugategradient/
cg.rs1use crate::prelude::*;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16use std::default::Default;
17use std::fmt::Debug;
18
19#[derive(Clone, Serialize, Deserialize)]
29pub struct ConjugateGradient<P, S> {
30 b: P,
32 r: P,
34 p: P,
36 p_prev: P,
38 #[serde(skip)]
40 rtr: S,
41 #[serde(skip)]
43 alpha: S,
44 #[serde(skip)]
46 beta: S,
47}
48
49impl<P, S> ConjugateGradient<P, S>
50where
51 P: Clone + Default,
52 S: Default,
53{
54 pub fn new(b: P) -> Result<Self, Error> {
60 Ok(ConjugateGradient {
61 b,
62 r: P::default(),
63 p: P::default(),
64 p_prev: P::default(),
65 rtr: S::default(),
66 alpha: S::default(),
67 beta: S::default(),
68 })
69 }
70
71 pub fn p(&self) -> P {
73 self.p.clone()
74 }
75
76 pub fn p_prev(&self) -> P {
78 self.p_prev.clone()
79 }
80
81 pub fn residual(&self) -> P {
83 self.r.clone()
84 }
85}
86
87impl<P, O, S, F> Solver<O> for ConjugateGradient<P, S>
88where
89 O: ArgminOp<Param = P, Output = P, Float = F>,
90 P: Clone
91 + Serialize
92 + DeserializeOwned
93 + ArgminDot<O::Param, S>
94 + ArgminSub<O::Param, O::Param>
95 + ArgminScaledAdd<O::Param, S, O::Param>
96 + ArgminAdd<O::Param, O::Param>
97 + ArgminConj
98 + ArgminMul<O::Float, O::Param>,
99 S: Debug + ArgminDiv<S, S> + ArgminNorm<O::Float> + ArgminConj,
100 F: ArgminFloat,
101{
102 const NAME: &'static str = "Conjugate Gradient";
103
104 fn init(
105 &mut self,
106 op: &mut OpWrapper<O>,
107 state: &IterState<O>,
108 ) -> Result<Option<ArgminIterData<O>>, Error> {
109 let init_param = state.get_param();
110 let ap = op.apply(&init_param)?;
111 let r0 = self.b.sub(&ap).mul(&(F::from_f64(-1.0).unwrap()));
112 self.r = r0.clone();
113 self.p = r0.mul(&(F::from_f64(-1.0).unwrap()));
114 self.rtr = self.r.dot(&self.r.conj());
115 Ok(None)
116 }
117
118 fn next_iter(
120 &mut self,
121 op: &mut OpWrapper<O>,
122 state: &IterState<O>,
123 ) -> Result<ArgminIterData<O>, Error> {
124 self.p_prev = self.p.clone();
125 let apk = op.apply(&self.p)?;
126 self.alpha = self.rtr.div(&self.p.dot(&apk.conj()));
127 let new_param = state.get_param().scaled_add(&self.alpha, &self.p);
128 self.r = self.r.scaled_add(&self.alpha, &apk);
129 let rtr_n = self.r.dot(&self.r.conj());
130 self.beta = rtr_n.div(&self.rtr);
131 self.rtr = rtr_n;
132 self.p = self
133 .r
134 .mul(&(F::from_f64(-1.0).unwrap()))
135 .scaled_add(&self.beta, &self.p);
136 let norm = self.r.dot(&self.r.conj());
137
138 Ok(ArgminIterData::new()
139 .param(new_param)
140 .cost(norm.norm())
142 .kv(make_kv!("alpha" => self.alpha; "beta" => self.beta;)))
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::test_trait_impl;
150
151 test_trait_impl!(conjugate_gradient, ConjugateGradient<Vec<f64>, f64>);
152}