1use {
2 Symbol,
3 SymbolType,
4 Constraint,
5 Variable,
6 Expression,
7 Term,
8 Row,
9 AddConstraintError,
10 RemoveConstraintError,
11 InternalSolverError,
12 SuggestValueError,
13 AddEditVariableError,
14 RemoveEditVariableError,
15 RelationalOperator,
16 near_zero
17};
18
19use ::std::rc::Rc;
20use ::std::cell::RefCell;
21use ::std::collections::{ HashMap, HashSet };
22use ::std::collections::hash_map::Entry;
23
24#[derive(Copy, Clone)]
25struct Tag {
26 marker: Symbol,
27 other: Symbol
28}
29
30#[derive(Clone)]
31struct EditInfo {
32 tag: Tag,
33 constraint: Constraint,
34 constant: f64
35}
36
37pub struct Solver {
39 cns: HashMap<Constraint, Tag>,
40 var_data: HashMap<Variable, (f64, Symbol, usize)>,
41 var_for_symbol: HashMap<Symbol, Variable>,
42 public_changes: Vec<(Variable, f64)>,
43 changed: HashSet<Variable>,
44 should_clear_changes: bool,
45 rows: HashMap<Symbol, Box<Row>>,
46 edits: HashMap<Variable, EditInfo>,
47 infeasible_rows: Vec<Symbol>, objective: Rc<RefCell<Row>>,
49 artificial: Option<Rc<RefCell<Row>>>,
50 id_tick: usize
51}
52
53impl Solver {
54 pub fn new() -> Solver {
56 Solver {
57 cns: HashMap::new(),
58 var_data: HashMap::new(),
59 var_for_symbol: HashMap::new(),
60 public_changes: Vec::new(),
61 changed: HashSet::new(),
62 should_clear_changes: false,
63 rows: HashMap::new(),
64 edits: HashMap::new(),
65 infeasible_rows: Vec::new(),
66 objective: Rc::new(RefCell::new(Row::new(0.0))),
67 artificial: None,
68 id_tick: 1
69 }
70 }
71
72 pub fn add_constraints<'a, I: IntoIterator<Item = &'a Constraint>>(
73 &mut self,
74 constraints: I) -> Result<(), AddConstraintError>
75 {
76 for constraint in constraints {
77 try!(self.add_constraint(constraint.clone()));
78 }
79 Ok(())
80 }
81
82 pub fn add_constraint(&mut self, constraint: Constraint) -> Result<(), AddConstraintError> {
84 if self.cns.contains_key(&constraint) {
85 return Err(AddConstraintError::DuplicateConstraint);
86 }
87
88 let (mut row, tag) = self.create_row(&constraint);
95 let mut subject = Solver::choose_subject(&row, &tag);
96
97 if subject.type_() == SymbolType::Invalid && Solver::all_dummies(&row) {
104 if !near_zero(row.constant) {
105 return Err(AddConstraintError::UnsatisfiableConstraint);
106 } else {
107 subject = tag.marker;
108 }
109 }
110
111 if subject.type_() == SymbolType::Invalid {
115 if !try!(self.add_with_artificial_variable(&row)
116 .map_err(|e| AddConstraintError::InternalSolverError(e.0))) {
117 return Err(AddConstraintError::UnsatisfiableConstraint);
118 }
119 } else {
120 row.solve_for_symbol(subject);
121 self.substitute(subject, &row);
122 if subject.type_() == SymbolType::External && row.constant != 0.0 {
123 let v = self.var_for_symbol[&subject];
124 self.var_changed(v);
125 }
126 self.rows.insert(subject, row);
127 }
128
129 self.cns.insert(constraint, tag);
130
131 let objective = self.objective.clone();
135 try!(self.optimise(&objective).map_err(|e| AddConstraintError::InternalSolverError(e.0)));
136 Ok(())
137 }
138
139 pub fn remove_constraint(&mut self, constraint: &Constraint) -> Result<(), RemoveConstraintError> {
141 let tag = try!(self.cns.remove(constraint).ok_or(RemoveConstraintError::UnknownConstraint));
142
143 self.remove_constraint_effects(constraint, &tag);
147
148 if let None = self.rows.remove(&tag.marker) {
151 let (leaving, mut row) = try!(self.get_marker_leaving_row(tag.marker)
152 .ok_or(
153 RemoveConstraintError::InternalSolverError(
154 "Failed to find leaving row.")));
155 row.solve_for_symbols(leaving, tag.marker);
156 self.substitute(tag.marker, &row);
157 }
158
159 let objective = self.objective.clone();
163 try!(self.optimise(&objective).map_err(|e| RemoveConstraintError::InternalSolverError(e.0)));
164
165 for term in &constraint.expr().terms {
168 if !near_zero(term.coefficient) {
169 let mut should_remove = false;
170 if let Some(&mut (_, _, ref mut count)) = self.var_data.get_mut(&term.variable) {
171 *count -= 1;
172 should_remove = *count == 0;
173 }
174 if should_remove {
175 self.var_for_symbol.remove(&self.var_data[&term.variable].1);
176 self.var_data.remove(&term.variable);
177 }
178 }
179 }
180 Ok(())
181 }
182
183 pub fn has_constraint(&self, constraint: &Constraint) -> bool {
185 self.cns.contains_key(constraint)
186 }
187
188 pub fn add_edit_variable(&mut self, v: Variable, strength: f64) -> Result<(), AddEditVariableError> {
193 if self.edits.contains_key(&v) {
194 return Err(AddEditVariableError::DuplicateEditVariable);
195 }
196 let strength = ::strength::clip(strength);
197 if strength == ::strength::REQUIRED {
198 return Err(AddEditVariableError::BadRequiredStrength);
199 }
200 let cn = Constraint::new(Expression::from_term(Term::new(v.clone(), 1.0)),
201 RelationalOperator::Equal,
202 strength);
203 self.add_constraint(cn.clone()).unwrap();
204 self.edits.insert(v.clone(), EditInfo {
205 tag: self.cns[&cn].clone(),
206 constraint: cn,
207 constant: 0.0
208 });
209 Ok(())
210 }
211
212 pub fn remove_edit_variable(&mut self, v: Variable) -> Result<(), RemoveEditVariableError> {
214 if let Some(constraint) = self.edits.remove(&v).map(|e| e.constraint) {
215 try!(self.remove_constraint(&constraint)
216 .map_err(|e| match e {
217 RemoveConstraintError::UnknownConstraint =>
218 RemoveEditVariableError::InternalSolverError("Edit constraint not in system"),
219 RemoveConstraintError::InternalSolverError(s) =>
220 RemoveEditVariableError::InternalSolverError(s)
221 }));
222 Ok(())
223 } else {
224 Err(RemoveEditVariableError::UnknownEditVariable)
225 }
226 }
227
228 pub fn has_edit_variable(&self, v: &Variable) -> bool {
230 self.edits.contains_key(v)
231 }
232
233 pub fn suggest_value(&mut self, variable: Variable, value: f64) -> Result<(), SuggestValueError> {
238 let (info_tag_marker, info_tag_other, delta) = {
239 let info = try!(self.edits.get_mut(&variable).ok_or(SuggestValueError::UnknownEditVariable));
240 let delta = value - info.constant;
241 info.constant = value;
242 (info.tag.marker, info.tag.other, delta)
243 };
244 {
249 let infeasible_rows = &mut self.infeasible_rows;
250 if self.rows.get_mut(&info_tag_marker)
251 .map(|row|
252 if row.add(-delta) < 0.0 {
253 infeasible_rows.push(info_tag_marker);
254 }).is_some()
255 {
256
257 } else if self.rows.get_mut(&info_tag_other)
258 .map(|row|
259 if row.add(delta) < 0.0 {
260 infeasible_rows.push(info_tag_other);
261 }).is_some()
262 {
263
264 } else {
265 for (symbol, row) in &mut self.rows {
266 let coeff = row.coefficient_for(info_tag_marker);
267 let diff = delta * coeff;
268 if diff != 0.0 && symbol.type_() == SymbolType::External {
269 let v = self.var_for_symbol[symbol];
270 if self.should_clear_changes {
272 self.changed.clear();
273 self.should_clear_changes = false;
274 }
275 self.changed.insert(v);
276 }
277 if coeff != 0.0 &&
278 row.add(diff) < 0.0 &&
279 symbol.type_() != SymbolType::External
280 {
281 infeasible_rows.push(*symbol);
282 }
283 }
284 }
285 }
286 try!(self.dual_optimise().map_err(|e| SuggestValueError::InternalSolverError(e.0)));
287 return Ok(());
288 }
289
290 fn var_changed(&mut self, v: Variable) {
291 if self.should_clear_changes {
292 self.changed.clear();
293 self.should_clear_changes = false;
294 }
295 self.changed.insert(v);
296 }
297
298 pub fn fetch_changes(&mut self) -> &[(Variable, f64)] {
303 if self.should_clear_changes {
304 self.changed.clear();
305 self.should_clear_changes = false;
306 } else {
307 self.should_clear_changes = true;
308 }
309 self.public_changes.clear();
310 for &v in &self.changed {
311 if let Some(var_data) = self.var_data.get_mut(&v) {
312 let new_value = self.rows.get(&var_data.1).map(|r| r.constant).unwrap_or(0.0);
313 let old_value = var_data.0;
314 if old_value != new_value {
315 self.public_changes.push((v, new_value));
316 var_data.0 = new_value;
317 }
318 }
319 }
320 &self.public_changes
321 }
322
323 pub fn reset(&mut self) {
331 self.rows.clear();
332 self.cns.clear();
333 self.var_data.clear();
334 self.var_for_symbol.clear();
335 self.changed.clear();
336 self.should_clear_changes = false;
337 self.edits.clear();
338 self.infeasible_rows.clear();
339 *self.objective.borrow_mut() = Row::new(0.0);
340 self.artificial = None;
341 self.id_tick = 1;
342 }
343
344 fn get_var_symbol(&mut self, v: Variable) -> Symbol {
348 let id_tick = &mut self.id_tick;
349 let var_for_symbol = &mut self.var_for_symbol;
350 let value = self.var_data.entry(v).or_insert_with(|| {
351 let s = Symbol(*id_tick, SymbolType::External);
352 var_for_symbol.insert(s, v);
353 *id_tick += 1;
354 (::std::f64::NAN, s, 0)
355 });
356 value.2 += 1;
357 value.1
358 }
359
360 fn create_row(&mut self, constraint: &Constraint) -> (Box<Row>, Tag) {
376 let expr = constraint.expr();
377 let mut row = Row::new(expr.constant);
378 for term in &expr.terms {
380 if !near_zero(term.coefficient) {
381 let symbol = self.get_var_symbol(term.variable);
382 if let Some(other_row) = self.rows.get(&symbol) {
383 row.insert_row(other_row, term.coefficient);
384 } else {
385 row.insert_symbol(symbol, term.coefficient);
386 }
387 }
388 }
389
390 let mut objective = self.objective.borrow_mut();
391
392 let tag = match constraint.op() {
394 RelationalOperator::GreaterOrEqual |
395 RelationalOperator::LessOrEqual => {
396 let coeff = if constraint.op() == RelationalOperator::LessOrEqual {
397 1.0
398 } else {
399 -1.0
400 };
401 let slack = Symbol(self.id_tick, SymbolType::Slack);
402 self.id_tick += 1;
403 row.insert_symbol(slack, coeff);
404 if constraint.strength() < ::strength::REQUIRED {
405 let error = Symbol(self.id_tick, SymbolType::Error);
406 self.id_tick += 1;
407 row.insert_symbol(error, -coeff);
408 objective.insert_symbol(error, constraint.strength());
409 Tag {
410 marker: slack,
411 other: error
412 }
413 } else {
414 Tag {
415 marker: slack,
416 other: Symbol::invalid()
417 }
418 }
419 }
420 RelationalOperator::Equal => {
421 if constraint.strength() < ::strength::REQUIRED {
422 let errplus = Symbol(self.id_tick, SymbolType::Error);
423 self.id_tick += 1;
424 let errminus = Symbol(self.id_tick, SymbolType::Error);
425 self.id_tick += 1;
426 row.insert_symbol(errplus, -1.0); row.insert_symbol(errminus, 1.0); objective.insert_symbol(errplus, constraint.strength());
429 objective.insert_symbol(errminus, constraint.strength());
430 Tag {
431 marker: errplus,
432 other: errminus
433 }
434 } else {
435 let dummy = Symbol(self.id_tick, SymbolType::Dummy);
436 self.id_tick += 1;
437 row.insert_symbol(dummy, 1.0);
438 Tag {
439 marker: dummy,
440 other: Symbol::invalid()
441 }
442 }
443 }
444 };
445
446 if row.constant < 0.0 {
448 row.reverse_sign();
449 }
450 (Box::new(row), tag)
451 }
452
453 fn choose_subject(row: &Row, tag: &Tag) -> Symbol {
466 for s in row.cells.keys() {
467 if s.type_() == SymbolType::External {
468 return *s
469 }
470 }
471 if tag.marker.type_() == SymbolType::Slack || tag.marker.type_() == SymbolType::Error {
472 if row.coefficient_for(tag.marker) < 0.0 {
473 return tag.marker;
474 }
475 }
476 if tag.other.type_() == SymbolType::Slack || tag.other.type_() == SymbolType::Error {
477 if row.coefficient_for(tag.other) < 0.0 {
478 return tag.other;
479 }
480 }
481 Symbol::invalid()
482 }
483
484 fn add_with_artificial_variable(&mut self, row: &Row) -> Result<bool, InternalSolverError> {
488 let art = Symbol(self.id_tick, SymbolType::Slack);
490 self.id_tick += 1;
491 self.rows.insert(art, Box::new(row.clone()));
492 self.artificial = Some(Rc::new(RefCell::new(row.clone())));
493
494 let artificial = self.artificial.as_ref().unwrap().clone();
497 try!(self.optimise(&artificial));
498 let success = near_zero(artificial.borrow().constant);
499 self.artificial = None;
500
501 if let Some(mut row) = self.rows.remove(&art) {
504 if row.cells.is_empty() {
505 return Ok(success);
506 }
507 let entering = Solver::any_pivotable_symbol(&row); if entering.type_() == SymbolType::Invalid {
509 return Ok(false); }
511 row.solve_for_symbols(art, entering);
512 self.substitute(entering, &row);
513 self.rows.insert(entering, row);
514 }
515
516 for (_, row) in &mut self.rows {
518 row.remove(art);
519 }
520 self.objective.borrow_mut().remove(art);
521 Ok(success)
522 }
523
524 fn substitute(&mut self, symbol: Symbol, row: &Row) {
529 for (&other_symbol, other_row) in &mut self.rows {
530 let constant_changed = other_row.substitute(symbol, row);
531 if other_symbol.type_() == SymbolType::External && constant_changed {
532 let v = self.var_for_symbol[&other_symbol];
533 if self.should_clear_changes {
535 self.changed.clear();
536 self.should_clear_changes = false;
537 }
538 self.changed.insert(v);
539 }
540 if other_symbol.type_() != SymbolType::External && other_row.constant < 0.0 {
541 self.infeasible_rows.push(other_symbol);
542 }
543 }
544 self.objective.borrow_mut().substitute(symbol, row);
545 if let Some(artificial) = self.artificial.as_ref() {
546 artificial.borrow_mut().substitute(symbol, row);
547 }
548 }
549
550 fn optimise(&mut self, objective: &RefCell<Row>) -> Result<(), InternalSolverError> {
555 loop {
556 let entering = Solver::get_entering_symbol(&objective.borrow());
557 if entering.type_() == SymbolType::Invalid {
558 return Ok(());
559 }
560 let (leaving, mut row) = try!(self.get_leaving_row(entering)
561 .ok_or(InternalSolverError("The objective is unbounded")));
562 row.solve_for_symbols(leaving, entering);
564 self.substitute(entering, &row);
565 if entering.type_() == SymbolType::External && row.constant != 0.0 {
566 let v = self.var_for_symbol[&entering];
567 self.var_changed(v);
568 }
569 self.rows.insert(entering, row);
570 }
571 }
572
573 fn dual_optimise(&mut self) -> Result<(), InternalSolverError> {
580 while !self.infeasible_rows.is_empty() {
581 let leaving = self.infeasible_rows.pop().unwrap();
582
583 let row = if let Entry::Occupied(entry) = self.rows.entry(leaving) {
584 if entry.get().constant < 0.0 {
585 Some(entry.remove())
586 } else {
587 None
588 }
589 } else {
590 None
591 };
592 if let Some(mut row) = row {
593 let entering = self.get_dual_entering_symbol(&row);
594 if entering.type_() == SymbolType::Invalid {
595 return Err(InternalSolverError("Dual optimise failed."));
596 }
597 row.solve_for_symbols(leaving, entering);
599 self.substitute(entering, &row);
600 if entering.type_() == SymbolType::External && row.constant != 0.0 {
601 let v = self.var_for_symbol[&entering];
602 self.var_changed(v);
603 }
604 self.rows.insert(entering, row);
605 }
606 }
607 Ok(())
608 }
609
610 fn get_entering_symbol(objective: &Row) -> Symbol {
618 for (symbol, value) in &objective.cells {
619 if symbol.type_() != SymbolType::Dummy && *value < 0.0 {
620 return *symbol;
621 }
622 }
623 Symbol::invalid()
624 }
625
626 fn get_dual_entering_symbol(&self, row: &Row) -> Symbol {
635 let mut entering = Symbol::invalid();
636 let mut ratio = ::std::f64::INFINITY;
637 let objective = self.objective.borrow();
638 for (symbol, value) in &row.cells {
639 if *value > 0.0 && symbol.type_() != SymbolType::Dummy {
640 let coeff = objective.coefficient_for(*symbol);
641 let r = coeff / *value;
642 if r < ratio {
643 ratio = r;
644 entering = *symbol;
645 }
646 }
647 }
648 entering
649 }
650
651 fn any_pivotable_symbol(row: &Row) -> Symbol {
656 for symbol in row.cells.keys() {
657 if symbol.type_() == SymbolType::Slack || symbol.type_() == SymbolType::Error {
658 return *symbol;
659 }
660 }
661 Symbol::invalid()
662 }
663
664 fn get_leaving_row(&mut self, entering: Symbol) -> Option<(Symbol, Box<Row>)> {
672 let mut ratio = ::std::f64::INFINITY;
673 let mut found = None;
674 for (symbol, row) in &self.rows {
675 if symbol.type_() != SymbolType::External {
676 let temp = row.coefficient_for(entering);
677 if temp < 0.0 {
678 let temp_ratio = -row.constant / temp;
679 if temp_ratio < ratio {
680 ratio = temp_ratio;
681 found = Some(*symbol);
682 }
683 }
684 }
685 }
686 found.map(|s| (s, self.rows.remove(&s).unwrap()))
687 }
688
689 fn get_marker_leaving_row(&mut self, marker: Symbol) -> Option<(Symbol, Box<Row>)> {
707 let mut r1 = ::std::f64::INFINITY;
708 let mut r2 = r1;
709 let mut first = None;
710 let mut second = None;
711 let mut third = None;
712 for (symbol, row) in &self.rows {
713 let c = row.coefficient_for(marker);
714 if c == 0.0 {
715 continue;
716 }
717 if symbol.type_() == SymbolType::External {
718 third = Some(*symbol);
719 } else if c < 0.0 {
720 let r = -row.constant / c;
721 if r < r1 {
722 r1 = r;
723 first = Some(*symbol);
724 }
725 } else {
726 let r = row.constant / c;
727 if r < r2 {
728 r2 = r;
729 second = Some(*symbol);
730 }
731 }
732 }
733 first
734 .or(second)
735 .or(third)
736 .and_then(|s| {
737 if s.type_() == SymbolType::External && self.rows[&s].constant != 0.0 {
738 let v = self.var_for_symbol[&s];
739 self.var_changed(v);
740 }
741 self.rows
742 .remove(&s)
743 .map(|r| (s, r))
744 })
745 }
746
747 fn remove_constraint_effects(&mut self, cn: &Constraint, tag: &Tag) {
749 if tag.marker.type_() == SymbolType::Error {
750 self.remove_marker_effects(tag.marker, cn.strength());
751 } else if tag.other.type_() == SymbolType::Error {
752 self.remove_marker_effects(tag.other, cn.strength());
753 }
754 }
755
756 fn remove_marker_effects(&mut self, marker: Symbol, strength: f64) {
758 if let Some(row) = self.rows.get(&marker) {
759 self.objective.borrow_mut().insert_row(row, -strength);
760 } else {
761 self.objective.borrow_mut().insert_symbol(marker, -strength);
762 }
763 }
764
765 fn all_dummies(row: &Row) -> bool {
767 for symbol in row.cells.keys() {
768 if symbol.type_() != SymbolType::Dummy {
769 return false;
770 }
771 }
772 true
773 }
774
775 pub fn get_value(&self, v: Variable) -> f64 {
780 self.var_data.get(&v).and_then(|s| {
781 self.rows.get(&s.1).map(|r| r.constant)
782 }).unwrap_or(0.0)
783 }
784}