argmin/solver/trustregion/
trustregion_method.rs1use crate::prelude::*;
14use crate::solver::trustregion::reduction_ratio;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18#[derive(Clone, Serialize, Deserialize)]
39pub struct TrustRegion<R, F> {
40 radius: F,
42 max_radius: F,
44 eta: F,
46 subproblem: R,
48 fxk: F,
50 mk0: F,
52}
53
54impl<R, F: ArgminFloat> TrustRegion<R, F> {
55 pub fn new(subproblem: R) -> Self {
57 TrustRegion {
58 radius: F::from_f64(1.0).unwrap(),
59 max_radius: F::from_f64(100.0).unwrap(),
60 eta: F::from_f64(0.125).unwrap(),
61 subproblem,
62 fxk: F::nan(),
63 mk0: F::nan(),
64 }
65 }
66
67 pub fn radius(mut self, radius: F) -> Self {
69 self.radius = radius;
70 self
71 }
72
73 pub fn max_radius(mut self, max_radius: F) -> Self {
75 self.max_radius = max_radius;
76 self
77 }
78
79 pub fn eta(mut self, eta: F) -> Result<Self, Error> {
81 if eta >= F::from_f64(0.25).unwrap() || eta < F::from_f64(0.0).unwrap() {
82 return Err(ArgminError::InvalidParameter {
83 text: "TrustRegion: eta must be in [0, 1/4).".to_string(),
84 }
85 .into());
86 }
87 self.eta = eta;
88 Ok(self)
89 }
90}
91
92impl<O, R, F> Solver<O> for TrustRegion<R, F>
93where
94 O: ArgminOp<Output = F, Float = F>,
95 O::Param: Default
96 + Clone
97 + Debug
98 + Serialize
99 + ArgminMul<F, O::Param>
100 + ArgminWeightedDot<O::Param, F, O::Hessian>
101 + ArgminNorm<F>
102 + ArgminDot<O::Param, F>
103 + ArgminAdd<O::Param, O::Param>
104 + ArgminSub<O::Param, O::Param>
105 + ArgminZeroLike
106 + ArgminMul<F, O::Param>,
107 O::Hessian: Default + Clone + Debug + Serialize + ArgminDot<O::Param, O::Param>,
108 R: ArgminTrustRegion<F> + Solver<OpWrapper<O>>,
109 F: ArgminFloat,
110{
111 const NAME: &'static str = "Trust region";
112
113 fn init(
114 &mut self,
115 op: &mut OpWrapper<O>,
116 state: &IterState<O>,
117 ) -> Result<Option<ArgminIterData<O>>, Error> {
118 let param = state.get_param();
119 let grad = op.gradient(¶m)?;
120 let hessian = op.hessian(¶m)?;
121 self.fxk = op.apply(¶m)?;
122 self.mk0 = self.fxk;
123 Ok(Some(
124 ArgminIterData::new()
125 .param(param)
126 .cost(self.fxk)
127 .grad(grad)
128 .hessian(hessian),
129 ))
130 }
131
132 fn next_iter(
133 &mut self,
134 op: &mut OpWrapper<O>,
135 state: &IterState<O>,
136 ) -> Result<ArgminIterData<O>, Error> {
137 let param = state.get_param();
138 let grad = state
139 .get_grad()
140 .unwrap_or_else(|| op.gradient(¶m).unwrap());
141 let hessian = state
142 .get_hessian()
143 .unwrap_or_else(|| op.hessian(¶m).unwrap());
144
145 self.subproblem.set_radius(self.radius);
146
147 let ArgminResult {
148 operator: sub_op,
149 state: IterState { param: pk, .. },
150 } = Executor::new(
151 OpWrapper::new_from_wrapper(op),
152 self.subproblem.clone(),
153 param.clone(),
154 )
155 .grad(grad.clone())
156 .hessian(hessian.clone())
157 .ctrlc(false)
158 .run()?;
159
160 op.consume_op(sub_op);
163
164 let new_param = pk.add(¶m);
165 let fxkpk = op.apply(&new_param)?;
166 let mkpk =
167 self.fxk + pk.dot(&grad) + F::from_f64(0.5).unwrap() * pk.weighted_dot(&hessian, &pk);
168
169 let rho = reduction_ratio(self.fxk, fxkpk, self.mk0, mkpk);
170
171 let pk_norm = pk.norm();
172
173 let cur_radius = self.radius;
174 self.radius = if rho < F::from_f64(0.25).unwrap() {
175 F::from_f64(0.25).unwrap() * pk_norm
176 } else if rho > F::from_f64(0.75).unwrap()
177 && (pk_norm - self.radius).abs() <= F::from_f64(10.0).unwrap() * F::epsilon()
178 {
179 self.max_radius.min(F::from_f64(2.0).unwrap() * self.radius)
180 } else {
181 self.radius
182 };
183
184 Ok(if rho > self.eta {
185 self.fxk = fxkpk;
186 self.mk0 = fxkpk;
187 let grad = op.gradient(&new_param)?;
188 let hessian = op.hessian(&new_param)?;
189 ArgminIterData::new()
190 .param(new_param)
191 .cost(fxkpk)
192 .grad(grad)
193 .hessian(hessian)
194 } else {
195 ArgminIterData::new().param(param).cost(self.fxk)
196 }
197 .kv(make_kv!("radius" => cur_radius;)))
198 }
199
200 fn terminate(&mut self, _state: &IterState<O>) -> TerminationReason {
201 TerminationReason::NotTerminated
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::solver::trustregion::steihaug::Steihaug;
210 use crate::test_trait_impl;
211
212 type Operator = MinimalNoOperator;
213
214 test_trait_impl!(trustregion, TrustRegion<Steihaug<Operator, f64>, f64>);
215}