argmin/solver/conjugategradient/
nonlinear_cg.rs1use crate::prelude::*;
17use serde::de::DeserializeOwned;
18use serde::{Deserialize, Serialize};
19use std::default::Default;
20
21#[derive(Clone, Serialize, Deserialize)]
31pub struct NonlinearConjugateGradient<P, L, B, F> {
32 p: P,
34 beta: F,
36 linesearch: L,
38 beta_method: B,
40 restart_iter: u64,
42 restart_orthogonality: Option<F>,
44}
45
46impl<P, L, B, F> NonlinearConjugateGradient<P, L, B, F>
47where
48 P: Default,
49 F: ArgminFloat,
50{
51 pub fn new(linesearch: L, beta_method: B) -> Result<Self, Error> {
53 Ok(NonlinearConjugateGradient {
54 p: P::default(),
55 beta: F::nan(),
56 linesearch,
57 beta_method,
58 restart_iter: std::u64::MAX,
59 restart_orthogonality: None,
60 })
61 }
62
63 pub fn restart_iters(mut self, iters: u64) -> Self {
67 self.restart_iter = iters;
68 self
69 }
70
71 pub fn restart_orthogonality(mut self, v: F) -> Self {
80 self.restart_orthogonality = Some(v);
81 self
82 }
83}
84
85impl<O, P, L, B, F> Solver<O> for NonlinearConjugateGradient<P, L, B, F>
86where
87 O: ArgminOp<Param = P, Output = F, Float = F>,
88 P: Clone
89 + Default
90 + Serialize
91 + DeserializeOwned
92 + ArgminSub<O::Param, O::Param>
93 + ArgminDot<O::Param, O::Float>
94 + ArgminScaledAdd<O::Param, O::Float, O::Param>
95 + ArgminAdd<O::Param, O::Param>
96 + ArgminMul<F, O::Param>
97 + ArgminDot<O::Param, O::Float>
98 + ArgminNorm<O::Float>,
99 O::Hessian: Default,
100 L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
101 B: ArgminNLCGBetaUpdate<O::Param, O::Float>,
102 F: ArgminFloat,
103{
104 const NAME: &'static str = "Nonlinear Conjugate Gradient";
105
106 fn init(
107 &mut self,
108 op: &mut OpWrapper<O>,
109 state: &IterState<O>,
110 ) -> Result<Option<ArgminIterData<O>>, Error> {
111 let param = state.get_param();
112 let cost = op.apply(¶m)?;
113 let grad = op.gradient(¶m)?;
114 self.p = grad.mul(&(F::from_f64(-1.0).unwrap()));
115 Ok(Some(
116 ArgminIterData::new().param(param).cost(cost).grad(grad),
117 ))
118 }
119
120 fn next_iter(
121 &mut self,
122 op: &mut OpWrapper<O>,
123 state: &IterState<O>,
124 ) -> Result<ArgminIterData<O>, Error> {
125 let xk = state.get_param();
126 let grad = if let Some(grad) = state.get_grad() {
127 grad
128 } else {
129 op.gradient(&xk)?
130 };
131 let cur_cost = state.get_cost();
132
133 self.linesearch.set_search_direction(self.p.clone());
135
136 let ArgminResult {
138 operator: line_op,
139 state: line_state,
140 } = Executor::new(OpWrapper::new_from_wrapper(op), self.linesearch.clone(), xk)
141 .grad(grad.clone())
142 .cost(cur_cost)
143 .ctrlc(false)
144 .run()?;
145
146 op.consume_op(line_op);
148
149 let xk1 = line_state.get_param();
150
151 let new_grad = op.gradient(&xk1)?;
153
154 let restart_orthogonality = match self.restart_orthogonality {
155 Some(v) => new_grad.dot(&grad).abs() / new_grad.norm().powi(2) >= v,
156 None => false,
157 };
158
159 let restart_iter: bool =
160 (state.get_iter() % self.restart_iter == 0) && state.get_iter() != 0;
161
162 if restart_iter || restart_orthogonality {
163 self.beta = F::from_f64(0.0).unwrap();
164 } else {
165 self.beta = self.beta_method.update(&grad, &new_grad, &self.p);
166 }
167
168 self.p = new_grad
170 .mul(&(F::from_f64(-1.0).unwrap()))
171 .add(&self.p.mul(&self.beta));
172
173 let cost = op.apply(&xk1)?;
175
176 Ok(ArgminIterData::new()
177 .param(xk1)
178 .cost(cost)
179 .grad(new_grad)
180 .kv(make_kv!("beta" => self.beta;
181 "restart_iter" => restart_iter;
182 "restart_orthogonality" => restart_orthogonality;
183 )))
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::core::MinimalNoOperator;
191 use crate::solver::conjugategradient::beta::PolakRibiere;
192 use crate::solver::linesearch::MoreThuenteLineSearch;
193 use crate::test_trait_impl;
194
195 test_trait_impl!(
196 nonlinear_cg,
197 NonlinearConjugateGradient<
198 MinimalNoOperator,
199 MoreThuenteLineSearch<MinimalNoOperator, f64>,
200 PolakRibiere,
201 f64
202 >
203 );
204}