argmin/solver/linesearch/
backtracking.rs1use crate::prelude::*;
11use crate::solver::linesearch::condition::*;
12use serde::de::DeserializeOwned;
13use serde::{Deserialize, Serialize};
14
15#[derive(Serialize, Deserialize, Clone)]
27pub struct BacktrackingLineSearch<P, L, F> {
28 init_param: P,
30 init_cost: F,
32 init_grad: P,
34 search_direction: Option<P>,
36 rho: F,
38 condition: Box<L>,
40 alpha: F,
42}
43
44impl<P: Default, L, F: ArgminFloat> BacktrackingLineSearch<P, L, F> {
45 pub fn new(condition: L) -> Self {
47 BacktrackingLineSearch {
48 init_param: P::default(),
49 init_cost: F::infinity(),
50 init_grad: P::default(),
51 search_direction: None,
52 rho: F::from_f64(0.9).unwrap(),
53 condition: Box::new(condition),
54 alpha: F::from_f64(1.0).unwrap(),
55 }
56 }
57
58 pub fn rho(mut self, rho: F) -> Result<Self, Error> {
60 if rho <= F::from_f64(0.0).unwrap() || rho >= F::from_f64(1.0).unwrap() {
61 return Err(ArgminError::InvalidParameter {
62 text: "BacktrackingLineSearch: Contraction factor rho must be in (0, 1)."
63 .to_string(),
64 }
65 .into());
66 }
67 self.rho = rho;
68 Ok(self)
69 }
70}
71
72impl<P, L, F> ArgminLineSearch<P, F> for BacktrackingLineSearch<P, L, F>
73where
74 P: Clone + Serialize + ArgminSub<P, P> + ArgminDot<P, f64> + ArgminScaledAdd<P, f64, P>,
75 L: LineSearchCondition<P, F>,
76 F: ArgminFloat + Serialize + DeserializeOwned,
77{
78 fn set_search_direction(&mut self, search_direction: P) {
80 self.search_direction = Some(search_direction);
81 }
82
83 fn set_init_alpha(&mut self, alpha: F) -> Result<(), Error> {
85 if alpha <= F::from_f64(0.0).unwrap() {
86 return Err(ArgminError::InvalidParameter {
87 text: "LineSearch: Inital alpha must be > 0.".to_string(),
88 }
89 .into());
90 }
91 self.alpha = alpha;
92 Ok(())
93 }
94}
95
96impl<O, P, L, F> Solver<O> for BacktrackingLineSearch<P, L, F>
97where
98 P: Clone
99 + Default
100 + Serialize
101 + DeserializeOwned
102 + ArgminSub<P, P>
103 + ArgminDot<P, F>
104 + ArgminScaledAdd<P, F, P>,
105 O: ArgminOp<Param = P, Output = F, Float = F>,
106 L: LineSearchCondition<P, F>,
107 F: ArgminFloat,
108{
109 const NAME: &'static str = "Backtracking Line search";
110
111 fn init(
112 &mut self,
113 op: &mut OpWrapper<O>,
114 state: &IterState<O>,
115 ) -> Result<Option<ArgminIterData<O>>, Error> {
116 self.init_param = state.get_param();
117 let cost = state.get_cost();
118 self.init_cost = if cost == F::infinity() {
119 op.apply(&self.init_param)?
120 } else {
121 cost
122 };
123
124 self.init_grad = state.get_grad().unwrap_or(op.gradient(&self.init_param)?);
125
126 if self.search_direction.is_none() {
127 return Err(ArgminError::NotInitialized {
128 text: "BacktrackingLineSearch: search_direction must be set.".to_string(),
129 }
130 .into());
131 }
132
133 Ok(None)
134 }
135
136 fn next_iter(
137 &mut self,
138 op: &mut OpWrapper<O>,
139 _state: &IterState<O>,
140 ) -> Result<ArgminIterData<O>, Error> {
141 let new_param = self
142 .init_param
143 .scaled_add(&self.alpha, self.search_direction.as_ref().unwrap());
144
145 let cur_cost = op.apply(&new_param)?;
146
147 self.alpha = self.alpha * self.rho;
148
149 let mut out = ArgminIterData::new()
150 .param(new_param.clone())
151 .cost(cur_cost);
152
153 if self.condition.requires_cur_grad() {
154 out = out.grad(op.gradient(&new_param)?);
155 }
156
157 Ok(out)
158 }
159
160 fn terminate(&mut self, state: &IterState<O>) -> TerminationReason {
161 if self.condition.eval(
162 state.get_cost(),
163 state.get_grad().unwrap_or_default(),
164 self.init_cost,
165 self.init_grad.clone(),
166 self.search_direction.clone().unwrap(),
167 self.alpha,
168 ) {
169 TerminationReason::LineSearchConditionMet
170 } else {
171 TerminationReason::NotTerminated
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::core::MinimalNoOperator;
180 use crate::test_trait_impl;
181
182 test_trait_impl!(backtrackinglinesearch,
183 BacktrackingLineSearch<MinimalNoOperator, ArmijoCondition<f64>, f64>);
184}