1use crate::function::gamma;
5use crate::prec;
6use std::f64;
7
8#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
11#[non_exhaustive]
12pub enum BetaFuncError {
13 ANotGreaterThanZero,
15
16 BNotGreaterThanZero,
18
19 XOutOfRange,
21}
22
23impl std::fmt::Display for BetaFuncError {
24 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
25 match self {
26 BetaFuncError::ANotGreaterThanZero => write!(f, "a is zero or less than zero"),
27 BetaFuncError::BNotGreaterThanZero => write!(f, "b is zero or less than zero"),
28 BetaFuncError::XOutOfRange => write!(f, "x is not in [0, 1]"),
29 }
30 }
31}
32
33impl std::error::Error for BetaFuncError {}
34
35pub fn ln_beta(a: f64, b: f64) -> f64 {
45 checked_ln_beta(a, b).unwrap()
46}
47
48pub fn checked_ln_beta(a: f64, b: f64) -> Result<f64, BetaFuncError> {
58 if a <= 0.0 {
59 Err(BetaFuncError::ANotGreaterThanZero)
60 } else if b <= 0.0 {
61 Err(BetaFuncError::BNotGreaterThanZero)
62 } else {
63 Ok(gamma::ln_gamma(a) + gamma::ln_gamma(b) - gamma::ln_gamma(a + b))
64 }
65}
66
67pub fn beta(a: f64, b: f64) -> f64 {
76 checked_beta(a, b).unwrap()
77}
78
79pub fn checked_beta(a: f64, b: f64) -> Result<f64, BetaFuncError> {
88 checked_ln_beta(a, b).map(|x| x.exp())
89}
90
91pub fn beta_inc(a: f64, b: f64, x: f64) -> f64 {
100 checked_beta_inc(a, b, x).unwrap()
101}
102
103pub fn checked_beta_inc(a: f64, b: f64, x: f64) -> Result<f64, BetaFuncError> {
112 checked_beta_reg(a, b, x).and_then(|x| checked_beta(a, b).map(|y| x * y))
113}
114
115pub fn beta_reg(a: f64, b: f64, x: f64) -> f64 {
125 checked_beta_reg(a, b, x).unwrap()
126}
127
128pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result<f64, BetaFuncError> {
138 if a <= 0.0 {
139 return Err(BetaFuncError::ANotGreaterThanZero);
140 }
141
142 if b <= 0.0 {
143 return Err(BetaFuncError::BNotGreaterThanZero);
144 }
145
146 if !(0.0..=1.0).contains(&x) {
147 return Err(BetaFuncError::XOutOfRange);
148 }
149
150 let bt = if x == 0.0 || ulps_eq!(x, 1.0) {
151 0.0
152 } else {
153 (gamma::ln_gamma(a + b) - gamma::ln_gamma(a) - gamma::ln_gamma(b)
154 + a * x.ln()
155 + b * (1.0 - x).ln())
156 .exp()
157 };
158 let symm_transform = x >= (a + 1.0) / (a + b + 2.0);
159 let eps = prec::F64_PREC;
160 let fpmin = f64::MIN_POSITIVE / eps;
161
162 let mut a = a;
163 let mut b = b;
164 let mut x = x;
165 if symm_transform {
166 let swap = a;
167 x = 1.0 - x;
168 a = b;
169 b = swap;
170 }
171
172 let qab = a + b;
173 let qap = a + 1.0;
174 let qam = a - 1.0;
175 let mut c = 1.0;
176 let mut d = 1.0 - qab * x / qap;
177
178 if d.abs() < fpmin {
179 d = fpmin;
180 }
181 d = 1.0 / d;
182 let mut h = d;
183
184 for m in 1..141 {
185 let m = f64::from(m);
186 let m2 = m * 2.0;
187 let mut aa = m * (b - m) * x / ((qam + m2) * (a + m2));
188 d = 1.0 + aa * d;
189
190 if d.abs() < fpmin {
191 d = fpmin;
192 }
193
194 c = 1.0 + aa / c;
195 if c.abs() < fpmin {
196 c = fpmin;
197 }
198
199 d = 1.0 / d;
200 h = h * d * c;
201 aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
202 d = 1.0 + aa * d;
203
204 if d.abs() < fpmin {
205 d = fpmin;
206 }
207
208 c = 1.0 + aa / c;
209
210 if c.abs() < fpmin {
211 c = fpmin;
212 }
213
214 d = 1.0 / d;
215 let del = d * c;
216 h *= del;
217
218 if (del - 1.0).abs() <= eps {
219 return if symm_transform {
220 Ok(1.0 - bt * h / a)
221 } else {
222 Ok(bt * h / a)
223 };
224 }
225 }
226
227 if symm_transform {
228 Ok(1.0 - bt * h / a)
229 } else {
230 Ok(bt * h / a)
231 }
232}
233
234pub fn inv_beta_reg(mut a: f64, mut b: f64, mut x: f64) -> f64 {
264 let ln_beta = ln_beta(a, b);
299
300 const SAE: i32 = -30;
303 const FPU: f64 = 1e-30; debug_assert!((0.0..=1.0).contains(&x) && a > 0.0 && b > 0.0);
306
307 if x == 0.0 {
308 return 0.0;
309 }
310 if x == 1.0 {
311 return 1.0;
312 }
313
314 let mut p;
315 let mut q;
316
317 let flip = 0.5 < x;
318 if flip {
319 p = a;
320 a = b;
321 b = p;
322 x = 1.0 - x;
323 }
324
325 p = (-(x * x).ln()).sqrt();
326 q = p - (2.30753 + 0.27061 * p) / (1.0 + (0.99229 + 0.04481 * p) * p);
327
328 if 1.0 < a && 1.0 < b {
329 let r = (q * q - 3.0) / 6.0;
337 let s = 1.0 / (2.0 * a - 1.0);
338 let t = 1.0 / (2.0 * b - 1.0);
339 let h = 2.0 / (s + t);
340 let w = q * (h + r).sqrt() / h - (t - s) * (r + 5.0 / 6.0 - 2.0 / (3.0 * h));
341 p = a / (a + b * (2.0 * w).exp());
342 } else {
343 let mut t = 1.0 / (9.0 * b);
344 t = 2.0 * b * (1.0 - t + q * t.sqrt()).powf(3.0);
345 if t <= 0.0 {
346 p = 1.0 - ((((1.0 - x) * b).ln() + ln_beta) / b).exp();
347 } else {
348 t = 2.0 * (2.0 * a + b - 1.0) / t;
349 if t <= 1.0 {
350 p = (((x * a).ln() + ln_beta) / a).exp();
351 } else {
352 p = 1.0 - 2.0 / (t + 1.0);
353 }
354 }
355 }
356
357 p = p.clamp(0.0001, 0.9999);
358
359 let e = (-5.0 / a / a - 1.0 / x.powf(0.2) - 13.0) as i32;
362 let acu = if e > SAE { f64::powi(10.0, e) } else { FPU };
363
364 let mut pnext;
365 let mut qprev = 0.0;
366 let mut sq = 1.0;
367 let mut prev = 1.0;
368
369 'outer: loop {
370 q = beta_reg(a, b, p);
373 q = (q - x) * (ln_beta + (1.0 - a) * p.ln() + (1.0 - b) * (1.0 - p).ln()).exp();
374
375 if q * qprev <= 0.0 {
378 prev = if sq > FPU { sq } else { FPU };
379 }
380
381 let mut g = 1.0;
384 loop {
385 loop {
386 let adj = g * q;
387 sq = adj * adj;
388
389 if sq < prev {
390 pnext = p - adj;
391 if (0.0..=1.0).contains(&pnext) {
392 break;
393 }
394 }
395 g /= 3.0;
396 }
397
398 if prev <= acu || q * q <= acu {
399 p = pnext;
400 break 'outer;
401 }
402
403 if pnext != 0.0 && pnext != 1.0 {
404 break;
405 }
406
407 g /= 3.0;
408 }
409
410 if pnext == p {
411 break;
412 }
413
414 p = pnext;
415 qprev = q;
416 }
417
418 if flip {
419 1.0 - p
420 } else {
421 p
422 }
423}
424
425#[rustfmt::skip]
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_ln_beta() {
432 assert_almost_eq!(super::ln_beta(0.5, 0.5), 1.144729885849400174144, 1e-15);
433 assert_almost_eq!(super::ln_beta(1.0, 0.5), 0.6931471805599453094172, 1e-14);
434 assert_almost_eq!(super::ln_beta(2.5, 0.5), 0.163900632837673937284, 1e-15);
435 assert_almost_eq!(super::ln_beta(0.5, 1.0), 0.6931471805599453094172, 1e-14);
436 assert_almost_eq!(super::ln_beta(1.0, 1.0), 0.0, 1e-15);
437 assert_almost_eq!(super::ln_beta(2.5, 1.0), -0.9162907318741550651835, 1e-14);
438 assert_almost_eq!(super::ln_beta(0.5, 2.5), 0.163900632837673937284, 1e-15);
439 assert_almost_eq!(super::ln_beta(1.0, 2.5), -0.9162907318741550651835, 1e-14);
440 assert_almost_eq!(super::ln_beta(2.5, 2.5), -2.608688089402107300388, 1e-14);
441 }
442
443 #[test]
444 #[should_panic]
445 fn test_ln_beta_a_lte_0() {
446 super::ln_beta(0.0, 0.5);
447 }
448
449 #[test]
450 #[should_panic]
451 fn test_ln_beta_b_lte_0() {
452 super::ln_beta(0.5, 0.0);
453 }
454
455 #[test]
456 fn test_checked_ln_beta_a_lte_0() {
457 assert!(super::checked_ln_beta(0.0, 0.5).is_err());
458 }
459
460 #[test]
461 fn test_checked_ln_beta_b_lte_0() {
462 assert!(super::checked_ln_beta(0.5, 0.0).is_err());
463 }
464
465 #[test]
466 #[should_panic]
467 fn test_beta_a_lte_0() {
468 super::beta(0.0, 0.5);
469 }
470
471 #[test]
472 #[should_panic]
473 fn test_beta_b_lte_0() {
474 super::beta(0.5, 0.0);
475 }
476
477 #[test]
478 fn test_checked_beta_a_lte_0() {
479 assert!(super::checked_beta(0.0, 0.5).is_err());
480 }
481
482 #[test]
483 fn test_checked_beta_b_lte_0() {
484 assert!(super::checked_beta(0.5, 0.0).is_err());
485 }
486
487 #[test]
488 fn test_beta() {
489 assert_almost_eq!(super::beta(0.5, 0.5), 3.141592653589793238463, 1e-15);
490 assert_almost_eq!(super::beta(1.0, 0.5), 2.0, 1e-14);
491 assert_almost_eq!(super::beta(2.5, 0.5), 1.17809724509617246442, 1e-15);
492 assert_almost_eq!(super::beta(0.5, 1.0), 2.0, 1e-14);
493 assert_almost_eq!(super::beta(1.0, 1.0), 1.0, 1e-15);
494 assert_almost_eq!(super::beta(2.5, 1.0), 0.4, 1e-14);
495 assert_almost_eq!(super::beta(0.5, 2.5), 1.17809724509617246442, 1e-15);
496 assert_almost_eq!(super::beta(1.0, 2.5), 0.4, 1e-14);
497 assert_almost_eq!(super::beta(2.5, 2.5), 0.073631077818510779026, 1e-15);
498 }
499
500 #[test]
501 fn test_beta_inc() {
502 assert_almost_eq!(super::beta_inc(0.5, 0.5, 0.5), 1.570796326794896619231, 1e-14);
503 assert_almost_eq!(super::beta_inc(0.5, 0.5, 1.0), 3.141592653589793238463, 1e-15);
504 assert_almost_eq!(super::beta_inc(1.0, 0.5, 0.5), 0.5857864376269049511983, 1e-15);
505 assert_almost_eq!(super::beta_inc(1.0, 0.5, 1.0), 2.0, 1e-14);
506 assert_almost_eq!(super::beta_inc(2.5, 0.5, 0.5), 0.0890486225480862322117, 1e-16);
507 assert_almost_eq!(super::beta_inc(2.5, 0.5, 1.0), 1.17809724509617246442, 1e-15);
508 assert_almost_eq!(super::beta_inc(0.5, 1.0, 0.5), 1.414213562373095048802, 1e-14);
509 assert_almost_eq!(super::beta_inc(0.5, 1.0, 1.0), 2.0, 1e-14);
510 assert_almost_eq!(super::beta_inc(1.0, 1.0, 0.5), 0.5, 1e-15);
511 assert_almost_eq!(super::beta_inc(1.0, 1.0, 1.0), 1.0, 1e-15);
512 assert_eq!(super::beta_inc(2.5, 1.0, 0.5), 0.0707106781186547524401);
513 assert_almost_eq!(super::beta_inc(2.5, 1.0, 1.0), 0.4, 1e-14);
514 assert_almost_eq!(super::beta_inc(0.5, 2.5, 0.5), 1.08904862254808623221, 1e-15);
515 assert_almost_eq!(super::beta_inc(0.5, 2.5, 1.0), 1.17809724509617246442, 1e-15);
516 assert_almost_eq!(super::beta_inc(1.0, 2.5, 0.5), 0.32928932188134524756, 1e-14);
517 assert_almost_eq!(super::beta_inc(1.0, 2.5, 1.0), 0.4, 1e-14);
518 assert_almost_eq!(super::beta_inc(2.5, 2.5, 0.5), 0.03681553890925538951323, 1e-15);
519 assert_almost_eq!(super::beta_inc(2.5, 2.5, 1.0), 0.073631077818510779026, 1e-15);
520 }
521
522 #[test]
523 #[should_panic]
524 fn test_beta_inc_a_lte_0() {
525 super::beta_inc(0.0, 1.0, 1.0);
526 }
527
528 #[test]
529 #[should_panic]
530 fn test_beta_inc_b_lte_0() {
531 super::beta_inc(1.0, 0.0, 1.0);
532 }
533
534 #[test]
535 #[should_panic]
536 fn test_beta_inc_x_lt_0() {
537 super::beta_inc(1.0, 1.0, -1.0);
538 }
539
540 #[test]
541 #[should_panic]
542 fn test_beta_inc_x_gt_1() {
543 super::beta_inc(1.0, 1.0, 2.0);
544 }
545
546 #[test]
547 fn test_checked_beta_inc_a_lte_0() {
548 assert!(super::checked_beta_inc(0.0, 1.0, 1.0).is_err());
549 }
550
551 #[test]
552 fn test_checked_beta_inc_b_lte_0() {
553 assert!(super::checked_beta_inc(1.0, 0.0, 1.0).is_err());
554 }
555
556 #[test]
557 fn test_checked_beta_inc_x_lt_0() {
558 assert!(super::checked_beta_inc(1.0, 1.0, -1.0).is_err());
559 }
560
561 #[test]
562 fn test_checked_beta_inc_x_gt_1() {
563 assert!(super::checked_beta_inc(1.0, 1.0, 2.0).is_err());
564 }
565
566 #[test]
567 fn test_beta_reg() {
568 assert_almost_eq!(super::beta_reg(0.5, 0.5, 0.5), 0.5, 1e-15);
569 assert_eq!(super::beta_reg(0.5, 0.5, 1.0), 1.0);
570 assert_almost_eq!(super::beta_reg(1.0, 0.5, 0.5), 0.292893218813452475599, 1e-15);
571 assert_eq!(super::beta_reg(1.0, 0.5, 1.0), 1.0);
572 assert_almost_eq!(super::beta_reg(2.5, 0.5, 0.5), 0.07558681842161243795, 1e-16);
573 assert_eq!(super::beta_reg(2.5, 0.5, 1.0), 1.0);
574 assert_almost_eq!(super::beta_reg(0.5, 1.0, 0.5), 0.7071067811865475244, 1e-15);
575 assert_eq!(super::beta_reg(0.5, 1.0, 1.0), 1.0);
576 assert_almost_eq!(super::beta_reg(1.0, 1.0, 0.5), 0.5, 1e-15);
577 assert_eq!(super::beta_reg(1.0, 1.0, 1.0), 1.0);
578 assert_almost_eq!(super::beta_reg(2.5, 1.0, 0.5), 0.1767766952966368811, 1e-15);
579 assert_eq!(super::beta_reg(2.5, 1.0, 1.0), 1.0);
580 assert_eq!(super::beta_reg(0.5, 2.5, 0.5), 0.92441318157838756205);
581 assert_eq!(super::beta_reg(0.5, 2.5, 1.0), 1.0);
582 assert_almost_eq!(super::beta_reg(1.0, 2.5, 0.5), 0.8232233047033631189, 1e-15);
583 assert_eq!(super::beta_reg(1.0, 2.5, 1.0), 1.0);
584 assert_almost_eq!(super::beta_reg(2.5, 2.5, 0.5), 0.5, 1e-15);
585 assert_eq!(super::beta_reg(2.5, 2.5, 1.0), 1.0);
586 }
587
588 #[test]
589 #[should_panic]
590 fn test_beta_reg_a_lte_0() {
591 super::beta_reg(0.0, 1.0, 1.0);
592 }
593
594 #[test]
595 #[should_panic]
596 fn test_beta_reg_b_lte_0() {
597 super::beta_reg(1.0, 0.0, 1.0);
598 }
599
600 #[test]
601 #[should_panic]
602 fn test_beta_reg_x_lt_0() {
603 super::beta_reg(1.0, 1.0, -1.0);
604 }
605
606 #[test]
607 #[should_panic]
608 fn test_beta_reg_x_gt_1() {
609 super::beta_reg(1.0, 1.0, 2.0);
610 }
611
612 #[test]
613 fn test_checked_beta_reg_a_lte_0() {
614 assert!(super::checked_beta_reg(0.0, 1.0, 1.0).is_err());
615 }
616
617 #[test]
618 fn test_checked_beta_reg_b_lte_0() {
619 assert!(super::checked_beta_reg(1.0, 0.0, 1.0).is_err());
620 }
621
622 #[test]
623 fn test_checked_beta_reg_x_lt_0() {
624 assert!(super::checked_beta_reg(1.0, 1.0, -1.0).is_err());
625 }
626
627 #[test]
628 fn test_checked_beta_reg_x_gt_1() {
629 assert!(super::checked_beta_reg(1.0, 1.0, 2.0).is_err());
630 }
631
632 #[test]
633 fn test_error_is_sync_send() {
634 fn assert_sync_send<T: Sync + Send>() {}
635 assert_sync_send::<BetaFuncError>();
636 }
637}