@@ -94,9 +94,7 @@ class ModelArguments:
94
94
lora_rank : int = field (default = 8 , metadata = {"help" : "Lora rank" })
95
95
lora_alpha : int = field (default = 16 , metadata = {"help" : "Lora alpha" })
96
96
qat : bool = field (default = False , metadata = {"help" : "Whether to use QAT technique" })
97
- qat_bit_length : int = field (
98
- default = 8 , metadata = {"help" : "Number of bits to represent an quantized integer in binary" }
99
- )
97
+ qat_type : str = field (default = "A8W8" , metadata = {"help" : "Quantization type. Supported values: A8W8, W4,A8W4" })
100
98
101
99
102
100
def convert_example (example , tokenizer , label_list , max_seq_length = 512 , is_test = False ):
@@ -195,7 +193,14 @@ def main():
195
193
if model_args .lora :
196
194
# TODO: hardcode parameters for now. Change after MergedLoRA is introduced
197
195
lora_config = LoRAConfig (
198
- target_modules = [".*q_proj.*" , ".*v_proj.*" ],
196
+ target_modules = [
197
+ ".*self_attn.q_proj.*" ,
198
+ ".*self_attn.k_proj.*" ,
199
+ ".*self_attn.v_proj.*" ,
200
+ ".*self_attn.out_proj.*" ,
201
+ ".*linear1.*" ,
202
+ ".*linear2.*" ,
203
+ ],
199
204
trainable_modules = [".*classifier.*" ],
200
205
r = model_args .lora_rank ,
201
206
lora_alpha = model_args .lora_alpha ,
@@ -218,20 +223,22 @@ def main():
218
223
219
224
q_config = QuantConfig (activation = None , weight = None )
220
225
q_config .add_qat_layer_mapping (LoRALinear , QuantedLoRALinear )
221
- q_config .add_type_config (
222
- LoRALinear ,
223
- weight = FakeQuanterChannelWiseAbsMaxObserver (bit_length = model_args .qat_bit_length , dtype = dtype ),
224
- activation = FakeQuanterWithAbsMaxObserver (
225
- moving_rate = 0.9 , bit_length = model_args .qat_bit_length , dtype = dtype
226
- ),
227
- )
228
- q_config .add_type_config (
229
- nn .Linear ,
230
- weight = FakeQuanterChannelWiseAbsMaxObserver (bit_length = model_args .qat_bit_length , dtype = dtype ),
231
- activation = FakeQuanterWithAbsMaxObserver (
232
- moving_rate = 0.9 , bit_length = model_args .qat_bit_length , dtype = dtype
233
- ),
234
- )
226
+
227
+ if model_args .qat_type == "A8W8" :
228
+ activation = FakeQuanterWithAbsMaxObserver (moving_rate = 0.9 , bit_length = 8 , dtype = dtype )
229
+ weight = FakeQuanterChannelWiseAbsMaxObserver (bit_length = 8 , dtype = dtype )
230
+ elif model_args .qat_type == "W4" :
231
+ activation = None
232
+ weight = FakeQuanterChannelWiseAbsMaxObserver (bit_length = 4 , dtype = dtype )
233
+ elif model_args .qat_type == "A8W4" :
234
+ activation = FakeQuanterWithAbsMaxObserver (moving_rate = 0.9 , bit_length = 8 , dtype = dtype )
235
+ weight = FakeQuanterChannelWiseAbsMaxObserver (bit_length = 4 , dtype = dtype )
236
+ else :
237
+ raise ValueError ("qat_type should be one of ['A8W8', 'W4', 'A8W4']" )
238
+
239
+ q_config .add_type_config (LoRALinear , weight = weight , activation = activation )
240
+ q_config .add_type_config (nn .Linear , weight = weight , activation = activation )
241
+
235
242
qat = QAT (q_config )
236
243
model = qat .quantize (model , inplace = True )
237
244
0 commit comments