1
+ import io
2
+ import os
3
+ import struct
4
+ import tempfile
5
+ import unittest
6
+
7
+ # The tests assume the package name is `zstandard` (python-zstandard)
8
+ import zstandard as zstd
9
+
10
+
11
+ class TestZstandardModule (unittest .TestCase ):
12
+ SAMPLE_TEXT = (
13
+ b"The quick brown fox jumps over the lazy dog. "
14
+ b"Pack my box with five dozen liquor jugs.\n " * 50
15
+ )
16
+
17
+ def setUp (self ):
18
+ self .sample = self .SAMPLE_TEXT
19
+ self .samples_list = [
20
+ b"alpha " * 100 ,
21
+ b"bravo " * 80 ,
22
+ b"charlie " * 120 ,
23
+ b"delta " * 60 ,
24
+ ]
25
+
26
+ # ---------- Basics & constants ----------
27
+
28
+ def test_backend_and_basic_constants (self ):
29
+ self .assertIn (zstd .backend , ("cext" , "cffi" ))
30
+ self .assertIsInstance (zstd .ZSTD_VERSION , tuple )
31
+ self .assertGreaterEqual (zstd .MAX_COMPRESSION_LEVEL , 1 )
32
+ # Recommended sizes should be positive
33
+ self .assertGreater (zstd .COMPRESSION_RECOMMENDED_INPUT_SIZE , 0 )
34
+ self .assertGreater (zstd .COMPRESSION_RECOMMENDED_OUTPUT_SIZE , 0 )
35
+ self .assertGreater (zstd .DECOMPRESSION_RECOMMENDED_INPUT_SIZE , 0 )
36
+ self .assertGreater (zstd .DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE , 0 )
37
+ # Frame header/magic number types
38
+ self .assertIsInstance (zstd .FRAME_HEADER , (bytes , bytearray , memoryview ))
39
+ self .assertIsInstance (zstd .MAGIC_NUMBER , int )
40
+
41
+ # ---------- One-shot APIs ----------
42
+
43
+ def test_one_shot_compress_decompress (self ):
44
+ for lvl in (- 5 , 1 , 3 , min (6 , zstd .MAX_COMPRESSION_LEVEL )):
45
+ with self .subTest (level = lvl ):
46
+ comp = zstd .compress (self .sample , level = lvl )
47
+ self .assertIsInstance (comp , bytes )
48
+ decomp = zstd .decompress (comp )
49
+ self .assertEqual (decomp , self .sample )
50
+
51
+ def test_one_shot_decompress_requires_max_output_if_unknown (self ):
52
+ # Build a stream without content size in header:
53
+ cctx = zstd .ZstdCompressor (write_content_size = False )
54
+ frame = cctx .compress (self .sample )
55
+ # frame_content_size should be unknown (-1)
56
+ cs = zstd .frame_content_size (frame )
57
+ self .assertIn (cs , (- 1 , zstd .CONTENTSIZE_UNKNOWN ))
58
+ # ZstdDecompressor().decompress without max_output_size should fail.
59
+ dctx = zstd .ZstdDecompressor ()
60
+ with self .assertRaises (zstd .ZstdError ):
61
+ dctx .decompress (frame )
62
+ # Works with explicit max_output_size
63
+ out = dctx .decompress (frame , max_output_size = len (self .sample ))
64
+ self .assertEqual (out , self .sample )
65
+
66
+ # ---------- File API (zstandard.open) ----------
67
+
68
+ def test_open_binary_and_text_modes (self ):
69
+ with tempfile .TemporaryDirectory () as td :
70
+ bin_path = os .path .join (td , "bin.zst" )
71
+ txt_path = os .path .join (td , "text.zst" )
72
+
73
+ # Binary write/read
74
+ with zstd .open (bin_path , mode = "wb" ) as fh :
75
+ fh .write (self .sample )
76
+ with zstd .open (bin_path , mode = "rb" ) as fh :
77
+ data = fh .read ()
78
+ self .assertEqual (data , self .sample )
79
+
80
+ # Text write/read
81
+ text = "hello\n world\n 🙂\n " * 100
82
+ with zstd .open (txt_path , mode = "wt" , encoding = "utf-8" ) as fh :
83
+ fh .write (text )
84
+ with zstd .open (txt_path , mode = "rt" , encoding = "utf-8" ) as fh :
85
+ read_text = fh .read ()
86
+ self .assertEqual (text , read_text )
87
+
88
+ # ---------- Streaming compression: writer/reader/copy_stream ----------
89
+
90
+ def test_stream_writer_and_reader (self ):
91
+ raw = self .sample
92
+ sink = io .BytesIO ()
93
+ cctx = zstd .ZstdCompressor (level = 5 )
94
+ # Don't close the underlying sink so we can inspect it.
95
+ with cctx .stream_writer (sink , write_size = 32768 , closefd = False ) as w :
96
+ for i in range (0 , len (raw ), 1234 ):
97
+ w .write (raw [i : i + 1234 ])
98
+ # finish frame explicitly to ensure full frame close
99
+ w .flush (zstd .FLUSH_FRAME )
100
+ self .assertGreaterEqual (w .tell (), 0 )
101
+ self .assertGreater (w .memory_size (), 0 )
102
+
103
+ comp = sink .getvalue ()
104
+ self .assertTrue (comp )
105
+
106
+ # Decompress via stream_reader (pull)
107
+ dctx = zstd .ZstdDecompressor ()
108
+ with dctx .stream_reader (io .BytesIO (comp )) as r :
109
+ out = r .read ()
110
+ self .assertEqual (out , raw )
111
+ self .assertFalse (r .readable () and r .writable ())
112
+
113
+ def test_copy_stream_roundtrip (self ):
114
+ cctx = zstd .ZstdCompressor ()
115
+ dctx = zstd .ZstdDecompressor ()
116
+
117
+ src = io .BytesIO (self .sample )
118
+ comp = io .BytesIO ()
119
+ r , w = cctx .copy_stream (src , comp )
120
+ self .assertGreater (r , 0 )
121
+ self .assertGreater (w , 0 )
122
+
123
+ comp .seek (0 )
124
+ out = io .BytesIO ()
125
+ r2 , w2 = dctx .copy_stream (comp , out )
126
+ self .assertGreater (r2 , 0 )
127
+ self .assertGreater (w2 , 0 )
128
+ self .assertEqual (out .getvalue (), self .sample )
129
+
130
+ def test_read_to_iter_and_stream_reader_variants (self ):
131
+ cctx = zstd .ZstdCompressor ()
132
+ pieces = []
133
+ for chunk in cctx .read_to_iter (io .BytesIO (self .sample ), read_size = 4096 , write_size = 8192 ):
134
+ pieces .append (chunk )
135
+ comp = b"" .join (pieces )
136
+ dctx = zstd .ZstdDecompressor ()
137
+ out = b"" .join (dctx .read_to_iter (io .BytesIO (comp ), read_size = 2048 , write_size = 4096 ))
138
+ self .assertEqual (out , self .sample )
139
+
140
+ # ---------- compressobj / decompressobj ----------
141
+
142
+ def test_standard_library_like_objects (self ):
143
+ cctx = zstd .ZstdCompressor ()
144
+ cobj = cctx .compressobj ()
145
+ part1 = cobj .compress (self .sample [:1000 ])
146
+ part2 = cobj .compress (self .sample [1000 :])
147
+ final = cobj .flush ()
148
+ comp = part1 + part2 + final
149
+ self .assertTrue (comp )
150
+
151
+ # Decompress in chunks and test eof/unused_data/unconsumed_tail behavior
152
+ dctx = zstd .ZstdDecompressor ()
153
+ dobj = dctx .decompressobj ()
154
+ out1 = dobj .decompress (comp [:50 ])
155
+ out2 = dobj .decompress (comp [50 :])
156
+ out3 = dobj .flush ()
157
+ out = out1 + out2 + out3
158
+ self .assertEqual (out , self .sample )
159
+ self .assertTrue (dobj .eof )
160
+ self .assertEqual (dobj .unconsumed_tail , b"" )
161
+ # Feed extra data after frame end
162
+ dobj2 = dctx .decompressobj ()
163
+ dobj2 .decompress (comp + b"EXTRA" )
164
+ self .assertNotEqual (dobj2 .unused_data , b"" )
165
+
166
+ # ---------- chunker ----------
167
+
168
+ def test_chunker_api (self ):
169
+ cctx = zstd .ZstdCompressor ()
170
+ chunker = cctx .chunker (chunk_size = 32768 )
171
+ out_chunks = []
172
+ # Feed in uneven piece sizes
173
+ for i in range (0 , len (self .sample ), 777 ):
174
+ for oc in chunker .compress (self .sample [i : i + 777 ]):
175
+ out_chunks .append (oc )
176
+ for oc in chunker .flush ():
177
+ out_chunks .append (oc )
178
+ for oc in chunker .finish ():
179
+ out_chunks .append (oc )
180
+ comp = b"" .join (out_chunks )
181
+ self .assertTrue (comp )
182
+ # Use streaming decompression since content size may be unknown
183
+ dctx = zstd .ZstdDecompressor ()
184
+ out = b"" .join (dctx .read_to_iter (io .BytesIO (comp )))
185
+ self .assertEqual (out , self .sample )
186
+
187
+ # ---------- Decompression writer/reader wrappers ----------
188
+
189
+ def test_decompression_wrappers (self ):
190
+ comp = zstd .compress (self .sample )
191
+ # stream_writer (push). Keep sink open.
192
+ dctx = zstd .ZstdDecompressor ()
193
+ sink = io .BytesIO ()
194
+ with dctx .stream_writer (sink , write_size = 8192 , closefd = False ) as w :
195
+ written = w .write (comp [:50 ])
196
+ self .assertGreaterEqual (written , 0 )
197
+ written += w .write (comp [50 :])
198
+ w .flush ()
199
+ self .assertGreater (w .memory_size (), 0 )
200
+ self .assertEqual (sink .getvalue (), self .sample )
201
+
202
+ # stream_reader again (pull) with seek/tell forward-only behavior
203
+ with dctx .stream_reader (io .BytesIO (comp )) as r :
204
+ self .assertEqual (r .tell (), 0 )
205
+ r .read (10_000 )
206
+ pos = r .tell ()
207
+ self .assertGreater (pos , 0 )
208
+ # seeking backwards should fail — some builds raise OSError, others ValueError
209
+ with self .assertRaises ((ValueError , OSError )):
210
+ r .seek (0 )
211
+
212
+ # ---------- Dictionaries (train/use), dict chaining ----------
213
+
214
+ def test_dictionary_train_and_use (self ):
215
+ # Use many varied small samples and a modest dict size to avoid "Src size is incorrect".
216
+ varied_samples = [
217
+ (f"sample-{ i :04d} -" + "abcde" [i % 5 ] * (20 + (i % 13 ))).encode ("ascii" )
218
+ for i in range (200 )
219
+ ]
220
+ dict_candidate = zstd .train_dictionary (1024 , varied_samples )
221
+ # Some versions return bytes; wrap into ZstdCompressionDict if needed.
222
+ if isinstance (dict_candidate , (bytes , bytearray , memoryview )):
223
+ dict_obj = zstd .ZstdCompressionDict (dict_candidate )
224
+ else :
225
+ dict_obj = dict_candidate # already a ZstdCompressionDict
226
+
227
+ self .assertGreaterEqual (len (dict_obj ), 0 )
228
+ dict_id = dict_obj .dict_id ()
229
+ self .assertIsInstance (dict_id , int )
230
+ raw_dict = dict_obj .as_bytes ()
231
+ self .assertIsInstance (raw_dict , (bytes , bytearray ))
232
+
233
+ # Precompute for a specific level to speed up multi use
234
+ dict_obj .precompute_compress (level = 3 )
235
+
236
+ # Use dict for compression & decompression
237
+ cctx = zstd .ZstdCompressor (dict_data = dict_obj )
238
+ dctx = zstd .ZstdDecompressor (dict_data = dict_obj )
239
+ frames = [cctx .compress (x ) for x in self .samples_list ]
240
+ outs = []
241
+ for fr in frames :
242
+ buf = io .BytesIO ()
243
+ with dctx .stream_writer (buf , closefd = False ) as dec :
244
+ dec .write (fr )
245
+ outs .append (buf .getvalue ())
246
+ self .assertEqual (outs , self .samples_list )
247
+
248
+ def test_decompress_content_dict_chain (self ):
249
+ # Build a content-dictionary chain per docs
250
+ inputs = [b"input 1" , b"input 2" , b"input 3" ]
251
+ frames = []
252
+ frames .append (zstd .ZstdCompressor ().compress (inputs [0 ]))
253
+ for i , raw in enumerate (inputs [1 :]):
254
+ dict_data = zstd .ZstdCompressionDict (
255
+ inputs [i ], dict_type = zstd .DICT_TYPE_RAWCONTENT
256
+ )
257
+ frames .append (zstd .ZstdCompressor (dict_data = dict_data ).compress (raw ))
258
+ # Should yield last input's raw bytes
259
+ last = zstd .ZstdDecompressor ().decompress_content_dict_chain (frames )
260
+ self .assertEqual (last , inputs [- 1 ])
261
+
262
+ # ---------- Multi (de)compress to buffer (experimental) ----------
263
+
264
+ def test_multi_ops_if_supported (self ):
265
+ # Skip if not supported in this installed version.
266
+ have_multi_comp = hasattr (zstd .ZstdCompressor (), "multi_compress_to_buffer" )
267
+ have_multi_decomp = hasattr (zstd .ZstdDecompressor (), "multi_decompress_to_buffer" )
268
+ if zstd .backend == "cffi" or not (have_multi_comp and have_multi_decomp ):
269
+ self .skipTest ("multi_* APIs not supported on this backend/version" )
270
+
271
+ cctx = zstd .ZstdCompressor ()
272
+ comp_collection = cctx .multi_compress_to_buffer (self .samples_list , threads = - 1 )
273
+ self .assertGreater (len (comp_collection ), 0 )
274
+ frames = [bytes (comp_collection [i ]) for i in range (len (comp_collection ))]
275
+
276
+ dctx = zstd .ZstdDecompressor ()
277
+ sizes = struct .pack ("=" + "Q" * len (self .samples_list ), * [len (x ) for x in self .samples_list ])
278
+ out_collection = dctx .multi_decompress_to_buffer (frames , decompressed_sizes = sizes , threads = - 1 )
279
+ self .assertEqual (len (out_collection ), len (self .samples_list ))
280
+ recon = [bytes (out_collection [i ]) for i in range (len (out_collection ))]
281
+ self .assertEqual (recon , self .samples_list )
282
+
283
+ # ---------- Frame inspection & utilities ----------
284
+
285
+ def test_frame_header_and_parameters (self ):
286
+ cctx = zstd .ZstdCompressor (write_checksum = True , write_content_size = True )
287
+ frame = cctx .compress (self .sample )
288
+ # Header size should be parseable
289
+ header_len = zstd .frame_header_size (frame )
290
+ self .assertGreaterEqual (header_len , 4 )
291
+ # get_frame_parameters needs at least 18 bytes according to docs
292
+ params = zstd .get_frame_parameters (frame [: max (18 , header_len )])
293
+ self .assertIsInstance (params , zstd .FrameParameters )
294
+ self .assertTrue (params .has_checksum )
295
+ # content size should be embedded
296
+ self .assertEqual (params .content_size , len (self .sample ))
297
+ # frame_content_size should match
298
+ self .assertEqual (zstd .frame_content_size (frame ), len (self .sample ))
299
+
300
+ def test_estimate_context_sizes (self ):
301
+ self .assertGreater (zstd .estimate_decompression_context_size (), 0 )
302
+ params = zstd .ZstdCompressionParameters .from_level (4 , source_size = len (self .sample ))
303
+ self .assertGreater (params .estimated_compression_context_size (), 0 )
304
+ # Override knobs
305
+ params2 = zstd .ZstdCompressionParameters .from_level (
306
+ 3 , window_log = 10 , threads = 2 , write_checksum = 1
307
+ )
308
+ self .assertIsInstance (params2 , zstd .ZstdCompressionParameters )
309
+
310
+ # ---------- Progress & memory size ----------
311
+
312
+ def test_frame_progression_and_memory_size (self ):
313
+ cctx = zstd .ZstdCompressor ()
314
+ _ = cctx .memory_size ()
315
+ ing , cons , prod = cctx .frame_progression ()
316
+ self .assertEqual (len ((ing , cons , prod )), 3 )
317
+ # Do some streaming work and check progression changes
318
+ sink = io .BytesIO ()
319
+ with cctx .stream_writer (sink , closefd = False ) as w :
320
+ w .write (self .sample [:1000 ])
321
+ a = cctx .frame_progression ()
322
+ w .write (self .sample [1000 :])
323
+ b = cctx .frame_progression ()
324
+ self .assertNotEqual (a , b )
325
+
326
+ # ---------- read_across_frames / allow_extra_data ----------
327
+
328
+ def test_decompress_across_frames_and_extra_data (self ):
329
+ cctx = zstd .ZstdCompressor ()
330
+ frame1 = cctx .compress (b"A" * 10 )
331
+ frame2 = cctx .compress (b"B" * 20 )
332
+ combined_with_trailing = frame1 + frame2 + b"TRAILING"
333
+
334
+ # ZstdDecompressor.decompress (single frame) should ignore extras by default
335
+ dctx = zstd .ZstdDecompressor ()
336
+ data1 = dctx .decompress (combined_with_trailing )
337
+ self .assertEqual (data1 , b"A" * 10 )
338
+
339
+ # When allow_extra_data=False, extra should error
340
+ with self .assertRaises (zstd .ZstdError ):
341
+ dctx .decompress (combined_with_trailing , allow_extra_data = False )
342
+
343
+ # For read_across_frames=True, feed only valid frames (no trailing garbage).
344
+ dctx2 = zstd .ZstdDecompressor ()
345
+ combined_frames_only = frame1 + frame2
346
+ with dctx2 .stream_reader (io .BytesIO (combined_frames_only ), read_across_frames = True ) as r :
347
+ out = r .read ()
348
+ self .assertEqual (out , b"A" * 10 + b"B" * 20 )
349
+
350
+ # ---------- zstandard.open with user-provided contexts ----------
351
+
352
+ def test_open_with_custom_contexts (self ):
353
+ with tempfile .TemporaryDirectory () as td :
354
+ p = os .path .join (td , "ctx.zst" )
355
+ cctx = zstd .ZstdCompressor (level = 7 )
356
+ with zstd .open (p , "wb" , cctx = cctx ) as fh :
357
+ fh .write (self .sample )
358
+ dctx = zstd .ZstdDecompressor ()
359
+ with zstd .open (p , "rb" , dctx = dctx ) as fh :
360
+ self .assertEqual (fh .read (), self .sample )
361
+
362
+
363
+ if __name__ == "__main__" :
364
+ # Helpful hint when running locally: respect import policy via env var.
365
+ # e.g. PYTHON_ZSTANDARD_IMPORT_POLICY=cext python test_zstandard.py
366
+ print (f"Using python-zstandard backend: { zstd .backend } " )
367
+ unittest .main (verbosity = 2 )
0 commit comments