Skip to content

Commit 9121856

Browse files
committed
Add test for zstandard
1 parent 86575e9 commit 9121856

File tree

1 file changed

+367
-0
lines changed

1 file changed

+367
-0
lines changed

tests/zstandard-test.py

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
import io
2+
import os
3+
import struct
4+
import tempfile
5+
import unittest
6+
7+
# The tests assume the package name is `zstandard` (python-zstandard)
8+
import zstandard as zstd
9+
10+
11+
class TestZstandardModule(unittest.TestCase):
12+
SAMPLE_TEXT = (
13+
b"The quick brown fox jumps over the lazy dog. "
14+
b"Pack my box with five dozen liquor jugs.\n" * 50
15+
)
16+
17+
def setUp(self):
18+
self.sample = self.SAMPLE_TEXT
19+
self.samples_list = [
20+
b"alpha " * 100,
21+
b"bravo " * 80,
22+
b"charlie " * 120,
23+
b"delta " * 60,
24+
]
25+
26+
# ---------- Basics & constants ----------
27+
28+
def test_backend_and_basic_constants(self):
29+
self.assertIn(zstd.backend, ("cext", "cffi"))
30+
self.assertIsInstance(zstd.ZSTD_VERSION, tuple)
31+
self.assertGreaterEqual(zstd.MAX_COMPRESSION_LEVEL, 1)
32+
# Recommended sizes should be positive
33+
self.assertGreater(zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE, 0)
34+
self.assertGreater(zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE, 0)
35+
self.assertGreater(zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE, 0)
36+
self.assertGreater(zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, 0)
37+
# Frame header/magic number types
38+
self.assertIsInstance(zstd.FRAME_HEADER, (bytes, bytearray, memoryview))
39+
self.assertIsInstance(zstd.MAGIC_NUMBER, int)
40+
41+
# ---------- One-shot APIs ----------
42+
43+
def test_one_shot_compress_decompress(self):
44+
for lvl in (-5, 1, 3, min(6, zstd.MAX_COMPRESSION_LEVEL)):
45+
with self.subTest(level=lvl):
46+
comp = zstd.compress(self.sample, level=lvl)
47+
self.assertIsInstance(comp, bytes)
48+
decomp = zstd.decompress(comp)
49+
self.assertEqual(decomp, self.sample)
50+
51+
def test_one_shot_decompress_requires_max_output_if_unknown(self):
52+
# Build a stream without content size in header:
53+
cctx = zstd.ZstdCompressor(write_content_size=False)
54+
frame = cctx.compress(self.sample)
55+
# frame_content_size should be unknown (-1)
56+
cs = zstd.frame_content_size(frame)
57+
self.assertIn(cs, (-1, zstd.CONTENTSIZE_UNKNOWN))
58+
# ZstdDecompressor().decompress without max_output_size should fail.
59+
dctx = zstd.ZstdDecompressor()
60+
with self.assertRaises(zstd.ZstdError):
61+
dctx.decompress(frame)
62+
# Works with explicit max_output_size
63+
out = dctx.decompress(frame, max_output_size=len(self.sample))
64+
self.assertEqual(out, self.sample)
65+
66+
# ---------- File API (zstandard.open) ----------
67+
68+
def test_open_binary_and_text_modes(self):
69+
with tempfile.TemporaryDirectory() as td:
70+
bin_path = os.path.join(td, "bin.zst")
71+
txt_path = os.path.join(td, "text.zst")
72+
73+
# Binary write/read
74+
with zstd.open(bin_path, mode="wb") as fh:
75+
fh.write(self.sample)
76+
with zstd.open(bin_path, mode="rb") as fh:
77+
data = fh.read()
78+
self.assertEqual(data, self.sample)
79+
80+
# Text write/read
81+
text = "hello\nworld\n🙂\n" * 100
82+
with zstd.open(txt_path, mode="wt", encoding="utf-8") as fh:
83+
fh.write(text)
84+
with zstd.open(txt_path, mode="rt", encoding="utf-8") as fh:
85+
read_text = fh.read()
86+
self.assertEqual(text, read_text)
87+
88+
# ---------- Streaming compression: writer/reader/copy_stream ----------
89+
90+
def test_stream_writer_and_reader(self):
91+
raw = self.sample
92+
sink = io.BytesIO()
93+
cctx = zstd.ZstdCompressor(level=5)
94+
# Don't close the underlying sink so we can inspect it.
95+
with cctx.stream_writer(sink, write_size=32768, closefd=False) as w:
96+
for i in range(0, len(raw), 1234):
97+
w.write(raw[i : i + 1234])
98+
# finish frame explicitly to ensure full frame close
99+
w.flush(zstd.FLUSH_FRAME)
100+
self.assertGreaterEqual(w.tell(), 0)
101+
self.assertGreater(w.memory_size(), 0)
102+
103+
comp = sink.getvalue()
104+
self.assertTrue(comp)
105+
106+
# Decompress via stream_reader (pull)
107+
dctx = zstd.ZstdDecompressor()
108+
with dctx.stream_reader(io.BytesIO(comp)) as r:
109+
out = r.read()
110+
self.assertEqual(out, raw)
111+
self.assertFalse(r.readable() and r.writable())
112+
113+
def test_copy_stream_roundtrip(self):
114+
cctx = zstd.ZstdCompressor()
115+
dctx = zstd.ZstdDecompressor()
116+
117+
src = io.BytesIO(self.sample)
118+
comp = io.BytesIO()
119+
r, w = cctx.copy_stream(src, comp)
120+
self.assertGreater(r, 0)
121+
self.assertGreater(w, 0)
122+
123+
comp.seek(0)
124+
out = io.BytesIO()
125+
r2, w2 = dctx.copy_stream(comp, out)
126+
self.assertGreater(r2, 0)
127+
self.assertGreater(w2, 0)
128+
self.assertEqual(out.getvalue(), self.sample)
129+
130+
def test_read_to_iter_and_stream_reader_variants(self):
131+
cctx = zstd.ZstdCompressor()
132+
pieces = []
133+
for chunk in cctx.read_to_iter(io.BytesIO(self.sample), read_size=4096, write_size=8192):
134+
pieces.append(chunk)
135+
comp = b"".join(pieces)
136+
dctx = zstd.ZstdDecompressor()
137+
out = b"".join(dctx.read_to_iter(io.BytesIO(comp), read_size=2048, write_size=4096))
138+
self.assertEqual(out, self.sample)
139+
140+
# ---------- compressobj / decompressobj ----------
141+
142+
def test_standard_library_like_objects(self):
143+
cctx = zstd.ZstdCompressor()
144+
cobj = cctx.compressobj()
145+
part1 = cobj.compress(self.sample[:1000])
146+
part2 = cobj.compress(self.sample[1000:])
147+
final = cobj.flush()
148+
comp = part1 + part2 + final
149+
self.assertTrue(comp)
150+
151+
# Decompress in chunks and test eof/unused_data/unconsumed_tail behavior
152+
dctx = zstd.ZstdDecompressor()
153+
dobj = dctx.decompressobj()
154+
out1 = dobj.decompress(comp[:50])
155+
out2 = dobj.decompress(comp[50:])
156+
out3 = dobj.flush()
157+
out = out1 + out2 + out3
158+
self.assertEqual(out, self.sample)
159+
self.assertTrue(dobj.eof)
160+
self.assertEqual(dobj.unconsumed_tail, b"")
161+
# Feed extra data after frame end
162+
dobj2 = dctx.decompressobj()
163+
dobj2.decompress(comp + b"EXTRA")
164+
self.assertNotEqual(dobj2.unused_data, b"")
165+
166+
# ---------- chunker ----------
167+
168+
def test_chunker_api(self):
169+
cctx = zstd.ZstdCompressor()
170+
chunker = cctx.chunker(chunk_size=32768)
171+
out_chunks = []
172+
# Feed in uneven piece sizes
173+
for i in range(0, len(self.sample), 777):
174+
for oc in chunker.compress(self.sample[i : i + 777]):
175+
out_chunks.append(oc)
176+
for oc in chunker.flush():
177+
out_chunks.append(oc)
178+
for oc in chunker.finish():
179+
out_chunks.append(oc)
180+
comp = b"".join(out_chunks)
181+
self.assertTrue(comp)
182+
# Use streaming decompression since content size may be unknown
183+
dctx = zstd.ZstdDecompressor()
184+
out = b"".join(dctx.read_to_iter(io.BytesIO(comp)))
185+
self.assertEqual(out, self.sample)
186+
187+
# ---------- Decompression writer/reader wrappers ----------
188+
189+
def test_decompression_wrappers(self):
190+
comp = zstd.compress(self.sample)
191+
# stream_writer (push). Keep sink open.
192+
dctx = zstd.ZstdDecompressor()
193+
sink = io.BytesIO()
194+
with dctx.stream_writer(sink, write_size=8192, closefd=False) as w:
195+
written = w.write(comp[:50])
196+
self.assertGreaterEqual(written, 0)
197+
written += w.write(comp[50:])
198+
w.flush()
199+
self.assertGreater(w.memory_size(), 0)
200+
self.assertEqual(sink.getvalue(), self.sample)
201+
202+
# stream_reader again (pull) with seek/tell forward-only behavior
203+
with dctx.stream_reader(io.BytesIO(comp)) as r:
204+
self.assertEqual(r.tell(), 0)
205+
r.read(10_000)
206+
pos = r.tell()
207+
self.assertGreater(pos, 0)
208+
# seeking backwards should fail — some builds raise OSError, others ValueError
209+
with self.assertRaises((ValueError, OSError)):
210+
r.seek(0)
211+
212+
# ---------- Dictionaries (train/use), dict chaining ----------
213+
214+
def test_dictionary_train_and_use(self):
215+
# Use many varied small samples and a modest dict size to avoid "Src size is incorrect".
216+
varied_samples = [
217+
(f"sample-{i:04d}-" + "abcde"[i % 5] * (20 + (i % 13))).encode("ascii")
218+
for i in range(200)
219+
]
220+
dict_candidate = zstd.train_dictionary(1024, varied_samples)
221+
# Some versions return bytes; wrap into ZstdCompressionDict if needed.
222+
if isinstance(dict_candidate, (bytes, bytearray, memoryview)):
223+
dict_obj = zstd.ZstdCompressionDict(dict_candidate)
224+
else:
225+
dict_obj = dict_candidate # already a ZstdCompressionDict
226+
227+
self.assertGreaterEqual(len(dict_obj), 0)
228+
dict_id = dict_obj.dict_id()
229+
self.assertIsInstance(dict_id, int)
230+
raw_dict = dict_obj.as_bytes()
231+
self.assertIsInstance(raw_dict, (bytes, bytearray))
232+
233+
# Precompute for a specific level to speed up multi use
234+
dict_obj.precompute_compress(level=3)
235+
236+
# Use dict for compression & decompression
237+
cctx = zstd.ZstdCompressor(dict_data=dict_obj)
238+
dctx = zstd.ZstdDecompressor(dict_data=dict_obj)
239+
frames = [cctx.compress(x) for x in self.samples_list]
240+
outs = []
241+
for fr in frames:
242+
buf = io.BytesIO()
243+
with dctx.stream_writer(buf, closefd=False) as dec:
244+
dec.write(fr)
245+
outs.append(buf.getvalue())
246+
self.assertEqual(outs, self.samples_list)
247+
248+
def test_decompress_content_dict_chain(self):
249+
# Build a content-dictionary chain per docs
250+
inputs = [b"input 1", b"input 2", b"input 3"]
251+
frames = []
252+
frames.append(zstd.ZstdCompressor().compress(inputs[0]))
253+
for i, raw in enumerate(inputs[1:]):
254+
dict_data = zstd.ZstdCompressionDict(
255+
inputs[i], dict_type=zstd.DICT_TYPE_RAWCONTENT
256+
)
257+
frames.append(zstd.ZstdCompressor(dict_data=dict_data).compress(raw))
258+
# Should yield last input's raw bytes
259+
last = zstd.ZstdDecompressor().decompress_content_dict_chain(frames)
260+
self.assertEqual(last, inputs[-1])
261+
262+
# ---------- Multi (de)compress to buffer (experimental) ----------
263+
264+
def test_multi_ops_if_supported(self):
265+
# Skip if not supported in this installed version.
266+
have_multi_comp = hasattr(zstd.ZstdCompressor(), "multi_compress_to_buffer")
267+
have_multi_decomp = hasattr(zstd.ZstdDecompressor(), "multi_decompress_to_buffer")
268+
if zstd.backend == "cffi" or not (have_multi_comp and have_multi_decomp):
269+
self.skipTest("multi_* APIs not supported on this backend/version")
270+
271+
cctx = zstd.ZstdCompressor()
272+
comp_collection = cctx.multi_compress_to_buffer(self.samples_list, threads=-1)
273+
self.assertGreater(len(comp_collection), 0)
274+
frames = [bytes(comp_collection[i]) for i in range(len(comp_collection))]
275+
276+
dctx = zstd.ZstdDecompressor()
277+
sizes = struct.pack("=" + "Q" * len(self.samples_list), *[len(x) for x in self.samples_list])
278+
out_collection = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes, threads=-1)
279+
self.assertEqual(len(out_collection), len(self.samples_list))
280+
recon = [bytes(out_collection[i]) for i in range(len(out_collection))]
281+
self.assertEqual(recon, self.samples_list)
282+
283+
# ---------- Frame inspection & utilities ----------
284+
285+
def test_frame_header_and_parameters(self):
286+
cctx = zstd.ZstdCompressor(write_checksum=True, write_content_size=True)
287+
frame = cctx.compress(self.sample)
288+
# Header size should be parseable
289+
header_len = zstd.frame_header_size(frame)
290+
self.assertGreaterEqual(header_len, 4)
291+
# get_frame_parameters needs at least 18 bytes according to docs
292+
params = zstd.get_frame_parameters(frame[: max(18, header_len)])
293+
self.assertIsInstance(params, zstd.FrameParameters)
294+
self.assertTrue(params.has_checksum)
295+
# content size should be embedded
296+
self.assertEqual(params.content_size, len(self.sample))
297+
# frame_content_size should match
298+
self.assertEqual(zstd.frame_content_size(frame), len(self.sample))
299+
300+
def test_estimate_context_sizes(self):
301+
self.assertGreater(zstd.estimate_decompression_context_size(), 0)
302+
params = zstd.ZstdCompressionParameters.from_level(4, source_size=len(self.sample))
303+
self.assertGreater(params.estimated_compression_context_size(), 0)
304+
# Override knobs
305+
params2 = zstd.ZstdCompressionParameters.from_level(
306+
3, window_log=10, threads=2, write_checksum=1
307+
)
308+
self.assertIsInstance(params2, zstd.ZstdCompressionParameters)
309+
310+
# ---------- Progress & memory size ----------
311+
312+
def test_frame_progression_and_memory_size(self):
313+
cctx = zstd.ZstdCompressor()
314+
_ = cctx.memory_size()
315+
ing, cons, prod = cctx.frame_progression()
316+
self.assertEqual(len((ing, cons, prod)), 3)
317+
# Do some streaming work and check progression changes
318+
sink = io.BytesIO()
319+
with cctx.stream_writer(sink, closefd=False) as w:
320+
w.write(self.sample[:1000])
321+
a = cctx.frame_progression()
322+
w.write(self.sample[1000:])
323+
b = cctx.frame_progression()
324+
self.assertNotEqual(a, b)
325+
326+
# ---------- read_across_frames / allow_extra_data ----------
327+
328+
def test_decompress_across_frames_and_extra_data(self):
329+
cctx = zstd.ZstdCompressor()
330+
frame1 = cctx.compress(b"A" * 10)
331+
frame2 = cctx.compress(b"B" * 20)
332+
combined_with_trailing = frame1 + frame2 + b"TRAILING"
333+
334+
# ZstdDecompressor.decompress (single frame) should ignore extras by default
335+
dctx = zstd.ZstdDecompressor()
336+
data1 = dctx.decompress(combined_with_trailing)
337+
self.assertEqual(data1, b"A" * 10)
338+
339+
# When allow_extra_data=False, extra should error
340+
with self.assertRaises(zstd.ZstdError):
341+
dctx.decompress(combined_with_trailing, allow_extra_data=False)
342+
343+
# For read_across_frames=True, feed only valid frames (no trailing garbage).
344+
dctx2 = zstd.ZstdDecompressor()
345+
combined_frames_only = frame1 + frame2
346+
with dctx2.stream_reader(io.BytesIO(combined_frames_only), read_across_frames=True) as r:
347+
out = r.read()
348+
self.assertEqual(out, b"A" * 10 + b"B" * 20)
349+
350+
# ---------- zstandard.open with user-provided contexts ----------
351+
352+
def test_open_with_custom_contexts(self):
353+
with tempfile.TemporaryDirectory() as td:
354+
p = os.path.join(td, "ctx.zst")
355+
cctx = zstd.ZstdCompressor(level=7)
356+
with zstd.open(p, "wb", cctx=cctx) as fh:
357+
fh.write(self.sample)
358+
dctx = zstd.ZstdDecompressor()
359+
with zstd.open(p, "rb", dctx=dctx) as fh:
360+
self.assertEqual(fh.read(), self.sample)
361+
362+
363+
if __name__ == "__main__":
364+
# Helpful hint when running locally: respect import policy via env var.
365+
# e.g. PYTHON_ZSTANDARD_IMPORT_POLICY=cext python test_zstandard.py
366+
print(f"Using python-zstandard backend: {zstd.backend}")
367+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)