argmin/solver/goldensectionsearch/
mod.rs1use crate::prelude::*;
13use serde::{Deserialize, Serialize};
14
15const GOLDEN_RATIO: f64 = 1.61803398874989484820;
16const G1: f64 = -1.0 + GOLDEN_RATIO;
17const G2: f64 = 1.0 - G1;
18
19#[derive(Clone, Serialize, Deserialize)]
38pub struct GoldenSectionSearch<F> {
39 g1: F,
40 g2: F,
41 min_bound: F,
42 max_bound: F,
43 init_estimate: F,
44 tolerance: F,
45
46 x0: F,
47 x1: F,
48 x2: F,
49 x3: F,
50 f1: F,
51 f2: F,
52}
53
54impl<F> GoldenSectionSearch<F>
55where
56 F: ArgminFloat,
57{
58 pub fn new(min_bound: F, max_bound: F) -> Self {
60 GoldenSectionSearch {
61 g1: F::from(G1).unwrap(),
62 g2: F::from(G2).unwrap(),
63 min_bound,
64 max_bound,
65 init_estimate: F::zero(),
66 tolerance: F::from(0.01).unwrap(),
67 x0: min_bound,
68 x1: F::zero(),
69 x2: F::zero(),
70 x3: max_bound,
71 f1: F::zero(),
72 f2: F::zero(),
73 }
74 }
75
76 pub fn tolerance(mut self, tol: F) -> Self {
78 self.tolerance = tol;
79 self
80 }
81}
82
83impl<O, F> Solver<O> for GoldenSectionSearch<F>
84where
85 O: ArgminOp<Output = F, Param = F, Float = F>,
86 F: ArgminFloat,
87{
88 const NAME: &'static str = "Golden-section search";
89
90 fn init(
91 &mut self,
92 op: &mut OpWrapper<O>,
93 state: &IterState<O>,
94 ) -> Result<Option<ArgminIterData<O>>, Error> {
95 let init_estimate = state.param;
96 if init_estimate < self.min_bound || init_estimate > self.max_bound {
97 Err(ArgminError::InvalidParameter {
98 text: "Initial estimate must be ∈ [min_bound, max_bound].".to_string(),
99 }
100 .into())
101 } else {
102 let ie_min = init_estimate - self.min_bound;
103 let max_ie = self.max_bound - init_estimate;
104 let (x1, x2) = if max_ie.abs() > ie_min.abs() {
105 (init_estimate, init_estimate + self.g2 * max_ie)
106 } else {
107 (init_estimate - self.g2 * ie_min, init_estimate)
108 };
109 self.x1 = x1;
110 self.x2 = x2;
111 self.f1 = op.apply(&self.x1)?;
112 self.f2 = op.apply(&self.x2)?;
113 if self.f1 < self.f2 {
114 Ok(Some(ArgminIterData::new().param(self.x1).cost(self.f1)))
115 } else {
116 Ok(Some(ArgminIterData::new().param(self.x2).cost(self.f2)))
117 }
118 }
119 }
120
121 fn next_iter(
122 &mut self,
123 op: &mut OpWrapper<O>,
124 state: &IterState<O>,
125 ) -> Result<ArgminIterData<O>, Error> {
126 if self.tolerance * (self.x1.abs() + self.x2.abs()) >= (self.x3 - self.x0).abs() {
127 return Ok(ArgminIterData::new()
128 .param(state.param)
129 .cost(state.cost)
130 .termination_reason(TerminationReason::TargetToleranceReached));
131 }
132
133 if self.f2 < self.f1 {
134 self.x0 = self.x1;
135 self.x1 = self.x2;
136 self.x2 = self.g1 * self.x1 + self.g2 * self.x3;
137 self.f1 = self.f2;
138 self.f2 = op.apply(&self.x2)?;
139 } else {
140 self.x3 = self.x2;
141 self.x2 = self.x1;
142 self.x1 = self.g1 * self.x2 + self.g2 * self.x0;
143 self.f2 = self.f1;
144 self.f1 = op.apply(&self.x1)?;
145 }
146 if self.f1 < self.f2 {
147 Ok(ArgminIterData::new().param(self.x1).cost(self.f1))
148 } else {
149 Ok(ArgminIterData::new().param(self.x2).cost(self.f2))
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::test_trait_impl;
158
159 test_trait_impl!(golden_section_search, GoldenSectionSearch<f64>);
160}