1use crate::prelude::*;
22use serde::de::DeserializeOwned;
23use serde::{Deserialize, Serialize};
24use std::default::Default;
25
26#[derive(Serialize, Deserialize, Clone)]
40pub struct MoreThuenteLineSearch<P, F> {
41 search_direction_b: Option<P>,
43 init_param: P,
45 finit: F,
47 init_grad: P,
49 search_direction: P,
51 dginit: F,
53 dgtest: F,
55 ftol: F,
57 gtol: F,
59 xtrapf: F,
61 width: F,
63 width1: F,
65 xtol: F,
67 alpha: F,
69 stpmin: F,
71 stpmax: F,
73 stp: Step<F>,
75 stx: Step<F>,
77 sty: Step<F>,
79 f: F,
81 brackt: bool,
83 stage1: bool,
85 infoc: usize,
87}
88
89#[derive(Clone, Serialize, Deserialize)]
90struct Step<F> {
91 pub x: F,
92 pub fx: F,
93 pub gx: F,
94}
95
96impl<F> Step<F> {
97 pub fn new(x: F, fx: F, gx: F) -> Self {
98 Step { x, fx, gx }
99 }
100}
101
102impl<F: ArgminFloat> Default for Step<F> {
103 fn default() -> Self {
104 Step {
105 x: F::from_f64(0.0).unwrap(),
106 fx: F::from_f64(0.0).unwrap(),
107 gx: F::from_f64(0.0).unwrap(),
108 }
109 }
110}
111
112impl<P: Default, F: ArgminFloat> MoreThuenteLineSearch<P, F> {
113 pub fn new() -> Self {
115 MoreThuenteLineSearch {
116 search_direction_b: None,
117 init_param: P::default(),
118 finit: F::infinity(),
119 init_grad: P::default(),
120 search_direction: P::default(),
121 dginit: F::from_f64(0.0).unwrap(),
122 dgtest: F::from_f64(0.0).unwrap(),
123 ftol: F::from_f64(1e-4).unwrap(),
124 gtol: F::from_f64(0.9).unwrap(),
125 xtrapf: F::from_f64(4.0).unwrap(),
126 width: F::nan(),
127 width1: F::nan(),
128 xtol: F::from_f64(1e-10).unwrap(),
129 alpha: F::from_f64(1.0).unwrap(),
130 stpmin: F::epsilon().sqrt(),
131 stpmax: F::infinity(),
132 stp: Step::default(),
133 stx: Step::default(),
134 sty: Step::default(),
135 f: F::nan(),
136 brackt: false,
137 stage1: true,
138 infoc: 1,
139 }
140 }
141
142 pub fn c(mut self, c1: F, c2: F) -> Result<Self, Error> {
144 if c1 <= F::from_f64(0.0).unwrap() || c1 >= c2 {
145 return Err(ArgminError::InvalidParameter {
146 text: "MoreThuenteLineSearch: Parameter c1 must be in (0, c2).".to_string(),
147 }
148 .into());
149 }
150 if c2 <= c1 || c2 >= F::from_f64(1.0).unwrap() {
151 return Err(ArgminError::InvalidParameter {
152 text: "MoreThuenteLineSearch: Parameter c2 must be in (c1, 1).".to_string(),
153 }
154 .into());
155 }
156 self.ftol = c1;
157 self.gtol = c2;
158 Ok(self)
159 }
160
161 pub fn alpha(mut self, alpha_min: F, alpha_max: F) -> Result<Self, Error> {
163 if alpha_min < F::from_f64(0.0).unwrap() {
164 return Err(ArgminError::InvalidParameter {
165 text: "MoreThuenteLineSearch: alpha_min must be >= 0.0.".to_string(),
166 }
167 .into());
168 }
169 if alpha_max <= alpha_min {
170 return Err(ArgminError::InvalidParameter {
171 text: "MoreThuenteLineSearch: alpha_min must be smaller than alpha_max."
172 .to_string(),
173 }
174 .into());
175 }
176 self.stpmin = alpha_min;
177 self.stpmax = alpha_max;
178 Ok(self)
179 }
180}
181
182impl<P: Default, F: ArgminFloat> Default for MoreThuenteLineSearch<P, F> {
183 fn default() -> Self {
184 MoreThuenteLineSearch::new()
185 }
186}
187
188impl<P, F> ArgminLineSearch<P, F> for MoreThuenteLineSearch<P, F>
189where
190 P: Clone + Serialize + ArgminSub<P, P> + ArgminDot<P, F> + ArgminScaledAdd<P, F, P>,
191 F: ArgminFloat,
192{
193 fn set_search_direction(&mut self, search_direction: P) {
195 self.search_direction_b = Some(search_direction);
196 }
197
198 fn set_init_alpha(&mut self, alpha: F) -> Result<(), Error> {
200 if alpha <= F::from_f64(0.0).unwrap() {
201 return Err(ArgminError::InvalidParameter {
202 text: "MoreThuenteLineSearch: Initial alpha must be > 0.".to_string(),
203 }
204 .into());
205 }
206 self.alpha = alpha;
207 Ok(())
208 }
209}
210
211impl<P, O, F> Solver<O> for MoreThuenteLineSearch<P, F>
212where
213 O: ArgminOp<Param = P, Output = F, Float = F>,
214 P: Clone
215 + Serialize
216 + DeserializeOwned
217 + ArgminSub<P, P>
218 + ArgminDot<P, F>
219 + ArgminScaledAdd<P, F, P>,
220 F: ArgminFloat,
221{
222 const NAME: &'static str = "More-Thuente Line search";
223
224 fn init(
225 &mut self,
226 op: &mut OpWrapper<O>,
227 state: &IterState<O>,
228 ) -> Result<Option<ArgminIterData<O>>, Error> {
229 self.search_direction = check_param!(
230 self.search_direction_b,
231 "MoreThuenteLineSearch: Search direction not initialized. Call `set_search_direction`."
232 );
233
234 self.init_param = state.get_param();
235
236 let cost = state.get_cost();
237 self.finit = if cost.is_infinite() {
238 op.apply(&self.init_param)?
239 } else {
240 cost
241 };
242
243 self.init_grad = state
244 .get_grad()
245 .unwrap_or_else(|| op.gradient(&self.init_param).unwrap());
246
247 self.dginit = self.init_grad.dot(&self.search_direction);
248
249 if self.dginit >= F::from_f64(0.0).unwrap() {
251 return Err(ArgminError::ConditionViolated {
252 text: "MoreThuenteLineSearch: Search direction must be a descent direction."
253 .to_string(),
254 }
255 .into());
256 }
257
258 self.stage1 = true;
259 self.brackt = false;
260
261 self.dgtest = self.ftol * self.dginit;
262 self.width = self.stpmax - self.stpmin;
263 self.width1 = F::from_f64(2.0).unwrap() * self.width;
264 self.f = self.finit;
265
266 self.stp = Step::new(self.alpha, F::nan(), F::nan());
267 self.stx = Step::new(F::from_f64(0.0).unwrap(), self.finit, self.dginit);
268 self.sty = Step::new(F::from_f64(0.0).unwrap(), self.finit, self.dginit);
269
270 Ok(None)
271 }
272
273 fn next_iter(
274 &mut self,
275 op: &mut OpWrapper<O>,
276 _state: &IterState<O>,
277 ) -> Result<ArgminIterData<O>, Error> {
278 let mut info = 0;
280 let (stmin, stmax) = if self.brackt {
281 (self.stx.x.min(self.sty.x), self.stx.x.max(self.sty.x))
282 } else {
283 (
284 self.stx.x,
285 self.stp.x + self.xtrapf * (self.stp.x - self.stx.x),
286 )
287 };
288
289 self.stp.x = self.stp.x.max(self.stpmin);
291 self.stp.x = self.stp.x.min(self.stpmax);
292
293 if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax))
296 || (self.brackt && (stmax - stmin) <= self.xtol * stmax)
297 || self.infoc == 0
298 {
299 self.stp.x = self.stx.x;
300 }
301
302 let new_param = self
304 .init_param
305 .scaled_add(&self.stp.x, &self.search_direction);
306 self.f = op.apply(&new_param)?;
307 let new_grad = op.gradient(&new_param)?;
308 let cur_cost = self.f;
309 let cur_param = new_param;
310 let cur_grad = new_grad.clone();
311 let dg = self.search_direction.dot(&new_grad);
313 let ftest1 = self.finit + self.stp.x * self.dgtest;
314 if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax)) || self.infoc == 0 {
318 info = 6;
319 }
320
321 if (self.stp.x - self.stpmax).abs() < F::epsilon() && self.f <= ftest1 && dg <= self.dgtest
322 {
323 info = 5;
324 }
325
326 if (self.stp.x - self.stpmin).abs() < F::epsilon() && (self.f > ftest1 || dg >= self.dgtest)
327 {
328 info = 4;
329 }
330
331 if self.brackt && stmax - stmin <= self.xtol * stmax {
332 info = 2;
333 }
334
335 if self.f <= ftest1 && dg.abs() <= self.gtol * (-self.dginit) {
336 info = 1;
337 }
338
339 if info != 0 {
340 return Ok(ArgminIterData::new()
341 .param(cur_param)
342 .cost(cur_cost)
343 .grad(cur_grad)
344 .termination_reason(TerminationReason::LineSearchConditionMet));
345 }
346
347 if self.stage1 && self.f <= ftest1 && dg >= self.ftol.min(self.gtol) * self.dginit {
348 self.stage1 = false;
349 }
350
351 if self.stage1 && self.f <= self.stp.fx && self.f > ftest1 {
352 let fm = self.f - self.stp.x * self.dgtest;
353 let fxm = self.stx.fx - self.stx.x * self.dgtest;
354 let fym = self.sty.fx - self.sty.x * self.dgtest;
355 let dgm = dg - self.dgtest;
356 let dgxm = self.stx.gx - self.dgtest;
357 let dgym = self.sty.gx - self.dgtest;
358
359 let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
360 Step::new(self.stx.x, fxm, dgxm),
361 Step::new(self.sty.x, fym, dgym),
362 Step::new(self.stp.x, fm, dgm),
363 self.brackt,
364 stmin,
365 stmax,
366 )?;
367
368 self.stx.x = stx1.x;
369 self.sty.x = sty1.x;
370 self.stp.x = stp1.x;
371 self.stx.fx = self.stx.fx + stx1.x * self.dgtest;
372 self.sty.fx = self.sty.fx + sty1.x * self.dgtest;
373 self.stx.gx = self.stx.gx + self.dgtest;
374 self.sty.gx = self.sty.gx + self.dgtest;
375 self.brackt = brackt1;
376 self.stp = stp1;
377 self.infoc = infoc;
378 } else {
379 let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
380 self.stx.clone(),
381 self.sty.clone(),
382 Step::new(self.stp.x, self.f, dg),
383 self.brackt,
384 stmin,
385 stmax,
386 )?;
387 self.stx = stx1;
388 self.sty = sty1;
389 self.stp = stp1;
390 self.f = self.stp.fx;
391 self.brackt = brackt1;
393 self.infoc = infoc;
394 }
395
396 if self.brackt {
397 if (self.sty.x - self.stx.x).abs() >= F::from_f64(0.66).unwrap() * self.width1 {
398 self.stp.x = self.stx.x + F::from_f64(0.5).unwrap() * (self.sty.x - self.stx.x);
399 }
400 self.width1 = self.width;
401 self.width = (self.sty.x - self.stx.x).abs();
402 }
403
404 Ok(ArgminIterData::new())
409 }
410}
411
412fn cstep<F: ArgminFloat>(
413 stx: Step<F>,
414 sty: Step<F>,
415 stp: Step<F>,
416 brackt: bool,
417 stpmin: F,
418 stpmax: F,
419) -> Result<(Step<F>, Step<F>, Step<F>, bool, F, F, usize), Error> {
420 let mut info: usize = 0;
421 let bound: bool;
422 let mut stpf: F;
423 let stpc: F;
424 let stpq: F;
425 let mut brackt = brackt;
426
427 if (brackt && (stp.x <= stx.x.min(sty.x) || stp.x >= stx.x.max(sty.x)))
429 || stx.gx * (stp.x - stx.x) >= F::from_f64(0.0).unwrap()
430 || stpmax < stpmin
431 {
432 return Ok((stx, sty, stp, brackt, stpmin, stpmax, info));
433 }
434
435 let sgnd = stp.gx * (stx.gx / stx.gx.abs());
437
438 if stp.fx > stx.fx {
439 info = 1;
443 bound = true;
444 let theta =
445 F::from_f64(3.0).unwrap() * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
446 let tmp = vec![theta, stx.gx, stp.gx];
447 if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
449 return Err(ArgminError::ConditionViolated {
450 text: "MoreThuenteLineSearch: NaN or Inf encountered during iteration".to_string(),
451 }
452 .into());
453 }
454 let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
455 let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
456 if stp.x < stx.x {
457 gamma = -gamma;
458 }
459
460 let p = (gamma - stx.gx) + theta;
461 let q = ((gamma - stx.gx) + gamma) + stp.gx;
462 let r = p / q;
463 stpc = stx.x + r * (stp.x - stx.x);
464 stpq = stx.x
465 + ((stx.gx / ((stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx))
466 / F::from_f64(2.0).unwrap())
467 * (stp.x - stx.x);
468 if (stpc - stx.x).abs() < (stpq - stx.x).abs() {
469 stpf = stpc;
470 } else {
471 stpf = stpc + (stpq - stpc) / F::from_f64(2.0).unwrap();
472 }
473 brackt = true;
474 } else if sgnd < F::from_f64(0.0).unwrap() {
475 info = 2;
479 bound = false;
480 let theta =
481 F::from_f64(3.0).unwrap() * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
482 let tmp = vec![theta, stx.gx, stp.gx];
483 if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
485 return Err(ArgminError::ConditionViolated {
486 text: "MoreThuenteLineSearch: NaN or Inf encountered during iteration".to_string(),
487 }
488 .into());
489 }
490 let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
491 let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
492 if stp.x > stx.x {
493 gamma = -gamma;
494 }
495 let p = (gamma - stp.gx) + theta;
496 let q = ((gamma - stp.gx) + gamma) + stx.gx;
497 let r = p / q;
498 stpc = stp.x + r * (stx.x - stp.x);
499 stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
500 if (stpc - stp.x).abs() > (stpq - stp.x).abs() {
501 stpf = stpc;
502 } else {
503 stpf = stpq;
504 }
505 brackt = true;
506 } else if stp.gx.abs() < stx.gx.abs() {
507 info = 3;
514 bound = true;
515 let theta =
516 F::from_f64(3.0).unwrap() * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
517 let tmp = vec![theta, stx.gx, stp.gx];
518 if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
520 return Err(ArgminError::ConditionViolated {
521 text: "MoreThuenteLineSearch: NaN or Inf encountered during iteration".to_string(),
522 }
523 .into());
524 }
525 let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
526 let mut gamma = *s
530 * F::from_f64(0.0)
531 .unwrap()
532 .max((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s))
533 .sqrt();
534 if stp.x > stx.x {
535 gamma = -gamma;
536 }
537
538 let p = (gamma - stp.gx) + theta;
539 let q = (gamma + (stx.gx - stp.gx)) + gamma;
540 let r = p / q;
541 if r < F::from_f64(0.0).unwrap() && gamma != F::from_f64(0.0).unwrap() {
542 stpc = stp.x + r * (stx.x - stp.x);
543 } else if stp.x > stx.x {
544 stpc = stpmax;
545 } else {
546 stpc = stpmin;
547 }
548 stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
549 if brackt {
550 if (stp.x - stpc).abs() < (stp.x - stpq).abs() {
551 stpf = stpc;
552 } else {
553 stpf = stpq;
554 }
555 } else if (stp.x - stpc).abs() > (stp.x - stpq).abs() {
556 stpf = stpc;
557 } else {
558 stpf = stpq;
559 }
560 } else {
561 info = 4;
565 bound = false;
566 if brackt {
567 let theta =
568 F::from_f64(3.0).unwrap() * (stp.fx - sty.fx) / (sty.x - stp.x) + sty.gx + stp.gx;
569 let tmp = vec![theta, sty.gx, stp.gx];
570 if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
572 return Err(ArgminError::ConditionViolated {
573 text: "MoreThuenteLineSearch: NaN or Inf encountered during iteration"
574 .to_string(),
575 }
576 .into());
577 }
578 let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
579 let mut gamma = *s * ((theta / *s).powi(2) - (sty.gx / *s) * (stp.gx / *s)).sqrt();
580 if stp.x > sty.x {
581 gamma = -gamma;
582 }
583 let p = (gamma - stp.gx) + theta;
584 let q = ((gamma - stp.gx) + gamma) + sty.gx;
585 let r = p / q;
586 stpc = stp.x + r * (sty.x - stp.x);
587 stpf = stpc;
588 } else if stp.x > stx.x {
589 stpf = stpmax;
590 } else {
591 stpf = stpmin;
592 }
593 }
594 let mut stx_o = stx.clone();
598 let mut sty_o = sty.clone();
599 let mut stp_o = stp.clone();
600 if stp_o.fx > stx_o.fx {
601 sty_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
602 } else {
603 if sgnd < F::from_f64(0.0).unwrap() {
604 sty_o = Step::new(stx_o.x, stx_o.fx, stx_o.gx);
605 }
606 stx_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
607 }
608
609 stpf = stpmax.min(stpf);
612 stpf = stpmin.max(stpf);
613
614 stp_o.x = stpf;
615 if brackt && bound {
616 if sty_o.x > stx_o.x {
617 stp_o.x = stp_o
618 .x
619 .min(stx_o.x + F::from_f64(0.66).unwrap() * (sty_o.x - stx_o.x));
620 } else {
621 stp_o.x = stp_o
622 .x
623 .max(stx_o.x + F::from_f64(0.66).unwrap() * (sty_o.x - stx_o.x));
624 }
625 }
626
627 Ok((stx_o, sty_o, stp_o, brackt, stpmin, stpmax, info))
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use crate::core::MinimalNoOperator;
634 use crate::test_trait_impl;
635
636 test_trait_impl!(morethuente, MoreThuenteLineSearch<MinimalNoOperator, f64>);
637}