argmin/solver/particleswarm/
mod.rsuse crate::prelude::*;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std;
use std::default::Default;
#[derive(Serialize, Deserialize)]
pub struct ParticleSwarm<P, F> {
particles: Vec<Particle<P, F>>,
best_position: P,
best_cost: F,
weight_momentum: F,
weight_particle: F,
weight_swarm: F,
search_region: (P, P),
num_particles: usize,
}
impl<P, F> ParticleSwarm<P, F>
where
P: Position<F> + DeserializeOwned + Serialize,
F: ArgminFloat,
{
pub fn new(
search_region: (P, P),
num_particles: usize,
weight_momentum: F,
weight_particle: F,
weight_swarm: F,
) -> Result<Self, Error> {
let particle_swarm = ParticleSwarm {
particles: vec![],
best_position: P::rand_from_range(
&search_region.0,
&search_region.1,
),
best_cost: F::infinity(),
weight_momentum,
weight_particle,
weight_swarm,
search_region,
num_particles,
};
Ok(particle_swarm)
}
fn initialize_particles<O: ArgminOp<Param = P, Output = F, Float = F>>(
&mut self,
op: &mut OpWrapper<O>,
) {
self.particles = (0..self.num_particles)
.map(|_| self.initialize_particle(op))
.collect();
self.best_position = self.get_best_position();
self.best_cost = op.apply(&self.best_position).unwrap();
}
fn initialize_particle<O: ArgminOp<Param = P, Output = F, Float = F>>(
&mut self,
op: &mut OpWrapper<O>,
) -> Particle<P, F> {
let (min, max) = &self.search_region;
let delta = max.sub(min);
let delta_neg = delta.mul(&F::from_f64(-1.0).unwrap());
let initial_position = O::Param::rand_from_range(min, max);
let initial_cost = op.apply(&initial_position).unwrap(); Particle {
position: initial_position.clone(),
velocity: O::Param::rand_from_range(&delta_neg, &delta),
cost: initial_cost,
best_position: initial_position,
best_cost: initial_cost,
}
}
fn get_best_position(&self) -> P {
let mut best: Option<(&P, F)> = None;
for p in &self.particles {
match best {
Some(best_sofar) => {
if p.cost < best_sofar.1 {
best = Some((&p.position, p.cost))
}
}
None => best = Some((&p.position, p.cost)),
}
}
match best {
Some(best_sofar) => best_sofar.0.clone(),
None => panic!("Particles not initialized"),
}
}
}
impl<O, P, F> Solver<O> for ParticleSwarm<P, F>
where
O: ArgminOp<Output = F, Param = P, Float = F>,
O::Param: Position<F> + DeserializeOwned + Serialize,
O::Hessian: Clone + Default,
F: ArgminFloat,
{
const NAME: &'static str = "Particle Swarm Optimization";
fn init(
&mut self,
_op: &mut OpWrapper<O>,
_state: &IterState<O>,
) -> Result<Option<ArgminIterData<O>>, Error> {
self.initialize_particles(_op);
Ok(None)
}
fn next_iter(
&mut self,
_op: &mut OpWrapper<O>,
_state: &IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
let zero = O::Param::zero_like(&self.best_position);
for p in self.particles.iter_mut() {
let momentum = p.velocity.mul(&self.weight_momentum);
let to_optimum = p.best_position.sub(&p.position);
let pull_to_optimum = O::Param::rand_from_range(&zero, &to_optimum);
let pull_to_optimum = pull_to_optimum.mul(&self.weight_particle);
let to_global_optimum = self.best_position.sub(&p.position);
let pull_to_global_optimum =
O::Param::rand_from_range(&zero, &to_global_optimum).mul(&self.weight_swarm);
p.velocity = momentum.add(&pull_to_optimum).add(&pull_to_global_optimum);
let new_position = p.position.add(&p.velocity);
p.position = O::Param::min(
&O::Param::max(&new_position, &self.search_region.0),
&self.search_region.1,
);
p.cost = _op.apply(&p.position)?;
if p.cost < p.best_cost {
p.best_position = p.position.clone();
p.best_cost = p.cost;
if p.cost < self.best_cost {
self.best_position = p.position.clone();
self.best_cost = p.cost;
}
}
}
let population = self
.particles
.iter()
.map(|particle| (particle.position.clone(), particle.cost))
.collect();
let out = ArgminIterData::new()
.param(self.best_position.clone())
.cost(self.best_cost)
.population(population)
.kv(make_kv!(
"particles" => &self.particles;
));
Ok(out)
}
}
pub trait Position<F: ArgminFloat>:
Clone
+ Default
+ ArgminAdd<Self, Self>
+ ArgminSub<Self, Self>
+ ArgminMul<F, Self>
+ ArgminZeroLike
+ ArgminRandom
+ ArgminMinMax
+ std::fmt::Debug
{
}
impl<T, F: ArgminFloat> Position<F> for T
where
T: Clone
+ Default
+ ArgminAdd<Self, Self>
+ ArgminSub<Self, Self>
+ ArgminMul<F, Self>
+ ArgminZeroLike
+ ArgminRandom
+ ArgminMinMax
+ std::fmt::Debug,
F: ArgminFloat,
{
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct Particle<T, F> {
pub position: T,
velocity: T,
pub cost: F,
best_position: T,
best_cost: F,
}