Skip to content

Commit 74f8c65

Browse files
xzyaoiBoyko Borisov
andauthored
Support Llama MoE model (#2)
* checkpoint * checkpoint * Finalize naive MoE * checkpoint * checkpoint * Finalize naive MoE * integrate with triteia sbmm * Remove TriteiaLinear layer * Adds initial code for unquantised MoE * Improve linear MoE implementation * import attention from base llama --------- Co-authored-by: Boyko Borisov <bborisov@sgs-gpu05.ethz.ch>
1 parent 47b0993 commit 74f8c65

File tree

7 files changed

+1382
-1
lines changed

7 files changed

+1382
-1
lines changed

scratchpad/nn/layers/linear.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
PackedvLLMParameter,
2323
PerTensorScaleParameter,
2424
)
25+
from triteia.python.nn.linear import sparse_low_precision_linear
2526

2627
WEIGHT_LOADER_V2_SUPPORTED = [
2728
"CompressedTensorsLinearMethod",
@@ -1161,3 +1162,21 @@ def extra_repr(self) -> str:
11611162
s += f", tp_size={self.tp_size}"
11621163
s += f", reduce_results={self.reduce_results}"
11631164
return s
1165+
1166+
class TritelaLinear(LinearBase):
1167+
def __init__(
1168+
self,
1169+
input_size,
1170+
output_size,
1171+
skip_bias_add = False,
1172+
params_dtype = None,
1173+
quant_config = None,
1174+
prefix = ""
1175+
):
1176+
super().__init__(
1177+
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
1178+
)
1179+
self.layer = sparse_low_precision_linear(input_size, output_size)
1180+
1181+
def forward(self, x):
1182+
return self.layer(x)

scratchpad/nn/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88
_GENERATION_MODELS = {
99
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
10+
"LlamaNaiveQuantisedMoEForCausalLM": (
11+
"llama_naive_moe",
12+
"LlamaNaiveQuantisedMoEForCausalLM",
13+
),
14+
"LlamaQuantisedMoEForCausalLM": ("llama_quant_moe", "LlamaQuantisedMoEForCausalLM"),
15+
"LlamaMoEForCausalLM": ("llama_moe", "LlamaMoEForCausalLM"),
1016
}
1117

1218
_EMBEDDING_MODELS = {

0 commit comments

Comments
 (0)