@@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
333
333
size_t max_alloc_size;
334
334
bool fp16_support;
335
335
bool has_vector_subgroup_broadcast;
336
+ bool disable_fusion;
336
337
ggml_cl_compiler_version adreno_cl_compiler_version;
337
338
338
339
int adreno_wave_size;
@@ -411,7 +412,7 @@ struct ggml_backend_opencl_context {
411
412
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
412
413
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
413
414
cl_kernel kernel_norm;
414
- cl_kernel kernel_rms_norm;
415
+ cl_kernel kernel_rms_norm, kernel_rms_norm_mul ;
415
416
cl_kernel kernel_group_norm;
416
417
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
417
418
cl_kernel kernel_soft_max, kernel_soft_max_4;
@@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1100
1101
backend_ctx->program_rms_norm =
1101
1102
build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1102
1103
1103
- CL_CHECK ((backend_ctx->kernel_rms_norm = clCreateKernel (backend_ctx->program_rms_norm , " kernel_rms_norm" , &err), err));
1104
+ CL_CHECK ((backend_ctx->kernel_rms_norm = clCreateKernel (backend_ctx->program_rms_norm , " kernel_rms_norm" , &err), err));
1105
+ CL_CHECK ((backend_ctx->kernel_rms_norm_mul = clCreateKernel (backend_ctx->program_rms_norm , " kernel_rms_norm_mul" , &err), err));
1104
1106
GGML_LOG_CONT (" ." );
1105
1107
}
1106
1108
@@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
2110
2112
CL_CHECK ((backend_ctx->B_d_max = clCreateBuffer (context, 0 , max_B_d_bytes, NULL , &err), err));
2111
2113
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
2112
2114
2115
+ backend_ctx->disable_fusion = getenv (" GGML_OPENCL_DISABLE_FUSION" ) != nullptr ;
2116
+
2113
2117
dev_ctx->backend_ctx = backend_ctx.release ();
2114
2118
return dev_ctx->backend_ctx ;
2115
2119
}
@@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) {
2279
2283
sync_with_other_backends (backend_ctx);
2280
2284
}
2281
2285
2286
+ static bool ggml_opencl_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2287
+ if (!ggml_can_fuse (cgraph, node_idx, ops)) {
2288
+ return false ;
2289
+ }
2290
+
2291
+ if (ops.size () == 2 && ops.begin ()[0 ] == GGML_OP_RMS_NORM && ops.begin ()[1 ] == GGML_OP_MUL) {
2292
+ const ggml_tensor *rms_norm = cgraph->nodes [node_idx];
2293
+ const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
2294
+
2295
+ GGML_ASSERT (rms_norm->src [0 ]->type == GGML_TYPE_F32);
2296
+ GGML_ASSERT (rms_norm->type == GGML_TYPE_F32);
2297
+
2298
+ // rms_norm only supports f32
2299
+ if (mul->src [0 ]->type != GGML_TYPE_F32 ||
2300
+ mul->src [1 ]->type != GGML_TYPE_F32 ||
2301
+ mul->type != GGML_TYPE_F32) {
2302
+ return false ;
2303
+ }
2304
+
2305
+ // if rms_norm is the B operand, then we don't handle broadcast
2306
+ if (rms_norm == mul->src [1 ] &&
2307
+ !ggml_are_same_shape (mul->src [0 ], rms_norm->src [1 ])) {
2308
+ return false ;
2309
+ }
2310
+
2311
+ // rms_norm assumes contiguous rows
2312
+ if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
2313
+ return false ;
2314
+ }
2315
+ }
2316
+
2317
+ return true ;
2318
+ }
2319
+
2320
+ static void ggml_opencl_op_rms_norm_fused (ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
2321
+
2282
2322
static ggml_status ggml_backend_opencl_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
2323
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
2324
+
2283
2325
for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2284
2326
ggml_tensor * node = cgraph->nodes [i];
2285
2327
@@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
2292
2334
continue ;
2293
2335
}
2294
2336
2337
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2338
+ ggml_opencl_op_rms_norm_fused (backend, node, cgraph->nodes [i+1 ]);
2339
+ i++;
2340
+ continue ;
2341
+ }
2342
+
2295
2343
bool ok = ggml_cl_compute_forward (backend, node);
2296
2344
if (!ok) {
2297
2345
GGML_LOG_ERROR (" %s: error: op not supported %s (%s)\n " , __func__, node->name , ggml_op_name (node->op ));
@@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
4455
4503
backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
4456
4504
}
4457
4505
4506
+ static void ggml_opencl_op_rms_norm_fused (ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
4507
+ GGML_ASSERT (mul_tensor);
4508
+ GGML_ASSERT (rms_norm_tensor);
4509
+
4510
+ // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
4511
+ const ggml_tensor * src0 = rms_norm_tensor->src [0 ];
4512
+ const ggml_tensor * src1;
4513
+ if (mul_tensor->src [0 ] == rms_norm_tensor) {
4514
+ src1 = mul_tensor->src [1 ];
4515
+ } else if (mul_tensor->src [1 ] == rms_norm_tensor) {
4516
+ src1 = mul_tensor->src [0 ];
4517
+ } else {
4518
+ GGML_ASSERT (false && " Invalid args for rms_norm and mul" );
4519
+ }
4520
+ const ggml_tensor * dst = mul_tensor;
4521
+
4522
+ GGML_ASSERT (src0);
4523
+ GGML_ASSERT (src0->extra );
4524
+ GGML_ASSERT (src1);
4525
+ GGML_ASSERT (src1->extra );
4526
+ GGML_ASSERT (dst);
4527
+ GGML_ASSERT (dst->extra );
4528
+
4529
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
4530
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
4531
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
4532
+
4533
+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
4534
+ cl_ulong offset1 = extra1->offset + src0->view_offs ;
4535
+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
4536
+
4537
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
4538
+
4539
+ float eps;
4540
+ memcpy (&eps, rms_norm_tensor->op_params , sizeof (float ));
4541
+
4542
+ const int ne00 = src0->ne [0 ];
4543
+ const int ne01 = src0->ne [1 ];
4544
+ const int ne02 = src0->ne [2 ];
4545
+ const int ne03 = src0->ne [3 ];
4546
+
4547
+ const cl_ulong nb01 = src0->nb [1 ];
4548
+ const cl_ulong nb02 = src0->nb [2 ];
4549
+ const cl_ulong nb03 = src0->nb [3 ];
4550
+
4551
+ const int ne10 = src1->ne [0 ];
4552
+ const int ne11 = src1->ne [1 ];
4553
+ const int ne12 = src1->ne [2 ];
4554
+ const int ne13 = src1->ne [3 ];
4555
+
4556
+ const cl_ulong nb11 = src1->nb [1 ];
4557
+ const cl_ulong nb12 = src1->nb [2 ];
4558
+ const cl_ulong nb13 = src1->nb [3 ];
4559
+
4560
+ const cl_ulong nb1 = dst->nb [1 ];
4561
+ const cl_ulong nb2 = dst->nb [2 ];
4562
+ const cl_ulong nb3 = dst->nb [3 ];
4563
+
4564
+ GGML_ASSERT (ne00 % 4 == 0 );
4565
+
4566
+ size_t sgs;
4567
+ if (backend_ctx->gpu_family == ADRENO) {
4568
+ sgs = 64 ;
4569
+ } else if (backend_ctx->gpu_family == INTEL) {
4570
+ sgs = 32 ;
4571
+ } else {
4572
+ GGML_ASSERT (false && " Unsupported GPU" );
4573
+ }
4574
+
4575
+ cl_kernel kernel = backend_ctx->kernel_rms_norm_mul ;
4576
+
4577
+ int nth = sgs;
4578
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size (kernel);
4579
+ while (nth < ne00 && nth < max_workgroup_size) {
4580
+ nth *= 2 ;
4581
+ }
4582
+ nth = MIN (nth, max_workgroup_size);
4583
+ nth = MIN (nth, ne00);
4584
+
4585
+ size_t global_work_size[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
4586
+ size_t local_work_size[] = {(size_t )nth, 1 , 1 };
4587
+
4588
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
4589
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
4590
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
4591
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
4592
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
4593
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
4594
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &ne00));
4595
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &ne01));
4596
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne02));
4597
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne03));
4598
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb01));
4599
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb02));
4600
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb03));
4601
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &ne10));
4602
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &ne11));
4603
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne12));
4604
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne13));
4605
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb11));
4606
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &nb12));
4607
+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb13));
4608
+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb1));
4609
+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb2));
4610
+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (cl_ulong), &nb3));
4611
+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (float ), &eps));
4612
+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (float )*nth/sgs, NULL ));
4613
+
4614
+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
4615
+ }
4616
+
4458
4617
static void ggml_cl_group_norm (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4459
4618
GGML_ASSERT (src0);
4460
4619
GGML_ASSERT (src0->extra );
0 commit comments