Skip to content

Commit c542c23

Browse files
authored
hotfix #11 (#13)
1 parent a9c570e commit c542c23

File tree

2 files changed

+4
-46
lines changed

2 files changed

+4
-46
lines changed

docker/build_image.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# get cpu arch
22
arch=$(uname -m)
33
version=$1
4+
buildtool=$2
45
# if version is not provided, raise error
56
if [ -z "$version" ]; then
67
echo "Please provide version number"
78
exit 1
89
fi
910
echo "Building image for $arch, version $version"
10-
docker build -f docker/Dockerfile.$arch-cuda . -t ghcr.io/xiaozheyao/scratchpad:${version}dev-$arch --build-arg ARCH=$arch
11-
docker push ghcr.io/xiaozheyao/scratchpad:${version}dev-$arch
11+
$buildtool build -f docker/Dockerfile.$arch-cuda . -t ghcr.io/xiaozheyao/scratchpad:${version}dev-$arch --build-arg ARCH=$arch
12+
$buildtool push ghcr.io/xiaozheyao/scratchpad:${version}dev-$arch

scratchpad/nn/layers/layernorm.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import torch.nn as nn
55
from flashinfer.norm import (
66
fused_add_rmsnorm,
7-
# gemma_fused_add_rmsnorm,
8-
# gemma_rmsnorm,
97
rmsnorm,
108
)
119
from scratchpad.model_executor.custom_op import CustomOp
@@ -50,45 +48,4 @@ def forward_native(
5048
if residual is None:
5149
return x
5250
else:
53-
return x, residual
54-
55-
56-
# class GemmaRMSNorm(CustomOp):
57-
# def __init__(
58-
# self,
59-
# hidden_size: int,
60-
# eps: float = 1e-6,
61-
# ) -> None:
62-
# super().__init__()
63-
# self.weight = nn.Parameter(torch.zeros(hidden_size))
64-
# self.variance_epsilon = eps
65-
66-
# def forward_native(
67-
# self,
68-
# x: torch.Tensor,
69-
# residual: Optional[torch.Tensor] = None,
70-
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
71-
# orig_dtype = x.dtype
72-
# if residual is not None:
73-
# x = x + residual
74-
# residual = x
75-
76-
# x = x.float()
77-
# variance = x.pow(2).mean(dim=-1, keepdim=True)
78-
# x = x * torch.rsqrt(variance + self.variance_epsilon)
79-
# x = x * (1.0 + self.weight.float())
80-
# x = x.to(orig_dtype)
81-
# return x if residual is None else (x, residual)
82-
83-
# def forward_cuda(
84-
# self,
85-
# x: torch.Tensor,
86-
# residual: Optional[torch.Tensor] = None,
87-
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
88-
# if residual is not None:
89-
# gemma_fused_add_rmsnorm(
90-
# x, residual, self.weight.data, self.variance_epsilon
91-
# )
92-
# return x, residual
93-
# out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
94-
# return out
51+
return x, residual

0 commit comments

Comments
 (0)