|
| 1 | +# Copyright (c) 2024, Tri Dao, Albert Gu. |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | + |
| 8 | + |
| 9 | +@triton.autotune( |
| 10 | + configs=[ |
| 11 | + triton.Config({'BLOCK_N': 32}), |
| 12 | + triton.Config({'BLOCK_N': 64}), |
| 13 | + triton.Config({'BLOCK_N': 128}), |
| 14 | + triton.Config({'BLOCK_N': 256}), |
| 15 | + triton.Config({'BLOCK_N': 512}), |
| 16 | + triton.Config({'BLOCK_N': 1024}), |
| 17 | + ], |
| 18 | + key=['ncols'], |
| 19 | +) |
| 20 | +@triton.jit |
| 21 | +def _swiglu_fwd_kernel( |
| 22 | + X, |
| 23 | + Y, |
| 24 | + OUT, |
| 25 | + stride_x_row, # how much to increase the pointer when moving by 1 row |
| 26 | + stride_y_row, |
| 27 | + stride_out_row, |
| 28 | + ncols, |
| 29 | + BLOCK_N: tl.constexpr, |
| 30 | +): |
| 31 | + # Map the program id to the row of X and Y it should compute. |
| 32 | + row = tl.program_id(0) |
| 33 | + start_col = tl.program_id(1) * BLOCK_N |
| 34 | + X += row * stride_x_row |
| 35 | + Y += row * stride_y_row |
| 36 | + OUT += row * stride_out_row |
| 37 | + cols = start_col + tl.arange(0, BLOCK_N) |
| 38 | + x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) |
| 39 | + y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) |
| 40 | + out = x * tl.sigmoid(x) * y |
| 41 | + tl.store(OUT + cols, out, mask=cols < ncols) |
| 42 | + |
| 43 | + |
| 44 | +def _swiglu_fwd(xy, out=None): |
| 45 | + if xy.stride(-1) != 1: |
| 46 | + xy = xy.contiguous() |
| 47 | + batch_shape = xy.shape[:-1] |
| 48 | + xy = xy.reshape(-1, xy.shape[-1]) |
| 49 | + x, y = xy.chunk(2, dim=-1) |
| 50 | + if out is None: |
| 51 | + out = torch.empty_like(x) |
| 52 | + else: |
| 53 | + out = out.reshape(-1, out.shape[-1]) |
| 54 | + assert out.shape == x.shape |
| 55 | + assert out.stride(-1) == 1 |
| 56 | + M, N = x.shape |
| 57 | + grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) |
| 58 | + with torch.cuda.device(x.device.index): |
| 59 | + _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N) |
| 60 | + return out.reshape(*batch_shape, out.shape[-1]) |
| 61 | + |
| 62 | + |
| 63 | +@triton.autotune( |
| 64 | + configs=[ |
| 65 | + triton.Config({'BLOCK_N': 32}), |
| 66 | + triton.Config({'BLOCK_N': 64}), |
| 67 | + triton.Config({'BLOCK_N': 128}), |
| 68 | + triton.Config({'BLOCK_N': 256}), |
| 69 | + triton.Config({'BLOCK_N': 512}), |
| 70 | + triton.Config({'BLOCK_N': 1024}), |
| 71 | + ], |
| 72 | + key=['ncols'], |
| 73 | +) |
| 74 | +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None}) |
| 75 | +@triton.jit |
| 76 | +def _swiglu_bwd_kernel( |
| 77 | + X, |
| 78 | + Y, |
| 79 | + DOUT, |
| 80 | + OUT, |
| 81 | + DX, |
| 82 | + DY, |
| 83 | + stride_x_row, # how much to increase the pointer when moving by 1 row |
| 84 | + stride_y_row, |
| 85 | + stride_dout_row, |
| 86 | + stride_out_row, |
| 87 | + stride_dx_row, |
| 88 | + stride_dy_row, |
| 89 | + ncols, |
| 90 | + BLOCK_N: tl.constexpr, |
| 91 | + RECOMPUTE_OUTPUT: tl.constexpr, |
| 92 | +): |
| 93 | + # Map the program id to the row of X and Y it should compute. |
| 94 | + row = tl.program_id(0) |
| 95 | + start_col = tl.program_id(1) * BLOCK_N |
| 96 | + X += row * stride_x_row |
| 97 | + Y += row * stride_y_row |
| 98 | + DOUT += row * stride_dout_row |
| 99 | + if RECOMPUTE_OUTPUT: |
| 100 | + OUT += row * stride_out_row |
| 101 | + DX += row * stride_dx_row |
| 102 | + DY += row * stride_dy_row |
| 103 | + cols = start_col + tl.arange(0, BLOCK_N) |
| 104 | + x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) |
| 105 | + y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) |
| 106 | + dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32) |
| 107 | + x_sigmoid = tl.sigmoid(x) |
| 108 | + dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout |
| 109 | + dy = x * x_sigmoid * dout |
| 110 | + tl.store(DX + cols, dx, mask=cols < ncols) |
| 111 | + tl.store(DY + cols, dy, mask=cols < ncols) |
| 112 | + if RECOMPUTE_OUTPUT: |
| 113 | + out = x * x_sigmoid * y |
| 114 | + tl.store(OUT + cols, out, mask=cols < ncols) |
| 115 | + |
| 116 | + |
| 117 | +def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None): |
| 118 | + if xy.stride(-1) != 1: |
| 119 | + xy = xy.contiguous() |
| 120 | + if dout.stride(-1) != 1: |
| 121 | + dout = dout.contiguous() |
| 122 | + batch_shape = xy.shape[:-1] |
| 123 | + xy = xy.reshape(-1, xy.shape[-1]) |
| 124 | + x, y = xy.chunk(2, dim=-1) |
| 125 | + dout = dout.reshape(-1, dout.shape[-1]) |
| 126 | + assert dout.shape == x.shape |
| 127 | + if dxy is None: |
| 128 | + dxy = torch.empty_like(xy) |
| 129 | + else: |
| 130 | + dxy = dxy.reshape(-1, dxy.shape[-1]) |
| 131 | + assert dxy.shape == xy.shape |
| 132 | + dx, dy = dxy.chunk(2, dim=-1) |
| 133 | + assert dx.stride(-1) == 1 |
| 134 | + assert dy.stride(-1) == 1 |
| 135 | + if recompute_output: |
| 136 | + if out is None: |
| 137 | + out = torch.empty_like(x) |
| 138 | + else: |
| 139 | + out = out.reshape(-1, out.shape[-1]) |
| 140 | + assert out.shape == x.shape |
| 141 | + assert out.stride(-1) == 1 |
| 142 | + M, N = x.shape |
| 143 | + grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) |
| 144 | + with torch.cuda.device(x.device.index): |
| 145 | + _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy, |
| 146 | + x.stride(0), y.stride(0), dout.stride(0), |
| 147 | + out.stride(0) if recompute_output else 0, |
| 148 | + dx.stride(0), dy.stride(0), |
| 149 | + N) |
| 150 | + if not recompute_output: |
| 151 | + return dxy.reshape(*batch_shape, dxy.shape[-1]) |
| 152 | + else: |
| 153 | + return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1]) |
0 commit comments