@@ -854,6 +854,9 @@ class ClassLabel:
854
854
* `names`: List of label strings.
855
855
* `names_file`: File containing the list of labels.
856
856
857
+ Under the hood the labels are stored as integers.
858
+ You can use negative integers to represent unknown/missing labels.
859
+
857
860
Args:
858
861
num_classes (:obj:`int`, optional): Number of classes. All labels must be < `num_classes`.
859
862
names (:obj:`list` of :obj:`str`, optional): String names for the integer classes.
@@ -910,7 +913,7 @@ def __post_init__(self, names_file):
910
913
def __call__ (self ):
911
914
return self .pa_type
912
915
913
- def str2int (self , values : Union [str , Iterable ]):
916
+ def str2int (self , values : Union [str , Iterable ]) -> Union [ int , Iterable ] :
914
917
"""Conversion class name string => integer.
915
918
916
919
Example:
@@ -934,7 +937,7 @@ def str2int(self, values: Union[str, Iterable]):
934
937
output = [self ._strval2int (value ) for value in values ]
935
938
return output if return_list else output [0 ]
936
939
937
- def _strval2int (self , value : str ):
940
+ def _strval2int (self , value : str ) -> int :
938
941
failed_parse = False
939
942
value = str (value )
940
943
# first attempt - raw string value
@@ -955,9 +958,11 @@ def _strval2int(self, value: str):
955
958
raise ValueError (f"Invalid string class label { value } " )
956
959
return int_value
957
960
958
- def int2str (self , values : Union [int , Iterable ]):
961
+ def int2str (self , values : Union [int , Iterable ]) -> Union [ str , Iterable ] :
959
962
"""Conversion integer => class name string.
960
963
964
+ Regarding unknown/missing labels: passing negative integers raises ValueError.
965
+
961
966
Example:
962
967
963
968
```py
@@ -1014,16 +1019,12 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.IntegerArray]) -> pa.In
1014
1019
"""
1015
1020
if isinstance (storage , pa .IntegerArray ):
1016
1021
min_max = pc .min_max (storage ).as_py ()
1017
- if min_max ["min" ] < - 1 :
1018
- raise ValueError (f"Class label { min_max ['min' ]} less than -1" )
1019
1022
if min_max ["max" ] >= self .num_classes :
1020
1023
raise ValueError (
1021
1024
f"Class label { min_max ['max' ]} greater than configured num_classes { self .num_classes } "
1022
1025
)
1023
1026
elif isinstance (storage , pa .StringArray ):
1024
- storage = pa .array (
1025
- [self ._strval2int (label ) if label is not None else None for label in storage .to_pylist ()]
1026
- )
1027
+ storage = pa .array (self .str2int (storage .to_pylist ()))
1027
1028
return array_cast (storage , self .pa_type )
1028
1029
1029
1030
@staticmethod
0 commit comments