Skip to content

Commit ca15d62

Browse files
lugimzzzsijunhe
andauthored
[GLUE ]update qat lora (PaddlePaddle#6278)
* update qat lora * Update examples/benchmark/glue/run_glue_trainer.py Co-authored-by: Sijun He <sijun.he@hotmail.com> * Update examples/benchmark/glue/run_glue_trainer.py Co-authored-by: Sijun He <sijun.he@hotmail.com> --------- Co-authored-by: Sijun He <sijun.he@hotmail.com>
1 parent 7af4ee3 commit ca15d62

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

examples/benchmark/glue/run_glue_trainer.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,7 @@ class ModelArguments:
9494
lora_rank: int = field(default=8, metadata={"help": "Lora rank"})
9595
lora_alpha: int = field(default=16, metadata={"help": "Lora alpha"})
9696
qat: bool = field(default=False, metadata={"help": "Whether to use QAT technique"})
97-
qat_bit_length: int = field(
98-
default=8, metadata={"help": "Number of bits to represent an quantized integer in binary"}
99-
)
97+
qat_type: str = field(default="A8W8", metadata={"help": "Quantization type. Supported values: A8W8, W4,A8W4"})
10098

10199

102100
def convert_example(example, tokenizer, label_list, max_seq_length=512, is_test=False):
@@ -195,7 +193,14 @@ def main():
195193
if model_args.lora:
196194
# TODO: hardcode parameters for now. Change after MergedLoRA is introduced
197195
lora_config = LoRAConfig(
198-
target_modules=[".*q_proj.*", ".*v_proj.*"],
196+
target_modules=[
197+
".*self_attn.q_proj.*",
198+
".*self_attn.k_proj.*",
199+
".*self_attn.v_proj.*",
200+
".*self_attn.out_proj.*",
201+
".*linear1.*",
202+
".*linear2.*",
203+
],
199204
trainable_modules=[".*classifier.*"],
200205
r=model_args.lora_rank,
201206
lora_alpha=model_args.lora_alpha,
@@ -218,20 +223,22 @@ def main():
218223

219224
q_config = QuantConfig(activation=None, weight=None)
220225
q_config.add_qat_layer_mapping(LoRALinear, QuantedLoRALinear)
221-
q_config.add_type_config(
222-
LoRALinear,
223-
weight=FakeQuanterChannelWiseAbsMaxObserver(bit_length=model_args.qat_bit_length, dtype=dtype),
224-
activation=FakeQuanterWithAbsMaxObserver(
225-
moving_rate=0.9, bit_length=model_args.qat_bit_length, dtype=dtype
226-
),
227-
)
228-
q_config.add_type_config(
229-
nn.Linear,
230-
weight=FakeQuanterChannelWiseAbsMaxObserver(bit_length=model_args.qat_bit_length, dtype=dtype),
231-
activation=FakeQuanterWithAbsMaxObserver(
232-
moving_rate=0.9, bit_length=model_args.qat_bit_length, dtype=dtype
233-
),
234-
)
226+
227+
if model_args.qat_type == "A8W8":
228+
activation = FakeQuanterWithAbsMaxObserver(moving_rate=0.9, bit_length=8, dtype=dtype)
229+
weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=8, dtype=dtype)
230+
elif model_args.qat_type == "W4":
231+
activation = None
232+
weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype=dtype)
233+
elif model_args.qat_type == "A8W4":
234+
activation = FakeQuanterWithAbsMaxObserver(moving_rate=0.9, bit_length=8, dtype=dtype)
235+
weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype=dtype)
236+
else:
237+
raise ValueError("qat_type should be one of ['A8W8', 'W4', 'A8W4']")
238+
239+
q_config.add_type_config(LoRALinear, weight=weight, activation=activation)
240+
q_config.add_type_config(nn.Linear, weight=weight, activation=activation)
241+
235242
qat = QAT(q_config)
236243
model = qat.quantize(model, inplace=True)
237244

0 commit comments

Comments
 (0)