Skip to content

Commit b335481

Browse files
lhoestqkhushmeeet
authored andcommitted
Support all negative values in ClassLabel (huggingface#4511)
* support all negative valeus in ClassLabel * support None in cast_storage
1 parent e9879b9 commit b335481

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

src/datasets/features/features.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,9 @@ class ClassLabel:
854854
* `names`: List of label strings.
855855
* `names_file`: File containing the list of labels.
856856
857+
Under the hood the labels are stored as integers.
858+
You can use negative integers to represent unknown/missing labels.
859+
857860
Args:
858861
num_classes (:obj:`int`, optional): Number of classes. All labels must be < `num_classes`.
859862
names (:obj:`list` of :obj:`str`, optional): String names for the integer classes.
@@ -910,7 +913,7 @@ def __post_init__(self, names_file):
910913
def __call__(self):
911914
return self.pa_type
912915

913-
def str2int(self, values: Union[str, Iterable]):
916+
def str2int(self, values: Union[str, Iterable]) -> Union[int, Iterable]:
914917
"""Conversion class name string => integer.
915918
916919
Example:
@@ -934,7 +937,7 @@ def str2int(self, values: Union[str, Iterable]):
934937
output = [self._strval2int(value) for value in values]
935938
return output if return_list else output[0]
936939

937-
def _strval2int(self, value: str):
940+
def _strval2int(self, value: str) -> int:
938941
failed_parse = False
939942
value = str(value)
940943
# first attempt - raw string value
@@ -955,9 +958,11 @@ def _strval2int(self, value: str):
955958
raise ValueError(f"Invalid string class label {value}")
956959
return int_value
957960

958-
def int2str(self, values: Union[int, Iterable]):
961+
def int2str(self, values: Union[int, Iterable]) -> Union[str, Iterable]:
959962
"""Conversion integer => class name string.
960963
964+
Regarding unknown/missing labels: passing negative integers raises ValueError.
965+
961966
Example:
962967
963968
```py
@@ -1014,8 +1019,6 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.IntegerArray]) -> pa.In
10141019
"""
10151020
if isinstance(storage, pa.IntegerArray):
10161021
min_max = pc.min_max(storage).as_py()
1017-
if min_max["min"] < -1:
1018-
raise ValueError(f"Class label {min_max['min']} less than -1")
10191022
if min_max["max"] >= self.num_classes:
10201023
raise ValueError(
10211024
f"Class label {min_max['max']} greater than configured num_classes {self.num_classes}"

tests/features/test_features.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ def test_classlabel_str2int():
288288
classlabel.str2int("__bad_label_name__")
289289
with pytest.raises(ValueError):
290290
classlabel.str2int(1)
291+
with pytest.raises(ValueError):
292+
classlabel.str2int(None)
291293

292294

293295
def test_classlabel_int2str():
@@ -297,6 +299,45 @@ def test_classlabel_int2str():
297299
assert classlabel.int2str(i) == names[i]
298300
with pytest.raises(ValueError):
299301
classlabel.int2str(len(names))
302+
with pytest.raises(ValueError):
303+
classlabel.int2str(-1)
304+
with pytest.raises(ValueError):
305+
classlabel.int2str(None)
306+
307+
308+
def test_classlabel_cast_storage():
309+
names = ["negative", "positive"]
310+
classlabel = ClassLabel(names=names)
311+
# from integers
312+
arr = pa.array([0, 1, -1, -100], type=pa.int64())
313+
result = classlabel.cast_storage(arr)
314+
assert result.type == pa.int64()
315+
assert result.to_pylist() == [0, 1, -1, -100]
316+
arr = pa.array([0, 1, -1, -100], type=pa.int32())
317+
result = classlabel.cast_storage(arr)
318+
assert result.type == pa.int64()
319+
assert result.to_pylist() == [0, 1, -1, -100]
320+
arr = pa.array([3])
321+
with pytest.raises(ValueError):
322+
classlabel.cast_storage(arr)
323+
# from strings
324+
arr = pa.array(["negative", "positive"])
325+
result = classlabel.cast_storage(arr)
326+
assert result.type == pa.int64()
327+
assert result.to_pylist() == [0, 1]
328+
arr = pa.array(["__label_that_doesnt_exist__"])
329+
with pytest.raises(ValueError):
330+
classlabel.cast_storage(arr)
331+
# from nulls
332+
arr = pa.array([None])
333+
result = classlabel.cast_storage(arr)
334+
assert result.type == pa.int64()
335+
assert result.to_pylist() == [None]
336+
# from empty
337+
arr = pa.array([])
338+
result = classlabel.cast_storage(arr)
339+
assert result.type == pa.int64()
340+
assert result.to_pylist() == []
300341

301342

302343
@pytest.mark.parametrize("class_label_arg", ["names", "names_file"])

0 commit comments

Comments
 (0)