From 07adbd81d95afd3c4016dadbe04cbfd9f3feef65 Mon Sep 17 00:00:00 2001 From: ranzhejiang Date: Thu, 28 Aug 2025 03:07:29 +0000 Subject: [PATCH] add chunk moe support on INC --- .../fp8_quant/_quant_common/helper_modules.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index 9e95c756140..26e760031b5 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -841,6 +841,12 @@ def post_process(self): self.w2_list[i].weight = torch.nn.Parameter(self.w2_list[i].weight.squeeze().t().contiguous()) htcore.mark_step() + def _get_extra_kwargs(self, tokens_num: int): + kwargs = {} + if hasattr(self.orig_mod, "_get_extra_kwargs"): + kwargs = self.orig_mod._get_extra_kwargs(tokens_num) + return kwargs + def forward_measure( self, x, @@ -900,6 +906,9 @@ def forward_quant( scale_w1 = [self.w13_list[i].scale_weight for i in experts_range] scale_w2 = [self.w2_list[i].scale_weight for i in experts_range] qinput = self.quant_input(hidden_states) + + tokens_num, hidden_dim = hidden_states.shape + extra_kwargs = self._get_extra_kwargs(tokens_num) output = self.dynamic_moe_op( hidden_states=qinput, expert_routing_table=expert_routing_table, @@ -914,6 +923,7 @@ def forward_quant( activation=activation, experts_min=self.experts_min, experts_max=self.experts_max, + **extra_kwargs ) return output