Skip to content

Commit a61f31b

Browse files
committed
Fix bug #126: handle NaNs correctly in min() and max()
1 parent b1894e9 commit a61f31b

File tree

2 files changed

+96
-4
lines changed

2 files changed

+96
-4
lines changed

src/bfloat.rs

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ impl bf16 {
511511
#[inline]
512512
#[must_use]
513513
pub fn max(self, other: bf16) -> bf16 {
514-
if other > self && !other.is_nan() {
514+
if self.is_nan() || other > self {
515515
other
516516
} else {
517517
self
@@ -534,7 +534,7 @@ impl bf16 {
534534
#[inline]
535535
#[must_use]
536536
pub fn min(self, other: bf16) -> bf16 {
537-
if other < self && !other.is_nan() {
537+
if self.is_nan() || other < self {
538538
other
539539
} else {
540540
self
@@ -1877,4 +1877,50 @@ mod test {
18771877
f.0 == roundtrip.0
18781878
}
18791879
}
1880+
1881+
#[test]
1882+
fn test_max() {
1883+
let a = bf16::from_f32(0.0);
1884+
let b = bf16::from_f32(42.0);
1885+
assert_eq!(a.max(b), b);
1886+
1887+
let a = bf16::from_f32(42.0);
1888+
let b = bf16::from_f32(0.0);
1889+
assert_eq!(a.max(b), a);
1890+
1891+
let a = bf16::NAN;
1892+
let b = bf16::from_f32(42.0);
1893+
assert_eq!(a.max(b), b);
1894+
1895+
let a = bf16::from_f32(42.0);
1896+
let b = bf16::NAN;
1897+
assert_eq!(a.max(b), a);
1898+
1899+
let a = bf16::NAN;
1900+
let b = bf16::NAN;
1901+
assert!(a.max(b).is_nan());
1902+
}
1903+
1904+
#[test]
1905+
fn test_min() {
1906+
let a = bf16::from_f32(0.0);
1907+
let b = bf16::from_f32(42.0);
1908+
assert_eq!(a.min(b), a);
1909+
1910+
let a = bf16::from_f32(42.0);
1911+
let b = bf16::from_f32(0.0);
1912+
assert_eq!(a.min(b), b);
1913+
1914+
let a = bf16::NAN;
1915+
let b = bf16::from_f32(42.0);
1916+
assert_eq!(a.min(b), b);
1917+
1918+
let a = bf16::from_f32(42.0);
1919+
let b = bf16::NAN;
1920+
assert_eq!(a.min(b), a);
1921+
1922+
let a = bf16::NAN;
1923+
let b = bf16::NAN;
1924+
assert!(a.min(b).is_nan());
1925+
}
18801926
}

src/binary16.rs

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ impl f16 {
522522
#[inline]
523523
#[must_use]
524524
pub fn max(self, other: f16) -> f16 {
525-
if other > self && !other.is_nan() {
525+
if self.is_nan() || other > self {
526526
other
527527
} else {
528528
self
@@ -545,7 +545,7 @@ impl f16 {
545545
#[inline]
546546
#[must_use]
547547
pub fn min(self, other: f16) -> f16 {
548-
if other < self && !other.is_nan() {
548+
if self.is_nan() || other < self {
549549
other
550550
} else {
551551
self
@@ -1961,4 +1961,50 @@ mod test {
19611961
f.0 == roundtrip.0
19621962
}
19631963
}
1964+
1965+
#[test]
1966+
fn test_max() {
1967+
let a = f16::from_f32(0.0);
1968+
let b = f16::from_f32(42.0);
1969+
assert_eq!(a.max(b), b);
1970+
1971+
let a = f16::from_f32(42.0);
1972+
let b = f16::from_f32(0.0);
1973+
assert_eq!(a.max(b), a);
1974+
1975+
let a = f16::NAN;
1976+
let b = f16::from_f32(42.0);
1977+
assert_eq!(a.max(b), b);
1978+
1979+
let a = f16::from_f32(42.0);
1980+
let b = f16::NAN;
1981+
assert_eq!(a.max(b), a);
1982+
1983+
let a = f16::NAN;
1984+
let b = f16::NAN;
1985+
assert!(a.max(b).is_nan());
1986+
}
1987+
1988+
#[test]
1989+
fn test_min() {
1990+
let a = f16::from_f32(0.0);
1991+
let b = f16::from_f32(42.0);
1992+
assert_eq!(a.min(b), a);
1993+
1994+
let a = f16::from_f32(42.0);
1995+
let b = f16::from_f32(0.0);
1996+
assert_eq!(a.min(b), b);
1997+
1998+
let a = f16::NAN;
1999+
let b = f16::from_f32(42.0);
2000+
assert_eq!(a.min(b), b);
2001+
2002+
let a = f16::from_f32(42.0);
2003+
let b = f16::NAN;
2004+
assert_eq!(a.min(b), a);
2005+
2006+
let a = f16::NAN;
2007+
let b = f16::NAN;
2008+
assert!(a.min(b).is_nan());
2009+
}
19642010
}

0 commit comments

Comments
 (0)