1use crate::error::StatsError;
5use crate::function::gamma;
6use crate::is_zero;
7use crate::prec;
8use crate::Result;
9use std::f64;
10
11pub fn ln_beta(a: f64, b: f64) -> f64 {
21 checked_ln_beta(a, b).unwrap()
22}
23
24pub fn checked_ln_beta(a: f64, b: f64) -> Result<f64> {
34 if a <= 0.0 {
35 Err(StatsError::ArgMustBePositive("a"))
36 } else if b <= 0.0 {
37 Err(StatsError::ArgMustBePositive("b"))
38 } else {
39 Ok(gamma::ln_gamma(a) + gamma::ln_gamma(b) - gamma::ln_gamma(a + b))
40 }
41}
42
43pub fn beta(a: f64, b: f64) -> f64 {
52 checked_beta(a, b).unwrap()
53}
54
55pub fn checked_beta(a: f64, b: f64) -> Result<f64> {
64 checked_ln_beta(a, b).map(|x| x.exp())
65}
66
67pub fn beta_inc(a: f64, b: f64, x: f64) -> f64 {
76 checked_beta_inc(a, b, x).unwrap()
77}
78
79pub fn checked_beta_inc(a: f64, b: f64, x: f64) -> Result<f64> {
88 checked_beta_reg(a, b, x).and_then(|x| checked_beta(a, b).map(|y| x * y))
89}
90
91pub fn beta_reg(a: f64, b: f64, x: f64) -> f64 {
101 checked_beta_reg(a, b, x).unwrap()
102}
103
104pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result<f64> {
114 if a <= 0.0 {
115 Err(StatsError::ArgMustBePositive("a"))
116 } else if b <= 0.0 {
117 Err(StatsError::ArgMustBePositive("b"))
118 } else if !(0.0..=1.0).contains(&x) {
119 Err(StatsError::ArgIntervalIncl("x", 0.0, 1.0))
120 } else {
121 let bt = if is_zero(x) || ulps_eq!(x, 1.0) {
122 0.0
123 } else {
124 (gamma::ln_gamma(a + b) - gamma::ln_gamma(a) - gamma::ln_gamma(b)
125 + a * x.ln()
126 + b * (1.0 - x).ln())
127 .exp()
128 };
129 let symm_transform = x >= (a + 1.0) / (a + b + 2.0);
130 let eps = prec::F64_PREC;
131 let fpmin = f64::MIN_POSITIVE / eps;
132
133 let mut a = a;
134 let mut b = b;
135 let mut x = x;
136 if symm_transform {
137 let swap = a;
138 x = 1.0 - x;
139 a = b;
140 b = swap;
141 }
142
143 let qab = a + b;
144 let qap = a + 1.0;
145 let qam = a - 1.0;
146 let mut c = 1.0;
147 let mut d = 1.0 - qab * x / qap;
148
149 if d.abs() < fpmin {
150 d = fpmin;
151 }
152 d = 1.0 / d;
153 let mut h = d;
154
155 for m in 1..141 {
156 let m = f64::from(m);
157 let m2 = m * 2.0;
158 let mut aa = m * (b - m) * x / ((qam + m2) * (a + m2));
159 d = 1.0 + aa * d;
160
161 if d.abs() < fpmin {
162 d = fpmin;
163 }
164
165 c = 1.0 + aa / c;
166 if c.abs() < fpmin {
167 c = fpmin;
168 }
169
170 d = 1.0 / d;
171 h = h * d * c;
172 aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
173 d = 1.0 + aa * d;
174
175 if d.abs() < fpmin {
176 d = fpmin;
177 }
178
179 c = 1.0 + aa / c;
180
181 if c.abs() < fpmin {
182 c = fpmin;
183 }
184
185 d = 1.0 / d;
186 let del = d * c;
187 h *= del;
188
189 if (del - 1.0).abs() <= eps {
190 return if symm_transform {
191 Ok(1.0 - bt * h / a)
192 } else {
193 Ok(bt * h / a)
194 };
195 }
196 }
197
198 if symm_transform {
199 Ok(1.0 - bt * h / a)
200 } else {
201 Ok(bt * h / a)
202 }
203 }
204}
205
206pub fn inv_beta_reg(mut a: f64, mut b: f64, mut x: f64) -> f64 {
237 let ln_beta = ln_beta(a, b);
272
273 const SAE: i32 = -30;
276 const FPU: f64 = 1e-30; debug_assert!((0.0..=1.0).contains(&x) && a > 0.0 && b > 0.0);
279
280 if x == 0.0 {
281 return 0.0;
282 }
283 if x == 1.0 {
284 return 1.0;
285 }
286
287 let mut p;
288 let mut q;
289
290 let flip = 0.5 < x;
291 if flip {
292 p = a;
293 a = b;
294 b = p;
295 x = 1.0 - x;
296 }
297
298 p = (-(x * x).ln()).sqrt();
299 q = p - (2.30753 + 0.27061 * p) / (1.0 + (0.99229 + 0.04481 * p) * p);
300
301 if 1.0 < a && 1.0 < b {
302 let r = (q * q - 3.0) / 6.0;
310 let s = 1.0 / (2.0 * a - 1.0);
311 let t = 1.0 / (2.0 * b - 1.0);
312 let h = 2.0 / (s + t);
313 let w = q * (h + r).sqrt() / h - (t - s) * (r + 5.0 / 6.0 - 2.0 / (3.0 * h));
314 p = a / (a + b * (2.0 * w).exp());
315 } else {
316 let mut t = 1.0 / (9.0 * b);
317 t = 2.0 * b * (1.0 - t + q * t.sqrt()).powf(3.0);
318 if t <= 0.0 {
319 p = 1.0 - ((((1.0 - x) * b).ln() + ln_beta) / b).exp();
320 } else {
321 t = 2.0 * (2.0 * a + b - 1.0) / t;
322 if t <= 1.0 {
323 p = (((x * a).ln() + ln_beta) / a).exp();
324 } else {
325 p = 1.0 - 2.0 / (t + 1.0);
326 }
327 }
328 }
329
330 if p < 0.0001 {
331 p = 0.0001;
332 } else if 0.9999 < p {
333 p = 0.9999;
334 }
335
336 let e = (-5.0 / a / a - 1.0 / x.powf(0.2) - 13.0) as i32;
339 let acu = if e > SAE { f64::powi(10.0, e) } else { FPU };
340
341 let mut pnext;
342 let mut qprev = 0.0;
343 let mut sq = 1.0;
344 let mut prev = 1.0;
345
346 'outer: loop {
347 q = beta_reg(a, b, p);
350 q = (q - x) * (ln_beta + (1.0 - a) * p.ln() + (1.0 - b) * (1.0 - p).ln()).exp();
351
352 if q * qprev <= 0.0 {
355 prev = if sq > FPU { sq } else { FPU };
356 }
357
358 let mut g = 1.0;
361 loop {
362 loop {
363 let adj = g * q;
364 sq = adj * adj;
365
366 if sq < prev {
367 pnext = p - adj;
368 if 0.0 <= pnext && pnext <= 1.0 {
369 break;
370 }
371 }
372 g /= 3.0;
373 }
374
375 if prev <= acu || q * q <= acu {
376 p = pnext;
377 break 'outer;
378 }
379
380 if pnext != 0.0 && pnext != 1.0 {
381 break;
382 }
383
384 g /= 3.0;
385 }
386
387 if pnext == p {
388 break;
389 }
390
391 p = pnext;
392 qprev = q;
393 }
394
395 if flip {
396 1.0 - p
397 } else {
398 p
399 }
400}
401
402#[rustfmt::skip]
403#[cfg(test)]
404mod tests {
405 #[test]
406 fn test_ln_beta() {
407 assert_almost_eq!(super::ln_beta(0.5, 0.5), 1.144729885849400174144, 1e-15);
408 assert_almost_eq!(super::ln_beta(1.0, 0.5), 0.6931471805599453094172, 1e-14);
409 assert_almost_eq!(super::ln_beta(2.5, 0.5), 0.163900632837673937284, 1e-15);
410 assert_almost_eq!(super::ln_beta(0.5, 1.0), 0.6931471805599453094172, 1e-14);
411 assert_almost_eq!(super::ln_beta(1.0, 1.0), 0.0, 1e-15);
412 assert_almost_eq!(super::ln_beta(2.5, 1.0), -0.9162907318741550651835, 1e-14);
413 assert_almost_eq!(super::ln_beta(0.5, 2.5), 0.163900632837673937284, 1e-15);
414 assert_almost_eq!(super::ln_beta(1.0, 2.5), -0.9162907318741550651835, 1e-14);
415 assert_almost_eq!(super::ln_beta(2.5, 2.5), -2.608688089402107300388, 1e-14);
416 }
417
418 #[test]
419 #[should_panic]
420 fn test_ln_beta_a_lte_0() {
421 super::ln_beta(0.0, 0.5);
422 }
423
424 #[test]
425 #[should_panic]
426 fn test_ln_beta_b_lte_0() {
427 super::ln_beta(0.5, 0.0);
428 }
429
430 #[test]
431 fn test_checked_ln_beta_a_lte_0() {
432 assert!(super::checked_ln_beta(0.0, 0.5).is_err());
433 }
434
435 #[test]
436 fn test_checked_ln_beta_b_lte_0() {
437 assert!(super::checked_ln_beta(0.5, 0.0).is_err());
438 }
439
440 #[test]
441 #[should_panic]
442 fn test_beta_a_lte_0() {
443 super::beta(0.0, 0.5);
444 }
445
446 #[test]
447 #[should_panic]
448 fn test_beta_b_lte_0() {
449 super::beta(0.5, 0.0);
450 }
451
452 #[test]
453 fn test_checked_beta_a_lte_0() {
454 assert!(super::checked_beta(0.0, 0.5).is_err());
455 }
456
457 #[test]
458 fn test_checked_beta_b_lte_0() {
459 assert!(super::checked_beta(0.5, 0.0).is_err());
460 }
461
462 #[test]
463 fn test_beta() {
464 assert_almost_eq!(super::beta(0.5, 0.5), 3.141592653589793238463, 1e-15);
465 assert_almost_eq!(super::beta(1.0, 0.5), 2.0, 1e-14);
466 assert_almost_eq!(super::beta(2.5, 0.5), 1.17809724509617246442, 1e-15);
467 assert_almost_eq!(super::beta(0.5, 1.0), 2.0, 1e-14);
468 assert_almost_eq!(super::beta(1.0, 1.0), 1.0, 1e-15);
469 assert_almost_eq!(super::beta(2.5, 1.0), 0.4, 1e-14);
470 assert_almost_eq!(super::beta(0.5, 2.5), 1.17809724509617246442, 1e-15);
471 assert_almost_eq!(super::beta(1.0, 2.5), 0.4, 1e-14);
472 assert_almost_eq!(super::beta(2.5, 2.5), 0.073631077818510779026, 1e-15);
473 }
474
475 #[test]
476 fn test_beta_inc() {
477 assert_almost_eq!(super::beta_inc(0.5, 0.5, 0.5), 1.570796326794896619231, 1e-14);
478 assert_almost_eq!(super::beta_inc(0.5, 0.5, 1.0), 3.141592653589793238463, 1e-15);
479 assert_almost_eq!(super::beta_inc(1.0, 0.5, 0.5), 0.5857864376269049511983, 1e-15);
480 assert_almost_eq!(super::beta_inc(1.0, 0.5, 1.0), 2.0, 1e-14);
481 assert_almost_eq!(super::beta_inc(2.5, 0.5, 0.5), 0.0890486225480862322117, 1e-16);
482 assert_almost_eq!(super::beta_inc(2.5, 0.5, 1.0), 1.17809724509617246442, 1e-15);
483 assert_almost_eq!(super::beta_inc(0.5, 1.0, 0.5), 1.414213562373095048802, 1e-14);
484 assert_almost_eq!(super::beta_inc(0.5, 1.0, 1.0), 2.0, 1e-14);
485 assert_almost_eq!(super::beta_inc(1.0, 1.0, 0.5), 0.5, 1e-15);
486 assert_almost_eq!(super::beta_inc(1.0, 1.0, 1.0), 1.0, 1e-15);
487 assert_eq!(super::beta_inc(2.5, 1.0, 0.5), 0.0707106781186547524401);
488 assert_almost_eq!(super::beta_inc(2.5, 1.0, 1.0), 0.4, 1e-14);
489 assert_almost_eq!(super::beta_inc(0.5, 2.5, 0.5), 1.08904862254808623221, 1e-15);
490 assert_almost_eq!(super::beta_inc(0.5, 2.5, 1.0), 1.17809724509617246442, 1e-15);
491 assert_almost_eq!(super::beta_inc(1.0, 2.5, 0.5), 0.32928932188134524756, 1e-14);
492 assert_almost_eq!(super::beta_inc(1.0, 2.5, 1.0), 0.4, 1e-14);
493 assert_almost_eq!(super::beta_inc(2.5, 2.5, 0.5), 0.03681553890925538951323, 1e-15);
494 assert_almost_eq!(super::beta_inc(2.5, 2.5, 1.0), 0.073631077818510779026, 1e-15);
495 }
496
497 #[test]
498 #[should_panic]
499 fn test_beta_inc_a_lte_0() {
500 super::beta_inc(0.0, 1.0, 1.0);
501 }
502
503 #[test]
504 #[should_panic]
505 fn test_beta_inc_b_lte_0() {
506 super::beta_inc(1.0, 0.0, 1.0);
507 }
508
509 #[test]
510 #[should_panic]
511 fn test_beta_inc_x_lt_0() {
512 super::beta_inc(1.0, 1.0, -1.0);
513 }
514
515 #[test]
516 #[should_panic]
517 fn test_beta_inc_x_gt_1() {
518 super::beta_inc(1.0, 1.0, 2.0);
519 }
520
521 #[test]
522 fn test_checked_beta_inc_a_lte_0() {
523 assert!(super::checked_beta_inc(0.0, 1.0, 1.0).is_err());
524 }
525
526 #[test]
527 fn test_checked_beta_inc_b_lte_0() {
528 assert!(super::checked_beta_inc(1.0, 0.0, 1.0).is_err());
529 }
530
531 #[test]
532 fn test_checked_beta_inc_x_lt_0() {
533 assert!(super::checked_beta_inc(1.0, 1.0, -1.0).is_err());
534 }
535
536 #[test]
537 fn test_checked_beta_inc_x_gt_1() {
538 assert!(super::checked_beta_inc(1.0, 1.0, 2.0).is_err());
539 }
540
541 #[test]
542 fn test_beta_reg() {
543 assert_almost_eq!(super::beta_reg(0.5, 0.5, 0.5), 0.5, 1e-15);
544 assert_eq!(super::beta_reg(0.5, 0.5, 1.0), 1.0);
545 assert_almost_eq!(super::beta_reg(1.0, 0.5, 0.5), 0.292893218813452475599, 1e-15);
546 assert_eq!(super::beta_reg(1.0, 0.5, 1.0), 1.0);
547 assert_almost_eq!(super::beta_reg(2.5, 0.5, 0.5), 0.07558681842161243795, 1e-16);
548 assert_eq!(super::beta_reg(2.5, 0.5, 1.0), 1.0);
549 assert_almost_eq!(super::beta_reg(0.5, 1.0, 0.5), 0.7071067811865475244, 1e-15);
550 assert_eq!(super::beta_reg(0.5, 1.0, 1.0), 1.0);
551 assert_almost_eq!(super::beta_reg(1.0, 1.0, 0.5), 0.5, 1e-15);
552 assert_eq!(super::beta_reg(1.0, 1.0, 1.0), 1.0);
553 assert_almost_eq!(super::beta_reg(2.5, 1.0, 0.5), 0.1767766952966368811, 1e-15);
554 assert_eq!(super::beta_reg(2.5, 1.0, 1.0), 1.0);
555 assert_eq!(super::beta_reg(0.5, 2.5, 0.5), 0.92441318157838756205);
556 assert_eq!(super::beta_reg(0.5, 2.5, 1.0), 1.0);
557 assert_almost_eq!(super::beta_reg(1.0, 2.5, 0.5), 0.8232233047033631189, 1e-15);
558 assert_eq!(super::beta_reg(1.0, 2.5, 1.0), 1.0);
559 assert_almost_eq!(super::beta_reg(2.5, 2.5, 0.5), 0.5, 1e-15);
560 assert_eq!(super::beta_reg(2.5, 2.5, 1.0), 1.0);
561 }
562
563 #[test]
564 #[should_panic]
565 fn test_beta_reg_a_lte_0() {
566 super::beta_reg(0.0, 1.0, 1.0);
567 }
568
569 #[test]
570 #[should_panic]
571 fn test_beta_reg_b_lte_0() {
572 super::beta_reg(1.0, 0.0, 1.0);
573 }
574
575 #[test]
576 #[should_panic]
577 fn test_beta_reg_x_lt_0() {
578 super::beta_reg(1.0, 1.0, -1.0);
579 }
580
581 #[test]
582 #[should_panic]
583 fn test_beta_reg_x_gt_1() {
584 super::beta_reg(1.0, 1.0, 2.0);
585 }
586
587 #[test]
588 fn test_checked_beta_reg_a_lte_0() {
589 assert!(super::checked_beta_reg(0.0, 1.0, 1.0).is_err());
590 }
591
592 #[test]
593 fn test_checked_beta_reg_b_lte_0() {
594 assert!(super::checked_beta_reg(1.0, 0.0, 1.0).is_err());
595 }
596
597 #[test]
598 fn test_checked_beta_reg_x_lt_0() {
599 assert!(super::checked_beta_reg(1.0, 1.0, -1.0).is_err());
600 }
601
602 #[test]
603 fn test_checked_beta_reg_x_gt_1() {
604 assert!(super::checked_beta_reg(1.0, 1.0, 2.0).is_err());
605 }
606}