1use crate::prelude::*;
16use crate::solver::conjugategradient::ConjugateGradient;
17use serde::{Deserialize, Serialize};
18
19#[derive(Clone, Serialize, Deserialize)]
29pub struct NewtonCG<L, F> {
30 linesearch: L,
32 curvature_threshold: F,
34 tol: F,
36}
37
38impl<L, F: ArgminFloat> NewtonCG<L, F> {
39 pub fn new(linesearch: L) -> Self {
41 NewtonCG {
42 linesearch,
43 curvature_threshold: F::from_f64(0.0).unwrap(),
44 tol: F::epsilon(),
45 }
46 }
47
48 pub fn curvature_threshold(mut self, threshold: F) -> Self {
50 self.curvature_threshold = threshold;
51 self
52 }
53
54 pub fn with_tol(mut self, tol: F) -> Result<Self, Error> {
56 if tol <= F::from_f64(0.0).unwrap() {
57 return Err(ArgminError::InvalidParameter {
58 text: "Newton-CG: tol must be positive.".to_string(),
59 }
60 .into());
61 }
62 self.tol = tol;
63 Ok(self)
64 }
65}
66
67impl<O, L, F> Solver<O> for NewtonCG<L, F>
68where
69 O: ArgminOp<Output = F, Float = F>,
70 O::Param: Send
71 + Sync
72 + Clone
73 + Serialize
74 + Default
75 + ArgminSub<O::Param, O::Param>
76 + ArgminAdd<O::Param, O::Param>
77 + ArgminDot<O::Param, O::Float>
78 + ArgminScaledAdd<O::Param, O::Float, O::Param>
79 + ArgminMul<F, O::Param>
80 + ArgminConj
81 + ArgminZeroLike
82 + ArgminNorm<O::Float>,
83 O::Hessian: Send
84 + Sync
85 + Default
86 + Clone
87 + Serialize
88 + Default
89 + ArgminInv<O::Hessian>
90 + ArgminDot<O::Param, O::Param>,
91 L: Clone + ArgminLineSearch<O::Param, O::Float> + Solver<OpWrapper<O>>,
92 F: ArgminFloat + Default + ArgminDiv<O::Float, O::Float> + ArgminNorm<O::Float> + ArgminConj,
93{
94 const NAME: &'static str = "Newton-CG";
95
96 fn next_iter(
97 &mut self,
98 op: &mut OpWrapper<O>,
99 state: &IterState<O>,
100 ) -> Result<ArgminIterData<O>, Error> {
101 let param = state.get_param();
102 let grad = op.gradient(¶m)?;
103 let hessian = op.hessian(¶m)?;
104
105 let cg_op: CGSubProblem<O::Param, O::Hessian, O::Float> =
107 CGSubProblem::new(hessian.clone());
108 let mut cg_op = OpWrapper::new(cg_op);
109
110 let mut x_p = param.zero_like();
111 let mut x: O::Param = param.zero_like();
112 let mut cg = ConjugateGradient::new(grad.mul(&(F::from_f64(-1.0).unwrap())))?;
113
114 let mut cg_state = IterState::new(x_p.clone());
115 cg.init(&mut cg_op, &cg_state)?;
116 let grad_norm = grad.norm();
117 for iter in 0.. {
118 let data = cg.next_iter(&mut cg_op, &cg_state)?;
119 x = data.get_param().unwrap();
120 let p = cg.p_prev();
121 let curvature = p.dot(&hessian.dot(&p));
122 if curvature <= self.curvature_threshold {
123 if iter == 0 {
124 x = grad.mul(&(F::from_f64(-1.0).unwrap()));
125 break;
126 } else {
127 x = x_p;
128 break;
129 }
130 }
131 if data.get_cost().unwrap()
132 <= F::from_f64(0.5).unwrap().min(grad_norm.sqrt()) * grad_norm
133 {
134 break;
135 }
136 cg_state.param(x.clone());
137 cg_state.cost(data.get_cost().unwrap());
138 x_p = x.clone();
139 }
140
141 self.linesearch.set_search_direction(x);
143
144 let ArgminResult {
146 operator: line_op,
147 state:
148 IterState {
149 param: next_param,
150 cost: next_cost,
151 ..
152 },
153 } = Executor::new(
154 OpWrapper::new_from_wrapper(op),
155 self.linesearch.clone(),
156 param,
157 )
158 .grad(grad)
159 .cost(state.get_cost())
160 .ctrlc(false)
161 .run()?;
162
163 op.consume_op(line_op);
164
165 Ok(ArgminIterData::new().param(next_param).cost(next_cost))
166 }
167
168 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
169 if (state.get_cost() - state.get_prev_cost()).abs() < self.tol {
170 TerminationReason::NoChangeInCost
171 } else {
172 TerminationReason::NotTerminated
173 }
174 }
175}
176
177#[derive(Clone, Default, Serialize, Deserialize)]
178struct CGSubProblem<T, H, F> {
179 hessian: H,
180 phantom: std::marker::PhantomData<T>,
181 float: std::marker::PhantomData<F>,
182}
183
184impl<T, H, F> CGSubProblem<T, H, F>
185where
186 T: Clone + Send + Sync,
187 H: Clone + Default + ArgminDot<T, T> + Send + Sync,
188{
189 pub fn new(hessian: H) -> Self {
191 CGSubProblem {
192 hessian,
193 phantom: std::marker::PhantomData,
194 float: std::marker::PhantomData,
195 }
196 }
197}
198
199impl<T, H, F> ArgminOp for CGSubProblem<T, H, F>
200where
201 T: Clone + Default + Send + Sync + Serialize + serde::de::DeserializeOwned,
202 H: Clone + Default + ArgminDot<T, T> + Send + Sync + Serialize + serde::de::DeserializeOwned,
203 F: ArgminFloat,
204{
205 type Param = T;
206 type Output = T;
207 type Hessian = ();
208 type Jacobian = ();
209 type Float = F;
210
211 fn apply(&self, p: &T) -> Result<T, Error> {
212 Ok(self.hessian.dot(&p))
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::solver::linesearch::MoreThuenteLineSearch;
220 use crate::test_trait_impl;
221
222 test_trait_impl!(
223 newton_cg,
224 NewtonCG<MoreThuenteLineSearch<Vec<f64>, f64>, f64>
225 );
226
227 test_trait_impl!(cg_subproblem, CGSubProblem<Vec<f64>, Vec<Vec<f64>>, f64>);
228
229 #[test]
230 fn test_tolerance() {
231 let tol1: f64 = 1e-4;
232
233 let linesearch: MoreThuenteLineSearch<Vec<f64>, f64> = MoreThuenteLineSearch::new();
234
235 let NewtonCG { tol: t, .. }: NewtonCG<MoreThuenteLineSearch<Vec<f64>, f64>, f64> =
236 NewtonCG::new(linesearch).with_tol(tol1).unwrap();
237
238 assert!((t - tol1).abs() < std::f64::EPSILON);
239 }
240}