Skip to content

Commit d414c3f

Browse files
lhezggerganov
authored andcommitted
opencl: add fused rms_norm_mul (llama/14841)
* opencl: add fused `rms_norm` + `mul` * opencl: improve workgroup size for `rms_norm_mul`
1 parent bbf2389 commit d414c3f

File tree

2 files changed

+240
-2
lines changed

2 files changed

+240
-2
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 161 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
333333
size_t max_alloc_size;
334334
bool fp16_support;
335335
bool has_vector_subgroup_broadcast;
336+
bool disable_fusion;
336337
ggml_cl_compiler_version adreno_cl_compiler_version;
337338

338339
int adreno_wave_size;
@@ -411,7 +412,7 @@ struct ggml_backend_opencl_context {
411412
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
412413
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
413414
cl_kernel kernel_norm;
414-
cl_kernel kernel_rms_norm;
415+
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
415416
cl_kernel kernel_group_norm;
416417
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
417418
cl_kernel kernel_soft_max, kernel_soft_max_4;
@@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11001101
backend_ctx->program_rms_norm =
11011102
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
11021103

1103-
CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
1104+
CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
1105+
CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err));
11041106
GGML_LOG_CONT(".");
11051107
}
11061108

@@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
21102112
CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
21112113
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
21122114

2115+
backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
2116+
21132117
dev_ctx->backend_ctx = backend_ctx.release();
21142118
return dev_ctx->backend_ctx;
21152119
}
@@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) {
22792283
sync_with_other_backends(backend_ctx);
22802284
}
22812285

2286+
static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2287+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2288+
return false;
2289+
}
2290+
2291+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2292+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2293+
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2294+
2295+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2296+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2297+
2298+
// rms_norm only supports f32
2299+
if (mul->src[0]->type != GGML_TYPE_F32 ||
2300+
mul->src[1]->type != GGML_TYPE_F32 ||
2301+
mul->type != GGML_TYPE_F32) {
2302+
return false;
2303+
}
2304+
2305+
// if rms_norm is the B operand, then we don't handle broadcast
2306+
if (rms_norm == mul->src[1] &&
2307+
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2308+
return false;
2309+
}
2310+
2311+
// rms_norm assumes contiguous rows
2312+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2313+
return false;
2314+
}
2315+
}
2316+
2317+
return true;
2318+
}
2319+
2320+
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
2321+
22822322
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2323+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
2324+
22832325
for (int i = 0; i < cgraph->n_nodes; i++) {
22842326
ggml_tensor * node = cgraph->nodes[i];
22852327

@@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
22922334
continue;
22932335
}
22942336

2337+
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2338+
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
2339+
i++;
2340+
continue;
2341+
}
2342+
22952343
bool ok = ggml_cl_compute_forward(backend, node);
22962344
if (!ok) {
22972345
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
@@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
44554503
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
44564504
}
44574505

4506+
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
4507+
GGML_ASSERT(mul_tensor);
4508+
GGML_ASSERT(rms_norm_tensor);
4509+
4510+
// src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
4511+
const ggml_tensor * src0 = rms_norm_tensor->src[0];
4512+
const ggml_tensor * src1;
4513+
if (mul_tensor->src[0] == rms_norm_tensor) {
4514+
src1 = mul_tensor->src[1];
4515+
} else if (mul_tensor->src[1] == rms_norm_tensor) {
4516+
src1 = mul_tensor->src[0];
4517+
} else {
4518+
GGML_ASSERT(false && "Invalid args for rms_norm and mul");
4519+
}
4520+
const ggml_tensor * dst = mul_tensor;
4521+
4522+
GGML_ASSERT(src0);
4523+
GGML_ASSERT(src0->extra);
4524+
GGML_ASSERT(src1);
4525+
GGML_ASSERT(src1->extra);
4526+
GGML_ASSERT(dst);
4527+
GGML_ASSERT(dst->extra);
4528+
4529+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
4530+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
4531+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
4532+
4533+
cl_ulong offset0 = extra0->offset + src0->view_offs;
4534+
cl_ulong offset1 = extra1->offset + src0->view_offs;
4535+
cl_ulong offsetd = extrad->offset + dst->view_offs;
4536+
4537+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4538+
4539+
float eps;
4540+
memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));
4541+
4542+
const int ne00 = src0->ne[0];
4543+
const int ne01 = src0->ne[1];
4544+
const int ne02 = src0->ne[2];
4545+
const int ne03 = src0->ne[3];
4546+
4547+
const cl_ulong nb01 = src0->nb[1];
4548+
const cl_ulong nb02 = src0->nb[2];
4549+
const cl_ulong nb03 = src0->nb[3];
4550+
4551+
const int ne10 = src1->ne[0];
4552+
const int ne11 = src1->ne[1];
4553+
const int ne12 = src1->ne[2];
4554+
const int ne13 = src1->ne[3];
4555+
4556+
const cl_ulong nb11 = src1->nb[1];
4557+
const cl_ulong nb12 = src1->nb[2];
4558+
const cl_ulong nb13 = src1->nb[3];
4559+
4560+
const cl_ulong nb1 = dst->nb[1];
4561+
const cl_ulong nb2 = dst->nb[2];
4562+
const cl_ulong nb3 = dst->nb[3];
4563+
4564+
GGML_ASSERT(ne00 % 4 == 0);
4565+
4566+
size_t sgs;
4567+
if (backend_ctx->gpu_family == ADRENO) {
4568+
sgs = 64;
4569+
} else if (backend_ctx->gpu_family == INTEL) {
4570+
sgs = 32;
4571+
} else {
4572+
GGML_ASSERT(false && "Unsupported GPU");
4573+
}
4574+
4575+
cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
4576+
4577+
int nth = sgs;
4578+
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
4579+
while (nth < ne00 && nth < max_workgroup_size) {
4580+
nth *= 2;
4581+
}
4582+
nth = MIN(nth, max_workgroup_size);
4583+
nth = MIN(nth, ne00);
4584+
4585+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
4586+
size_t local_work_size[] = {(size_t)nth, 1, 1};
4587+
4588+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4589+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
4590+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
4591+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
4592+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
4593+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4594+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
4595+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
4596+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
4597+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
4598+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
4599+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
4600+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
4601+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10));
4602+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11));
4603+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12));
4604+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13));
4605+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
4606+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
4607+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
4608+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));
4609+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
4610+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
4611+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
4612+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
4613+
4614+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
4615+
}
4616+
44584617
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
44594618
GGML_ASSERT(src0);
44604619
GGML_ASSERT(src0->extra);

ggml/src/ggml-opencl/kernels/rms_norm.cl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,82 @@ kernel void kernel_rms_norm(
9494
}
9595
}
9696
}
97+
98+
//------------------------------------------------------------------------------
99+
// rms_norm_mul
100+
//------------------------------------------------------------------------------
101+
#ifdef INTEL_GPU
102+
REQD_SUBGROUP_SIZE_32
103+
#elif defined (ADRENO_GPU)
104+
REQD_SUBGROUP_SIZE_64
105+
#endif
106+
kernel void kernel_rms_norm_mul(
107+
global char * src0,
108+
ulong offset0,
109+
global char * src1,
110+
ulong offset1,
111+
global char * dst,
112+
ulong offsetd,
113+
int ne00,
114+
int ne01,
115+
int ne02,
116+
int ne03,
117+
ulong nb01,
118+
ulong nb02,
119+
ulong nb03,
120+
int ne10,
121+
int ne11,
122+
int ne12,
123+
int ne13,
124+
ulong nb11,
125+
ulong nb12,
126+
ulong nb13,
127+
ulong nb1,
128+
ulong nb2,
129+
ulong nb3,
130+
float eps,
131+
local float * sum
132+
) {
133+
src0 = src0 + offset0;
134+
src1 = src1 + offset1;
135+
dst = dst + offsetd;
136+
137+
int i03 = get_group_id(2);
138+
int i02 = get_group_id(1);
139+
int i01 = get_group_id(0);
140+
141+
global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
142+
global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
143+
144+
float sumf = 0;
145+
146+
// parallel sum
147+
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
148+
sumf += dot(x[i00], x[i00]);
149+
}
150+
sumf = sub_group_reduce_add(sumf);
151+
if (get_sub_group_local_id() == 0) {
152+
sum[get_sub_group_id()] = sumf;
153+
}
154+
155+
barrier(CLK_LOCAL_MEM_FENCE);
156+
157+
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
158+
if (get_local_id(0) < i) {
159+
sum[get_local_id(0)] += sum[get_local_id(0) + i];
160+
}
161+
}
162+
if (get_local_id(0) == 0) {
163+
sum[0] /= ne00;
164+
}
165+
166+
barrier(CLK_LOCAL_MEM_FENCE);
167+
168+
float mean = sum[0];
169+
float scale = 1.0f/sqrt(mean + eps);
170+
171+
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
172+
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
173+
y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
174+
}
175+
}

0 commit comments

Comments
 (0)