@@ -190,6 +190,10 @@ def generate_expected_address_lookup_buffer(
190
190
191
191
return torch .tensor (address_lookup , dtype = torch .int64 )
192
192
193
+ @unittest .skipIf (
194
+ torch .cuda .device_count () <= 1 ,
195
+ "Not enough GPUs, this test requires at least two GPUs" ,
196
+ )
193
197
def test_init_itep_module (self ) -> None :
194
198
itep_module = GenericITEPModule (
195
199
table_name_to_unpruned_hash_sizes = self ._table_name_to_unpruned_hash_sizes ,
@@ -222,6 +226,10 @@ def test_init_itep_module(self) -> None:
222
226
equal_nan = True ,
223
227
)
224
228
229
+ @unittest .skipIf (
230
+ torch .cuda .device_count () <= 1 ,
231
+ "Not enough GPUs, this test requires at least two GPUs" ,
232
+ )
225
233
def test_init_itep_module_without_pruned_table (self ) -> None :
226
234
itep_module = GenericITEPModule (
227
235
table_name_to_unpruned_hash_sizes = {},
@@ -353,6 +361,10 @@ def test_eval_forward(
353
361
# Check that reset_weight_momentum is not called
354
362
self .assertEqual (mock_reset_weight_momentum .call_count , 0 )
355
363
364
+ @unittest .skipIf (
365
+ torch .cuda .device_count () <= 1 ,
366
+ "Not enough GPUs, this test requires at least two GPUs" ,
367
+ )
356
368
def test_iter_increment_per_forward (self ) -> None :
357
369
"""Test that the iteration counter increments correctly with each forward pass."""
358
370
itep_module = GenericITEPModule (
0 commit comments