Skip to content

Commit 9000902

Browse files
committed
start a crate with tgamma
0 parents  commit 9000902

File tree

6 files changed

+293
-0
lines changed

6 files changed

+293
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/target
2+
Cargo.lock

Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "pymath"
3+
version = "0.1.0"
4+
edition = "2024"
5+
6+
[dev-dependencies]
7+
proptest = "1.6.0"
8+
pyo3 = "0.23.4"

proptest-regressions/gamma.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Seeds for failure cases proptest has generated in the past. It is
2+
# automatically read and these particular cases re-run before any
3+
# novel cases are generated.
4+
#
5+
# It is recommended to check this file in to source control so that
6+
# everyone who runs the test benefits from these saved cases.
7+
cc e8ed768221998086795d95c68921437e80c4b7fe68fe9da15ca40faa216391b5 # shrinks to x = 0.0
8+
cc 23c7f86ab299daa966772921d8c615afda11e1b77944bed40e88264a68e62ac3 # shrinks to x = -19.80948467648103
9+
cc f57954d91904549b9431755f196b630435a43cbefd558b932efad487a403c6c8 # shrinks to x = 0.003585187864492183

src/err.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// defined in libc
2+
#[derive(Debug, PartialEq, Eq)]
3+
pub enum Error {
4+
EDOM = 33,
5+
ERANGE = 34,
6+
}

src/gamma.rs

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
use crate::Error;
2+
use std::f64::consts::PI;
3+
4+
const LOG_PI: f64 = 1.144729885849400174143427351353058711647;
5+
6+
const LANCZOS_N: usize = 13;
7+
const LANCZOS_G: f64 = 6.024680040776729583740234375;
8+
const LANCZOS_G_MINUS_HALF: f64 = 5.524680040776729583740234375;
9+
const LANCZOS_NUM_COEFFS: [f64; LANCZOS_N] = [
10+
23531376880.410759688572007674451636754734846804940,
11+
42919803642.649098768957899047001988850926355848959,
12+
35711959237.355668049440185451547166705960488635843,
13+
17921034426.037209699919755754458931112671403265390,
14+
6039542586.3520280050642916443072979210699388420708,
15+
1439720407.3117216736632230727949123939715485786772,
16+
248874557.86205415651146038641322942321632125127801,
17+
31426415.585400194380614231628318205362874684987640,
18+
2876370.6289353724412254090516208496135991145378768,
19+
186056.26539522349504029498971604569928220784236328,
20+
8071.6720023658162106380029022722506138218516325024,
21+
210.82427775157934587250973392071336271166969580291,
22+
2.5066282746310002701649081771338373386264310793408,
23+
];
24+
const LANCZOS_DEN_COEFFS: [f64; LANCZOS_N] = [
25+
0.0,
26+
39916800.0,
27+
120543840.0,
28+
150917976.0,
29+
105258076.0,
30+
45995730.0,
31+
13339535.0,
32+
2637558.0,
33+
357423.0,
34+
32670.0,
35+
1925.0,
36+
66.0,
37+
1.0,
38+
];
39+
40+
fn lanczos_sum(x: f64) -> f64 {
41+
let mut num = 0.0;
42+
let mut den = 0.0;
43+
// evaluate the rational function lanczos_sum(x). For large
44+
// x, the obvious algorithm risks overflow, so we instead
45+
// rescale the denominator and numerator of the rational
46+
// function by x**(1-LANCZOS_N) and treat this as a
47+
// rational function in 1/x. This also reduces the error for
48+
// larger x values. The choice of cutoff point (5.0 below) is
49+
// somewhat arbitrary; in tests, smaller cutoff values than
50+
// this resulted in lower accuracy.
51+
if x < 5.0 {
52+
for i in (0..LANCZOS_N).rev() {
53+
num = num * x + LANCZOS_NUM_COEFFS[i];
54+
den = den * x + LANCZOS_DEN_COEFFS[i];
55+
}
56+
} else {
57+
for i in 0..LANCZOS_N {
58+
num = num / x + LANCZOS_NUM_COEFFS[i];
59+
den = den / x + LANCZOS_DEN_COEFFS[i];
60+
}
61+
}
62+
num / den
63+
}
64+
65+
fn m_sinpi(x: f64) -> f64 {
66+
// this function should only ever be called for finite arguments
67+
debug_assert!(x.is_finite());
68+
let y = x.abs() % 2.0;
69+
let n = (2.0 * y).round() as i32;
70+
let r = match n {
71+
0 => (PI * y).sin(),
72+
1 => (PI * (y - 0.5)).cos(),
73+
2 => {
74+
// N.B. -sin(pi*(y-1.0)) is *not* equivalent: it would give
75+
// -0.0 instead of 0.0 when y == 1.0.
76+
(PI * (1.0 - y)).sin()
77+
}
78+
3 => -(PI * (y - 1.5)).cos(),
79+
4 => (PI * (y - 2.0)).sin(),
80+
_ => unreachable!(),
81+
};
82+
(1.0f64).copysign(x) * r
83+
}
84+
85+
const NGAMMA_INTEGRAL: usize = 23;
86+
const GAMMA_INTEGRAL: [f64; NGAMMA_INTEGRAL] = [
87+
1.0,
88+
1.0,
89+
2.0,
90+
6.0,
91+
24.0,
92+
120.0,
93+
720.0,
94+
5040.0,
95+
40320.0,
96+
362880.0,
97+
3628800.0,
98+
39916800.0,
99+
479001600.0,
100+
6227020800.0,
101+
87178291200.0,
102+
1307674368000.0,
103+
20922789888000.0,
104+
355687428096000.0,
105+
6402373705728000.0,
106+
121645100408832000.0,
107+
2432902008176640000.0,
108+
51090942171709440000.0,
109+
1124000727777607680000.0,
110+
];
111+
112+
pub fn tgamma(x: f64) -> Result<f64, Error> {
113+
// special cases
114+
if !x.is_finite() {
115+
if x.is_nan() || x > 0.0 {
116+
// tgamma(nan) = nan, tgamma(inf) = inf
117+
return Ok(x);
118+
} else {
119+
// tgamma(-inf) = nan, invalid
120+
return Err((f64::NAN, Error::EDOM).1);
121+
}
122+
}
123+
if x == 0.0 {
124+
// tgamma(+-0.0) = +-inf, divide-by-zero
125+
let v = if x.is_sign_positive() {
126+
f64::INFINITY
127+
} else {
128+
f64::NEG_INFINITY
129+
};
130+
return Err((v, Error::EDOM).1);
131+
}
132+
// integer arguments
133+
if x == x.floor() {
134+
if x < 0.0 {
135+
// tgamma(n) = nan, invalid for
136+
return Err((f64::NAN, Error::EDOM).1);
137+
}
138+
if x < NGAMMA_INTEGRAL as f64 {
139+
return Ok(GAMMA_INTEGRAL[x as usize - 1]);
140+
}
141+
}
142+
let absx = x.abs();
143+
// tiny arguments: tgamma(x) ~ 1/x for x near 0
144+
if absx < 1e-20 {
145+
let r = 1.0 / x;
146+
if r.is_infinite() {
147+
return Err((f64::INFINITY, Error::ERANGE).1);
148+
} else {
149+
return Ok(r);
150+
}
151+
}
152+
// large arguments: assuming IEEE 754 doubles, tgamma(x) overflows for
153+
// x > 200, and underflows to +-0.0 for x < -200, not a negative
154+
// integer.
155+
if absx > 200.0 {
156+
if x < 0.0 {
157+
return Ok(0.0 / m_sinpi(x));
158+
} else {
159+
return Err((f64::INFINITY, Error::ERANGE).1);
160+
}
161+
}
162+
163+
let y = absx + LANCZOS_G_MINUS_HALF;
164+
let z = if absx > LANCZOS_G_MINUS_HALF {
165+
// note: the correction can be foiled by an optimizing
166+
// compiler that (incorrectly) thinks that an expression like
167+
// a + b - a - b can be optimized to 0.0. This shouldn't
168+
// happen in a standards-conforming compiler.
169+
let q = y - absx;
170+
q - LANCZOS_G_MINUS_HALF
171+
} else {
172+
let q = y - LANCZOS_G_MINUS_HALF;
173+
q - absx
174+
};
175+
let z = z * LANCZOS_G / y;
176+
let r = if x < 0.0 {
177+
let mut r = -PI / m_sinpi(absx) / absx * y.exp() / lanczos_sum(absx);
178+
r -= z * r;
179+
if absx < 140.0 {
180+
r /= y.powf(absx - 0.5);
181+
} else {
182+
let sqrtpow = y.powf(absx / 2.0 - 0.25);
183+
r /= sqrtpow;
184+
r /= sqrtpow;
185+
}
186+
r
187+
} else {
188+
let mut r = lanczos_sum(absx) / y.exp();
189+
r += z * r;
190+
if absx < 140.0 {
191+
r *= y.powf(absx - 0.5);
192+
} else {
193+
let sqrtpow = y.powf(absx / 2.0 - 0.25);
194+
r *= sqrtpow;
195+
r *= sqrtpow;
196+
}
197+
r
198+
};
199+
if r.is_infinite() {
200+
return Err((f64::INFINITY, Error::ERANGE).1);
201+
} else {
202+
return Ok(r);
203+
}
204+
}
205+
206+
#[cfg(test)]
207+
mod tests {
208+
use super::*;
209+
use pyo3::Python;
210+
use pyo3::prelude::*;
211+
212+
use proptest::prelude::*;
213+
214+
fn unwrap<'a, T: 'a>(
215+
py: Python,
216+
py_v: PyResult<Bound<'a, PyAny>>,
217+
v: Result<T, crate::Error>,
218+
) -> Option<(T, T)>
219+
where
220+
T: PartialEq + std::fmt::Debug + FromPyObject<'a>,
221+
{
222+
match py_v {
223+
Ok(py_v) => {
224+
let py_v: T = py_v.extract().unwrap();
225+
Some((py_v, v.unwrap()))
226+
}
227+
Err(e) => {
228+
if e.is_instance_of::<pyo3::exceptions::PyValueError>(py) {
229+
assert_eq!(v.err(), Some(Error::EDOM));
230+
} else if e.is_instance_of::<pyo3::exceptions::PyOverflowError>(py) {
231+
assert_eq!(v.err(), Some(Error::ERANGE));
232+
} else {
233+
panic!();
234+
}
235+
None
236+
}
237+
}
238+
}
239+
240+
proptest! {
241+
#[test]
242+
fn test_tgamma(x: f64) {
243+
let rs_gamma = tgamma(x);
244+
245+
pyo3::prepare_freethreaded_python();
246+
Python::with_gil(|py| {
247+
let math = PyModule::import(py, "math").unwrap();
248+
let py_gamma_func = math
249+
.getattr("gamma")
250+
.unwrap();
251+
let r = py_gamma_func.call1((x,));
252+
let Some((py_gamma, rs_gamma)) = unwrap(py, r, rs_gamma) else {
253+
return;
254+
};
255+
let py_gamma_repr = unsafe { std::mem::transmute::<f64, i64>(py_gamma) };
256+
let rs_gamma_repr = unsafe { std::mem::transmute::<f64, i64>(rs_gamma) };
257+
// assert_eq!(py_gamma_repr, rs_gamma_repr, "x = {x}, py_gamma = {py_gamma}, rs_gamma = {rs_gamma}");
258+
// allow 1 bit error for now
259+
assert!((py_gamma_repr - rs_gamma_repr).abs() <= 1, "x = {x} diff: {}, py_gamma = {py_gamma} ({py_gamma_repr:x}), rs_gamma = {rs_gamma} ({rs_gamma_repr:x})", py_gamma_repr ^ rs_gamma_repr);
260+
});
261+
}
262+
}
263+
}

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mod err;
2+
mod gamma;
3+
4+
pub use err::Error;
5+
pub use gamma::tgamma as gamma;

0 commit comments

Comments
 (0)