argmin/solver/trustregion/
trustregion_method.rs

1// Copyright 2018-2020 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # References:
9//!
10//! [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
11//! Springer. ISBN 0-387-30303-0.
12
13use crate::prelude::*;
14use crate::solver::trustregion::reduction_ratio;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18/// The trust region method approximates the cost function within a certain region around the
19/// current point in parameter space. Depending on the quality of this approximation, the region is
20/// either expanded or contracted.
21///
22/// The calculation of the actual step length and direction is done by one of the following
23/// methods:
24///
25/// * [Cauchy point](../cauchypoint/struct.CauchyPoint.html)
26/// * [Dogleg method](../dogleg/struct.Dogleg.html)
27/// * [Steihaug method](../steihaug/struct.Steihaug.html)
28///
29/// This subproblem can be set via `set_subproblem(...)`. If this is not provided, it will default
30/// to the Steihaug method.
31///
32/// [Example](https://github.com/argmin-rs/argmin/blob/master/examples/trustregion_nd.rs)
33///
34/// # References:
35///
36/// [0] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
37/// Springer. ISBN 0-387-30303-0.
38#[derive(Clone, Serialize, Deserialize)]
39pub struct TrustRegion<R, F> {
40    /// Radius
41    radius: F,
42    /// Maximum Radius
43    max_radius: F,
44    /// eta \in [0, 1/4)
45    eta: F,
46    /// subproblem
47    subproblem: R,
48    /// f(xk)
49    fxk: F,
50    /// mk(0)
51    mk0: F,
52}
53
54impl<R, F: ArgminFloat> TrustRegion<R, F> {
55    /// Constructor
56    pub fn new(subproblem: R) -> Self {
57        TrustRegion {
58            radius: F::from_f64(1.0).unwrap(),
59            max_radius: F::from_f64(100.0).unwrap(),
60            eta: F::from_f64(0.125).unwrap(),
61            subproblem,
62            fxk: F::nan(),
63            mk0: F::nan(),
64        }
65    }
66
67    /// set radius
68    pub fn radius(mut self, radius: F) -> Self {
69        self.radius = radius;
70        self
71    }
72
73    /// Set maximum radius
74    pub fn max_radius(mut self, max_radius: F) -> Self {
75        self.max_radius = max_radius;
76        self
77    }
78
79    /// Set eta
80    pub fn eta(mut self, eta: F) -> Result<Self, Error> {
81        if eta >= F::from_f64(0.25).unwrap() || eta < F::from_f64(0.0).unwrap() {
82            return Err(ArgminError::InvalidParameter {
83                text: "TrustRegion: eta must be in [0, 1/4).".to_string(),
84            }
85            .into());
86        }
87        self.eta = eta;
88        Ok(self)
89    }
90}
91
92impl<O, R, F> Solver<O> for TrustRegion<R, F>
93where
94    O: ArgminOp<Output = F, Float = F>,
95    O::Param: Default
96        + Clone
97        + Debug
98        + Serialize
99        + ArgminMul<F, O::Param>
100        + ArgminWeightedDot<O::Param, F, O::Hessian>
101        + ArgminNorm<F>
102        + ArgminDot<O::Param, F>
103        + ArgminAdd<O::Param, O::Param>
104        + ArgminSub<O::Param, O::Param>
105        + ArgminZeroLike
106        + ArgminMul<F, O::Param>,
107    O::Hessian: Default + Clone + Debug + Serialize + ArgminDot<O::Param, O::Param>,
108    R: ArgminTrustRegion<F> + Solver<OpWrapper<O>>,
109    F: ArgminFloat,
110{
111    const NAME: &'static str = "Trust region";
112
113    fn init(
114        &mut self,
115        op: &mut OpWrapper<O>,
116        state: &IterState<O>,
117    ) -> Result<Option<ArgminIterData<O>>, Error> {
118        let param = state.get_param();
119        let grad = op.gradient(&param)?;
120        let hessian = op.hessian(&param)?;
121        self.fxk = op.apply(&param)?;
122        self.mk0 = self.fxk;
123        Ok(Some(
124            ArgminIterData::new()
125                .param(param)
126                .cost(self.fxk)
127                .grad(grad)
128                .hessian(hessian),
129        ))
130    }
131
132    fn next_iter(
133        &mut self,
134        op: &mut OpWrapper<O>,
135        state: &IterState<O>,
136    ) -> Result<ArgminIterData<O>, Error> {
137        let param = state.get_param();
138        let grad = state
139            .get_grad()
140            .unwrap_or_else(|| op.gradient(&param).unwrap());
141        let hessian = state
142            .get_hessian()
143            .unwrap_or_else(|| op.hessian(&param).unwrap());
144
145        self.subproblem.set_radius(self.radius);
146
147        let ArgminResult {
148            operator: sub_op,
149            state: IterState { param: pk, .. },
150        } = Executor::new(
151            OpWrapper::new_from_wrapper(op),
152            self.subproblem.clone(),
153            param.clone(),
154        )
155        .grad(grad.clone())
156        .hessian(hessian.clone())
157        .ctrlc(false)
158        .run()?;
159
160        // Operator must be consumed again, otherwise the operator, which moved into the subproblem
161        // executor as well as the function evaluation counts are lost.
162        op.consume_op(sub_op);
163
164        let new_param = pk.add(&param);
165        let fxkpk = op.apply(&new_param)?;
166        let mkpk =
167            self.fxk + pk.dot(&grad) + F::from_f64(0.5).unwrap() * pk.weighted_dot(&hessian, &pk);
168
169        let rho = reduction_ratio(self.fxk, fxkpk, self.mk0, mkpk);
170
171        let pk_norm = pk.norm();
172
173        let cur_radius = self.radius;
174        self.radius = if rho < F::from_f64(0.25).unwrap() {
175            F::from_f64(0.25).unwrap() * pk_norm
176        } else if rho > F::from_f64(0.75).unwrap()
177            && (pk_norm - self.radius).abs() <= F::from_f64(10.0).unwrap() * F::epsilon()
178        {
179            self.max_radius.min(F::from_f64(2.0).unwrap() * self.radius)
180        } else {
181            self.radius
182        };
183
184        Ok(if rho > self.eta {
185            self.fxk = fxkpk;
186            self.mk0 = fxkpk;
187            let grad = op.gradient(&new_param)?;
188            let hessian = op.hessian(&new_param)?;
189            ArgminIterData::new()
190                .param(new_param)
191                .cost(fxkpk)
192                .grad(grad)
193                .hessian(hessian)
194        } else {
195            ArgminIterData::new().param(param).cost(self.fxk)
196        }
197        .kv(make_kv!("radius" => cur_radius;)))
198    }
199
200    fn terminate(&mut self, _state: &IterState<O>) -> TerminationReason {
201        // todo
202        TerminationReason::NotTerminated
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::solver::trustregion::steihaug::Steihaug;
210    use crate::test_trait_impl;
211
212    type Operator = MinimalNoOperator;
213
214    test_trait_impl!(trustregion, TrustRegion<Steihaug<Operator, f64>, f64>);
215}