Skip to content

Commit dfd163a

Browse files
authored
[feat] Support schema field type promotion (#159)
## Motivation The client is not correctly following [Avro's type promotion rules](https://avro.apache.org/docs/1.11.1/specification/#schema-resolution), leading to a potential problem with data serialization and deserialization. The expected behavior is that the Python client should correctly follow Avro's type promotion rules and perform type conversion when necessary, ensuring compatibility. However the actual behavior is that the Python client's schema deserialization is too strict, and type promotion is not happening as expected. ## Modification - Support schema field type promotion when validating the python type - Convert the field value to the desired compatible python type
1 parent 995e491 commit dfd163a

File tree

3 files changed

+83
-11
lines changed

3 files changed

+83
-11
lines changed

pulsar/schema/definition.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def validate_type(self, name, val):
228228
if val is None and not self._required:
229229
return self.default()
230230

231-
if type(val) != self.python_type():
231+
if not isinstance(val, self.python_type()):
232232
raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type())))
233233
return val
234234

@@ -309,7 +309,7 @@ def type(self):
309309
return 'float'
310310

311311
def python_type(self):
312-
return float
312+
return float, int
313313

314314
def default(self):
315315
if self._default is not None:
@@ -323,7 +323,7 @@ def type(self):
323323
return 'double'
324324

325325
def python_type(self):
326-
return float
326+
return float, int
327327

328328
def default(self):
329329
if self._default is not None:
@@ -337,30 +337,37 @@ def type(self):
337337
return 'bytes'
338338

339339
def python_type(self):
340-
return bytes
340+
return bytes, str
341341

342342
def default(self):
343343
if self._default is not None:
344344
return self._default
345345
else:
346346
return None
347347

348+
def validate_type(self, name, val):
349+
if isinstance(val, str):
350+
return val.encode()
351+
return val
352+
348353

349354
class String(Field):
350355
def type(self):
351356
return 'string'
352357

353358
def python_type(self):
354-
return str
359+
return str, bytes
355360

356361
def validate_type(self, name, val):
357362
t = type(val)
358363

359364
if val is None and not self._required:
360365
return self.default()
361366

362-
if not (t is str or t.__name__ == 'unicode'):
367+
if not (isinstance(val, (str, bytes)) or t.__name__ == 'unicode'):
363368
raise TypeError("Invalid type '%s' for field '%s'. Expected a string" % (t, name))
369+
if isinstance(val, bytes):
370+
return val.decode()
364371
return val
365372

366373
def default(self):
@@ -406,7 +413,7 @@ def validate_type(self, name, val):
406413
else:
407414
raise TypeError(
408415
"Invalid enum value '%s' for field '%s'. Expected: %s" % (val, name, self.values.keys()))
409-
elif type(val) != self.python_type():
416+
elif not isinstance(val, self.python_type()):
410417
raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type())))
411418
else:
412419
return val
@@ -450,7 +457,7 @@ def validate_type(self, name, val):
450457
super(Array, self).validate_type(name, val)
451458

452459
for x in val:
453-
if type(x) != self.array_type.python_type():
460+
if not isinstance(x, self.array_type.python_type()):
454461
raise TypeError('Array field ' + name + ' items should all be of type ' +
455462
_string_representation(self.array_type.type()))
456463
return val
@@ -493,7 +500,7 @@ def validate_type(self, name, val):
493500
for k, v in val.items():
494501
if type(k) != str and not is_unicode(k):
495502
raise TypeError('Map keys for field ' + name + ' should all be strings')
496-
if type(v) != self.value_type.python_type():
503+
if not isinstance(v, self.value_type.python_type()):
497504
raise TypeError('Map values for field ' + name + ' should all be of type '
498505
+ _string_representation(self.value_type.python_type()))
499506

pulsar/schema/schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def __init__(self, record_cls):
101101
def _get_serialized_value(self, o):
102102
if isinstance(o, enum.Enum):
103103
return o.value
104+
elif isinstance(o, bytes):
105+
return o.decode()
104106
else:
105107
data = o.__dict__.copy()
106108
remove_reserved_key(data)

tests/schema_test.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
format='%(asctime)s %(levelname)-5s %(message)s')
3636

3737

38+
class ExampleRecord(Record):
39+
str_field = String()
40+
int_field = Integer()
41+
float_field = Float()
42+
bytes_field = Bytes()
43+
3844
class SchemaTest(TestCase):
3945

4046
serviceUrl = 'pulsar://localhost:6650'
@@ -87,6 +93,31 @@ class Example(Record):
8793
]
8894
})
8995

96+
def test_type_promotion(self):
97+
test_cases = [
98+
(20, int, 20), # No promotion necessary: int => int
99+
(20, float, 20.0), # Promotion: int => float
100+
(20.0, float, 20.0), # No Promotion necessary: float => float
101+
("Test text1", bytes, b"Test text1"), # Promotion: str => bytes
102+
(b"Test text1", str, "Test text1"), # Promotion: bytes => str
103+
]
104+
105+
for value_from, type_to, value_to in test_cases:
106+
if type_to == int:
107+
fieldType = Integer()
108+
elif type_to == float:
109+
fieldType = Double()
110+
elif type_to == str:
111+
fieldType = String()
112+
elif type_to == bytes:
113+
fieldType = Bytes()
114+
else:
115+
fieldType = String()
116+
117+
field_value = fieldType.validate_type("test_field", value_from)
118+
self.assertEqual(value_to, field_value)
119+
120+
90121
def test_complex(self):
91122
class Color(Enum):
92123
red = 1
@@ -229,7 +260,7 @@ class E3(Record):
229260
a = Float()
230261

231262
E3(a=1.0) # Ok
232-
self._expectTypeError(lambda: E3(a=1))
263+
E3(a=1) # Ok Type promotion: int -> float
233264

234265
class E4(Record):
235266
a = Null()
@@ -259,7 +290,7 @@ class E7(Record):
259290
a = Double()
260291

261292
E7(a=1.0) # Ok
262-
self._expectTypeError(lambda: E3(a=1))
293+
E7(a=1) # Ok Type promotion: int -> double
263294

264295
class Color(Enum):
265296
red = 1
@@ -1346,5 +1377,37 @@ def verify_messages(msgs: List[pulsar.Message]):
13461377

13471378
client.close()
13481379

1380+
def test_schema_type_promotion(self):
1381+
client = pulsar.Client(self.serviceUrl)
1382+
1383+
schemas = [("avro", AvroSchema(ExampleRecord)), ("json", JsonSchema(ExampleRecord))]
1384+
1385+
for schema_name, schema in schemas:
1386+
topic = f'test_schema_type_promotion_{schema_name}'
1387+
1388+
consumer = client.subscribe(
1389+
topic=topic,
1390+
subscription_name=f'my-sub-{schema_name}',
1391+
schema=schema
1392+
)
1393+
producer = client.create_producer(
1394+
topic=topic,
1395+
schema=schema
1396+
)
1397+
sendValue = ExampleRecord(str_field=b'test', int_field=1, float_field=3, bytes_field='str')
1398+
1399+
producer.send(sendValue)
1400+
1401+
msg = consumer.receive()
1402+
msg_value = msg.value()
1403+
self.assertEqual(msg_value.str_field, sendValue.str_field)
1404+
self.assertEqual(msg_value.int_field, sendValue.int_field)
1405+
self.assertEqual(msg_value.float_field, sendValue.float_field)
1406+
self.assertEqual(msg_value.bytes_field, sendValue.bytes_field)
1407+
consumer.acknowledge(msg)
1408+
1409+
client.close()
1410+
1411+
13491412
if __name__ == '__main__':
13501413
main()

0 commit comments

Comments
 (0)