Skip to content

Commit af28a45

Browse files
committed
fix pir llama vpp test
1 parent feebd79 commit af28a45

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

test/auto_parallel/hybrid_strategy/semi_auto_llama_pp_gradmerge.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,7 @@ def __init__(self):
120120
self.init_dist_env()
121121

122122
def init_dist_env(self):
123-
order = ["dp", "pp", "mp"]
124-
dp_degree = self.dp
125-
mp_degree = self.mp
126-
pp_degree = self.pp
127-
degree = [dp_degree, pp_degree, mp_degree]
128-
mesh_dims = list(filter(lambda x: x[1] > 1, list(zip(order, degree))))
123+
mesh_dims = [("pp", self.pp), ("dp", self.dp), ("mp", self.mp)]
129124
if not mesh_dims:
130125
mesh_dims = [("dp", 1)]
131126
dim_names = [mesh_dim[0] for mesh_dim in mesh_dims]
@@ -212,6 +207,7 @@ def run_llama(self, to_static=0):
212207
strategy.pipeline.accumulate_steps = (
213208
self.gradient_accumulation_steps
214209
)
210+
strategy.pipeline.pp_degree = self.pp
215211
strategy.pipeline.micro_batch_size = micro_bsz
216212
strategy.pipeline.schedule_mode = self.schedule_mode
217213
strategy.pipeline.vpp_degree = self.config.virtual_pp_degree

test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model_vpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import collective.test_communication_api_base as test_base
2222

23-
os.environ['FLAGS_enable_pir_api'] = '0'
23+
os.environ['FLAGS_enable_pir_api'] = '1'
2424

2525

2626
class TestSemiAutoParallelLlama3DVPP(test_base.CommunicationTestDistBase):

0 commit comments

Comments
 (0)