From 1a519cd2dc65bc038d400ebd01c878e98b713fa0 Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 8 Aug 2025 07:55:10 +0000 Subject: [PATCH 1/6] refactor softmax --- ggml/src/ggml-cann/aclnn_ops.cpp | 371 +++++++++++-------------------- ggml/src/ggml-cann/ggml-cann.cpp | 4 +- tests/test-backend-ops.cpp | 2 + 3 files changed, 131 insertions(+), 246 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 07d6b8b67d47c..c0fc505ef3b47 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1329,158 +1329,119 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, GGML_CANN_CALL_ACLNN_OP(ctx, InplacePowTensorTensor, acl_dst, acl_exp); } -/** - * @brief Applies the Alibi (Attention with Linear Biases) mechanism to the - * @details This function implements the Alibi mechanism, which introduces - * learnable biases into the attention scores to simulate relative - * position encoding without the need for explicit positional - * embeddings. - * - * @param ctx The backend CANN context for executing operations. - * @param acl_src The source tensor representing the query or key. - * @param acl_position The position tensor containing relative positions. - * @param acl_dst The destination tensor where the result will be stored. - * @param n_head The number of attention heads. - * @param src_ne The dimensions of the source tensor. - * @param src_nb0 The byte size of the first dimension of the source - tensor. - * @param max_bias The maximum bias value used in the Alibi mechanism. - * @param dst The destination tensor object for additional metadata. - * - * The function performs the following steps: - * 1. Calculates the logarithm floor of the number of heads to determine the - base for bias calculation. - * 2. Initializes arrays with arithmetic sequences and fills them with bias - values. - * 3. Computes the bias tensor based on the calculated biases and arithmetic - sequences. - * 4. Reshapes the bias tensor to match the dimensions of the input tensors. - * 5. Multiplies the position tensor by the bias tensor. - * 6. Adds the result of the multiplication to the source tensor to produce the - final output. - */ -static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_position, aclTensor* acl_dst, - const int n_head, int64_t* src_ne, const size_t src_nb0, - float max_bias, ggml_tensor* dst) { - const int64_t ne2_ne3 = src_ne[2] * src_ne[3]; - GGML_ASSERT(src_nb0 == sizeof(float)); - GGML_ASSERT(n_head == src_ne[2]); - - const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); - - float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - // init arange - ggml_cann_pool_alloc arange_allocator(ctx.pool(), - ne2_ne3 * ggml_type_size(dst->type)); - void* tmp_arange_buffer = arange_allocator.get(); - // arange1: [1, ..., n_heads_log2_floor+1) - float start = 1; - float stop = n_heads_log2_floor + 1; - float step = 1; - int64_t n_elements_arange = n_heads_log2_floor; +static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, float m, int64_t size, float start, float stop, float step){ + int64_t ne[] = {size}; + size_t nb[] = {sizeof(float)}; - int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; - size_t tmp_arange1_nb[] = {sizeof(dst->type)}; - aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + ggml_cann_pool_alloc arange_allocator(ctx.pool(),size * sizeof(float)); + void* arange_buffer = arange_allocator.get(); + + aclTensor* arange_tensor = ggml_cann_create_tensor( + arange_buffer, ACL_FLOAT, + sizeof(float), ne, nb, 1); + aclnn_arange(ctx, arange_tensor, start, stop, step, size); - aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); - - aclTensor* tmp_arange2_tensor = nullptr; - if (n_heads_log2_floor < ne2_ne3) { - // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) - start = 1; - stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; - step = 2; - n_elements_arange = ne2_ne3 - n_heads_log2_floor; - int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; - size_t tmp_arange2_nb[] = {sizeof(dst->type)}; - - aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( - (char*)tmp_arange_buffer + - n_heads_log2_floor * ggml_type_size(dst->type), - ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), - tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, - n_elements_arange); - } + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + float* arange_host = new float[size]; + aclrtMemcpy(arange_host, size * 4, arange_buffer, size* 4, ACL_MEMCPY_DEVICE_TO_HOST); - // init mk_base - ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), - ne2_ne3 * ggml_type_size(dst->type)); - void* tmp_mk_base_buffer = mk_base_allocator.get(); - int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; - size_t tmp_mk_base1_nb[] = {sizeof(dst->type)}; - aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclTensor* slope_tensor = ggml_cann_create_tensor( + slope_buffer, ACL_FLOAT, + sizeof(float), ne, nb, 1); + + aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor); + ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor); +} + +static void aclnn_get_slope(ggml_backend_cann_context& ctx, int64_t n_head, void* slope_buffer, float max_bias) { + const int n_head_log2 = 1u << (uint32_t)floor(log2(n_head)); - aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); - - aclTensor* tmp_mk_base2_tensor = nullptr; - if (n_heads_log2_floor < ne2_ne3) { - int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; - size_t tmp_mk_base2_nb[] = {sizeof(dst->type)}; - aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( - (char*)tmp_mk_base_buffer + - n_heads_log2_floor * ggml_type_size(dst->type), - ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), - tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + // arange1 + float start = 0 + 1; + float end = (n_head_log2 - 1) + 1; + float step = 1; + float count = n_head_log2; + // end needs to be +1 because aclnn uses a left-closed, right-open interval. + aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step); + if (n_head_log2 < n_head) { + // arange2 + start = 2*(n_head_log2 - n_head_log2) + 1; + end = 2*((n_head - 1) - n_head_log2) + 1; + step = 2; + count = n_head - n_head_log2; + aclnn_get_slope_inner(ctx, (char*)slope_buffer + n_head_log2* sizeof(float), m1, count, start, end + 1, step); } +} - // init mk - int64_t tmp_mk_base_ne[] = {ne2_ne3}; - size_t tmp_mk_base_nb[] = {sizeof(dst->type)}; - aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); +static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, ggml_tensor* dst, void* dst_ptr, float max_bias) { + void* slope_buffer = nullptr; + void* bias_buffer = nullptr; + if (max_bias > 0.0f) { + int64_t n_heads = dst->ne[2]; + ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); + slope_buffer = slope_allocator.get(); + ggml_cann_pool_alloc bias_allocator(ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); + bias_buffer = bias_allocator.get(); + aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias); + } - // reshape mk - int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]}; - size_t tmp_mk_nb[GGML_MAX_DIMS]; - tmp_mk_nb[0] = ggml_type_size(dst->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; + // broadcast for mask, slop and dst; + GGML_ASSERT(ggml_is_contiguous(mask)); + int64_t nr2 = dst->ne[2] / mask->ne[2]; + int64_t nr3 = dst->ne[3] / mask->ne[3]; + + // broadcast the mask across rows + int64_t mask_ne[] = {mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1}; + size_t mask_nb[GGML_MAX_DIMS + 2]; + mask_nb[0] = ggml_element_size(mask); + for(int i = 1;itype), - ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); - // acl_position * mk - int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]}; - size_t tmp_output_nb[GGML_MAX_DIMS]; - tmp_output_nb[0] = ggml_type_size(dst->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1]; + // ne2 and ne3 may be integer multiples of the mask. + int64_t dst_ne[] = {dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3}; + size_t dst_nb[GGML_MAX_DIMS + 2]; + dst_nb[0] = ggml_element_size(dst); + for(int i = 1;ine[2], nr2, 1, 1}; + size_t slope_nb[GGML_MAX_DIMS + 2]; + slope_nb[0] = sizeof(float); + for(int i = 1;itype), - ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); - aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor); - // add - aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst); - ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, - tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, - tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor); + aclTensor* acl_slope = ggml_cann_create_tensor(slope_buffer, ACL_FLOAT, sizeof(float), slope_ne, slope_nb, GGML_MAX_DIMS + 2); + aclTensor* acl_mask = ggml_cann_create_tensor(mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2); + aclTensor* acl_dst = ggml_cann_create_tensor(dst_ptr, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst_ne, dst_nb, GGML_MAX_DIMS + 2); + + if (max_bias > 0.0f) { + int64_t bias_ne[] = {mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1}; + size_t bias_nb[GGML_MAX_DIMS + 2]; + bias_nb[0] = sizeof(float); + for(int i = 1;itype), src0->ne, + ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0)); + void* src_tensor_buffer = src_tensor_allocator.get(); + aclTensor* softmax_tensor = ggml_cann_create_tensor( + src_tensor_buffer, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), src0->ne, src0->nb, GGML_MAX_DIMS); - bool inplace = false; - aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace); + aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false); // mask - aclTensor* acl_src1_fp32_tensor = nullptr; - aclTensor* tmp_mask_tensor = nullptr; - ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool()); if (src1) { - const bool use_f16 = src1->type == GGML_TYPE_F16; - if (use_f16) { - // cast to fp32 - size_t n_bytes = ggml_nelements(src1) * sizeof(float_t); - size_t src1_fp32_nb[GGML_MAX_DIMS]; - src1_fp32_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1]; - } - src1_fp32_allocator.alloc(n_bytes); - void* src1_fp32_buffer = src1_fp32_allocator.get(); - acl_src1_fp32_tensor = ggml_cann_create_tensor( - src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne, - src1_fp32_nb, GGML_MAX_DIMS); - aclTensor* acl_src1 = ggml_cann_create_tensor(src1); - aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT); - ggml_cann_release_resources(ctx, acl_src1); - } else { - acl_src1_fp32_tensor = ggml_cann_create_tensor(src1); - } - - // broadcast the mask across rows, only use ne11 of ne01 in mask - if (src1->ne[1] != src0->ne[1]) { - // mask shape: [1,1,ne11,ne10] - int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1}; - size_t tmp_mask_nb[GGML_MAX_DIMS]; - tmp_mask_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1]; - } - tmp_mask_tensor = ggml_cann_create_tensor( - src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); - } - - // alibi - const int n_head = src0->ne[2]; - const size_t src_nb0 = src0->nb[0]; - - n_bytes = ggml_nbytes(dst); - ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes); - void* output_buffer = output_allocator.get(); - aclTensor* alibi_output_tensor = ggml_cann_create_tensor( - output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne, - dst->nb, GGML_MAX_DIMS); - if (max_bias <= 0.0f) { - // slope = 1.0 - if (tmp_mask_tensor) { - aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor, - alibi_output_tensor); - } else { - aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor, - alibi_output_tensor); - } - } else { - // slope != 1.0 - if (tmp_mask_tensor) { - aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor, - alibi_output_tensor, n_head, src0->ne, src_nb0, - max_bias, dst); - } else { - aclnn_alibi(ctx, acl_input_mul_scale_tensor, - acl_src1_fp32_tensor, alibi_output_tensor, n_head, - src0->ne, src_nb0, max_bias, dst); - } - } - - // softmax - aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst); - ggml_cann_release_resources(ctx, alibi_output_tensor); - } else { - aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst); + aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias); } - - ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst, - acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor); + // softmax + aclnn_softmax(ctx, softmax_tensor, 3, acl_dst); + ggml_cann_release_resources(ctx, acl_src0, acl_dst, + acl_scale, softmax_tensor); } /** @@ -3210,21 +3095,21 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ // alibi const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; const int64_t n_head = src0->ne[2]; - const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); - float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); - float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); + const int n_head_log2 = 1u << (uint32_t)floor(log2(n_head)); + float m0 = powf(2.0f, -(maxBias) / n_head_log2); + float m1 = powf(2.0f, -(maxBias / 2.0f) / n_head_log2); // init arange ggml_cann_pool_alloc arange_allocator(ctx.pool(), ne2_ne3 * faElemSize); void* tmp_arange_buffer = arange_allocator.get(); - // arange1: [1, ..., n_heads_log2_floor+1) + // arange1: [1, ..., n_head_log2+1) float start = 1; - float stop = n_heads_log2_floor + 1; + float stop = n_head_log2 + 1; float step = 1; - int64_t n_elements_arange = n_heads_log2_floor; + int64_t n_elements_arange = n_head_log2; - int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; + int64_t tmp_arange1_ne[] = {n_head_log2}; size_t tmp_arange1_nb[] = {faElemSize}; aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( tmp_arange_buffer, faDataType, faElemSize, @@ -3234,18 +3119,18 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); aclTensor* tmp_arange2_tensor = nullptr; - if (n_heads_log2_floor < ne2_ne3) { - // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) + if (n_head_log2 < ne2_ne3) { + // arange2: [1, ..., 2 * (k - n_head_log2) + 1) start = 1; - stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; + stop = 2 * (ne2_ne3 - n_head_log2) + 1; step = 2; - n_elements_arange = ne2_ne3 - n_heads_log2_floor; - int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + n_elements_arange = ne2_ne3 - n_head_log2; + int64_t tmp_arange2_ne[] = {ne2_ne3 - n_head_log2}; size_t tmp_arange2_nb[] = {faElemSize}; aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( (char*)tmp_arange_buffer + - n_heads_log2_floor * faElemSize, + n_head_log2 * faElemSize, faDataType, faElemSize, tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, @@ -3256,7 +3141,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), ne2_ne3 * faElemSize); void* tmp_mk_base_buffer = mk_base_allocator.get(); - int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; + int64_t tmp_mk_base1_ne[] = {n_head_log2}; size_t tmp_mk_base1_nb[] = {faElemSize}; aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( tmp_mk_base_buffer, faDataType, faElemSize, @@ -3266,12 +3151,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); aclTensor* tmp_mk_base2_tensor = nullptr; - if (n_heads_log2_floor < ne2_ne3) { - int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + if (n_head_log2 < ne2_ne3) { + int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_head_log2}; size_t tmp_mk_base2_nb[] = {faElemSize}; aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( (char*)tmp_mk_base_buffer + - n_heads_log2_floor * faElemSize, + n_head_log2 * faElemSize, faDataType, faElemSize, tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cf575b367500a..2eada54d1be44 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2503,9 +2503,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, if (op->src[2]) { return false; } - // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 - return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); + return true; case GGML_OP_FLASH_ATTN_EXT:{ // derived from [ggml-cuda.cu] if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d29779cd12b22..e72aa4b99d996 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5877,6 +5877,8 @@ static std::vector> make_test_cases_eval() { exponent <<= 1; } #endif + // SOFT_MAX(type=f32,ne=[16,16,1,3],mask=1,sinks=0,m_prec=f32,nr23=[3,1],scale=1.000000,max_bias=8.000000) + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 16, 1, 3}, 1, 0, GGML_TYPE_F32, {3, 1}, 1, 8)); for (bool mask : {false, true}) { for (bool sinks : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { From 66a5b82f423c7bee2302c7330476a57f80bd5ebe Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 8 Aug 2025 08:18:42 +0000 Subject: [PATCH 2/6] fix fa --- ggml/src/ggml-cann/aclnn_ops.cpp | 110 +++++-------------------------- ggml/src/ggml-cann/ggml-cann.cpp | 5 -- 2 files changed, 15 insertions(+), 100 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index c0fc505ef3b47..1e812b68b64a9 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3093,104 +3093,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ // Compute the slope if needed. Derived from ggml_cann_softmax(). if(maxBias != 0.0f){ // alibi - const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; - const int64_t n_head = src0->ne[2]; - const int n_head_log2 = 1u << (uint32_t)floor(log2(n_head)); - float m0 = powf(2.0f, -(maxBias) / n_head_log2); - float m1 = powf(2.0f, -(maxBias / 2.0f) / n_head_log2); - // init arange - ggml_cann_pool_alloc arange_allocator(ctx.pool(), - ne2_ne3 * faElemSize); - void* tmp_arange_buffer = arange_allocator.get(); - - // arange1: [1, ..., n_head_log2+1) - float start = 1; - float stop = n_head_log2 + 1; - float step = 1; - int64_t n_elements_arange = n_head_log2; - - int64_t tmp_arange1_ne[] = {n_head_log2}; - size_t tmp_arange1_nb[] = {faElemSize}; - aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, faDataType, faElemSize, - tmp_arange1_ne, tmp_arange1_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - - aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); - - aclTensor* tmp_arange2_tensor = nullptr; - if (n_head_log2 < ne2_ne3) { - // arange2: [1, ..., 2 * (k - n_head_log2) + 1) - start = 1; - stop = 2 * (ne2_ne3 - n_head_log2) + 1; - step = 2; - n_elements_arange = ne2_ne3 - n_head_log2; - int64_t tmp_arange2_ne[] = {ne2_ne3 - n_head_log2}; - size_t tmp_arange2_nb[] = {faElemSize}; - - aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( - (char*)tmp_arange_buffer + - n_head_log2 * faElemSize, - faDataType, faElemSize, - tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); - aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, - n_elements_arange); + const int64_t n_heads = src0->ne[2]; + ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); + void* slope_buffer = slope_allocator.get(); + aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias); + + int64_t slope_ne[] = {1, 1, n_heads, 1}; + size_t slope_nb[GGML_MAX_DIMS]; + slope_nb[0] = sizeof(float); + for(int i = 1;ine[2], src0->ne[3]}; - size_t tmp_mk_nb[GGML_MAX_DIMS]; - tmp_mk_nb[0] = faElemSize; - for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; - } - aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( - tmp_mk_base_buffer, faDataType, faElemSize, - tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); - - ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, - tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, - tmp_arange_tensor, tmp_mk_tensor); + ggml_cann_release_resources(ctx, slope_tensor); } } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 2eada54d1be44..9bb6b0a87c59b 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2530,11 +2530,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // DeepSeek MLA return false; } - // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 - if (op->src[0]->ne[3] != 1) { - return false; - } float logitSoftcap = 0.0f; memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float)); if(logitSoftcap != 0.0f) { From 44572327da6cd1d69f0c38f6eb4c8a9b2e328f25 Mon Sep 17 00:00:00 2001 From: hipudding Date: Mon, 11 Aug 2025 06:04:33 +0000 Subject: [PATCH 3/6] fix mask shape --- ggml/src/ggml-cann/aclnn_ops.cpp | 28 +++++++++++----------- ggml/src/ggml-cann/ggml-cann.cpp | 5 ++-- tests/test-backend-ops.cpp | 40 ++++++++++++++++---------------- 3 files changed, 37 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 1e812b68b64a9..dde7ead3b241c 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -812,7 +812,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_release_resources(ctx, src_trans_tensor); return; } else { - GGML_ABORT("Unsupport dst is not tontiguous."); + GGML_ABORT("Unsupport dst is not contiguous."); } } ggml_cann_release_resources(ctx, acl_src, acl_dst); @@ -1342,10 +1342,6 @@ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_bu sizeof(float), ne, nb, 1); aclnn_arange(ctx, arange_tensor, start, stop, step, size); - ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); - float* arange_host = new float[size]; - aclrtMemcpy(arange_host, size * 4, arange_buffer, size* 4, ACL_MEMCPY_DEVICE_TO_HOST); - aclTensor* slope_tensor = ggml_cann_create_tensor( slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); @@ -1383,12 +1379,14 @@ static void aclnn_get_slope(ggml_backend_cann_context& ctx, int64_t n_head, void static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, ggml_tensor* dst, void* dst_ptr, float max_bias) { void* slope_buffer = nullptr; void* bias_buffer = nullptr; + + int64_t n_heads = dst->ne[2]; + ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); + slope_buffer = slope_allocator.get(); + ggml_cann_pool_alloc bias_allocator(ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); + bias_buffer = bias_allocator.get(); + if (max_bias > 0.0f) { - int64_t n_heads = dst->ne[2]; - ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); - slope_buffer = slope_allocator.get(); - ggml_cann_pool_alloc bias_allocator(ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); - bias_buffer = bias_allocator.get(); aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias); } @@ -1400,10 +1398,12 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, g // broadcast the mask across rows int64_t mask_ne[] = {mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1}; size_t mask_nb[GGML_MAX_DIMS + 2]; - mask_nb[0] = ggml_element_size(mask); - for(int i = 1;inb[0]; + mask_nb[1] = mask->nb[1]; + mask_nb[2] = mask->nb[2]; + mask_nb[3] = mask->nb[2]; + mask_nb[4] = mask->nb[3]; + mask_nb[5] = mask->nb[3]; // ne2 and ne3 may be integer multiples of the mask. int64_t dst_ne[] = {dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3}; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 9bb6b0a87c59b..3d3520f195951 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2391,7 +2391,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // only support F32 and F16. return false; } - return true; + return ggml_is_contiguous(op); } break; case GGML_OP_CONT: { // TODO: support GGML_TYPE_BF16 @@ -2456,8 +2456,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // value of paddingW should be at most half of kernelW return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2)); } - case GGML_OP_SUM: case GGML_OP_DUP: + return ggml_is_contiguous(op); + case GGML_OP_SUM: case GGML_OP_IM2COL: case GGML_OP_CONCAT: case GGML_OP_REPEAT: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e72aa4b99d996..82827668ee375 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5461,26 +5461,26 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true)); - for(uint32_t Cout : {1, 9}){ - for(uint32_t Cin : {1, 7}){ - for(uint32_t K : {1, 3, 1337}){ - for(uint32_t L : {1, 2, 13}){ - for(uint32_t s0: {1, 2, 3}){ - test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1)); - } - } - } - } - } - - test_cases.emplace_back(new test_conv_transpose_1d()); - test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1)); - test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1)); - test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1)); - test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1)); - test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1)); - test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); - test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); + // for(uint32_t Cout : {1, 9}){ + // for(uint32_t Cin : {1, 7}){ + // for(uint32_t K : {1, 3, 1337}){ + // for(uint32_t L : {1, 2, 13}){ + // for(uint32_t s0: {1, 2, 3}){ + // test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1)); + // } + // } + // } + // } + // } + + // test_cases.emplace_back(new test_conv_transpose_1d()); + // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1)); + // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1)); + // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1)); + // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1)); + // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1)); + // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); + // test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1)); test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2)); From 927987baa7ff54d824cc8a060c9ab394bb70c403 Mon Sep 17 00:00:00 2001 From: hipudding Date: Mon, 11 Aug 2025 06:24:06 +0000 Subject: [PATCH 4/6] format --- ggml/src/ggml-cann/aclnn_ops.cpp | 145 ++++++++++++++++--------------- tests/test-backend-ops.cpp | 40 ++++----- 2 files changed, 97 insertions(+), 88 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index dde7ead3b241c..8428aad2d772b 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1330,60 +1330,68 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, } -static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, float m, int64_t size, float start, float stop, float step){ +static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, + float m, int64_t size, float start, float stop, float step){ int64_t ne[] = {size}; size_t nb[] = {sizeof(float)}; - ggml_cann_pool_alloc arange_allocator(ctx.pool(),size * sizeof(float)); - void* arange_buffer = arange_allocator.get(); - - aclTensor* arange_tensor = ggml_cann_create_tensor( - arange_buffer, ACL_FLOAT, - sizeof(float), ne, nb, 1); + ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float)); + void * arange_buffer = arange_allocator.get(); + + aclTensor * arange_tensor = ggml_cann_create_tensor( + arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); aclnn_arange(ctx, arange_tensor, start, stop, step, size); - aclTensor* slope_tensor = ggml_cann_create_tensor( - slope_buffer, ACL_FLOAT, - sizeof(float), ne, nb, 1); + aclTensor * slope_tensor = ggml_cann_create_tensor( + slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); - aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); + aclScalar * sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor); ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor); } -static void aclnn_get_slope(ggml_backend_cann_context& ctx, int64_t n_head, void* slope_buffer, float max_bias) { - const int n_head_log2 = 1u << (uint32_t)floor(log2(n_head)); +static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, + void * slope_buffer, float max_bias) { + const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - // const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + // const float slope = (max_bias > 0.0f) ? + // h < n_head_log2 ? + // powf(m0, h + 1) : + // powf(m1, 2*(h - n_head_log2) + 1) : + // 1.0f; // arange1 float start = 0 + 1; - float end = (n_head_log2 - 1) + 1; - float step = 1; + float end = (n_head_log2 - 1) + 1; + float step = 1; float count = n_head_log2; // end needs to be +1 because aclnn uses a left-closed, right-open interval. aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step); if (n_head_log2 < n_head) { // arange2 - start = 2*(n_head_log2 - n_head_log2) + 1; - end = 2*((n_head - 1) - n_head_log2) + 1; - step = 2; + start = 2 * (n_head_log2 - n_head_log2) + 1; + end = 2 * ((n_head - 1) - n_head_log2) + 1; + step = 2; count = n_head - n_head_log2; - aclnn_get_slope_inner(ctx, (char*)slope_buffer + n_head_log2* sizeof(float), m1, count, start, end + 1, step); + aclnn_get_slope_inner( + ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), + m1, count, start, end + 1, step); } } -static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, ggml_tensor* dst, void* dst_ptr, float max_bias) { +static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, + ggml_tensor* dst, void* dst_ptr, float max_bias) { void* slope_buffer = nullptr; void* bias_buffer = nullptr; int64_t n_heads = dst->ne[2]; ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); slope_buffer = slope_allocator.get(); - ggml_cann_pool_alloc bias_allocator(ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); + ggml_cann_pool_alloc bias_allocator( + ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); bias_buffer = bias_allocator.get(); if (max_bias > 0.0f) { @@ -1396,44 +1404,46 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, g int64_t nr3 = dst->ne[3] / mask->ne[3]; // broadcast the mask across rows - int64_t mask_ne[] = {mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1}; - size_t mask_nb[GGML_MAX_DIMS + 2]; - mask_nb[0] = mask->nb[0]; - mask_nb[1] = mask->nb[1]; - mask_nb[2] = mask->nb[2]; - mask_nb[3] = mask->nb[2]; - mask_nb[4] = mask->nb[3]; - mask_nb[5] = mask->nb[3]; - - // ne2 and ne3 may be integer multiples of the mask. - int64_t dst_ne[] = {dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3}; - size_t dst_nb[GGML_MAX_DIMS + 2]; - dst_nb[0] = ggml_element_size(dst); - for(int i = 1;ine[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 }; + size_t mask_nb[] = { + mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2], + mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3] + }; + + int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 }; + size_t dst_nb[] = { + dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2], + dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3] + }; // slope is a 1 dim tensor, slope.ne2 == dst.ne2 - int64_t slope_ne[] = {1, 1, mask->ne[2], nr2, 1, 1}; - size_t slope_nb[GGML_MAX_DIMS + 2]; + int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 }; + size_t slope_nb[GGML_MAX_DIMS + 2]; slope_nb[0] = sizeof(float); - for(int i = 1;itype), - ggml_type_size(dst->type), dst_ne, dst_nb, GGML_MAX_DIMS + 2); - + aclTensor * acl_slope = ggml_cann_create_tensor( + slope_buffer, ACL_FLOAT, sizeof(float), + slope_ne, slope_nb, GGML_MAX_DIMS + 2); + aclTensor * acl_mask = ggml_cann_create_tensor( + mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2); + aclTensor * acl_dst = ggml_cann_create_tensor( + dst_ptr, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst_ne, dst_nb, + GGML_MAX_DIMS + 2); + if (max_bias > 0.0f) { - int64_t bias_ne[] = {mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1}; - size_t bias_nb[GGML_MAX_DIMS + 2]; + int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 }; + size_t bias_nb[GGML_MAX_DIMS + 2]; bias_nb[0] = sizeof(float); - for(int i = 1;isrc[0]; - ggml_tensor* src1 = dst->src[1]; // mask +void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; // mask - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src0 = ggml_cann_create_tensor(src0); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); - float scale = 1.0f; + float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, (float*)dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float)); + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // input mul scale - aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); + aclScalar * acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0)); void* src_tensor_buffer = src_tensor_allocator.get(); aclTensor* softmax_tensor = ggml_cann_create_tensor( - src_tensor_buffer, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), src0->ne, - src0->nb, GGML_MAX_DIMS); + src_tensor_buffer, ggml_cann_type_mapping(src0->type), + ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS); aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false); @@ -1496,8 +1506,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { } // softmax aclnn_softmax(ctx, softmax_tensor, 3, acl_dst); - ggml_cann_release_resources(ctx, acl_src0, acl_dst, - acl_scale, softmax_tensor); + ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor); } /** diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 82827668ee375..e72aa4b99d996 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5461,26 +5461,26 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true)); - // for(uint32_t Cout : {1, 9}){ - // for(uint32_t Cin : {1, 7}){ - // for(uint32_t K : {1, 3, 1337}){ - // for(uint32_t L : {1, 2, 13}){ - // for(uint32_t s0: {1, 2, 3}){ - // test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1)); - // } - // } - // } - // } - // } - - // test_cases.emplace_back(new test_conv_transpose_1d()); - // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1)); - // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1)); - // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1)); - // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1)); - // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1)); - // test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); - // test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); + for(uint32_t Cout : {1, 9}){ + for(uint32_t Cin : {1, 7}){ + for(uint32_t K : {1, 3, 1337}){ + for(uint32_t L : {1, 2, 13}){ + for(uint32_t s0: {1, 2, 3}){ + test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1)); + } + } + } + } + } + + test_cases.emplace_back(new test_conv_transpose_1d()); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1)); test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2)); From 50074898a0845192655c6c1b552554de4e8f5967 Mon Sep 17 00:00:00 2001 From: hipudding Date: Mon, 11 Aug 2025 06:38:05 +0000 Subject: [PATCH 5/6] add comments --- ggml/src/ggml-cann/aclnn_ops.cpp | 113 ++++++++++++++++++++++++------- tests/test-backend-ops.cpp | 2 - 2 files changed, 89 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 8428aad2d772b..8bbba255cb0db 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1329,30 +1329,74 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, GGML_CANN_CALL_ACLNN_OP(ctx, InplacePowTensorTensor, acl_dst, acl_exp); } - +/** + * @brief Generate a range of values and apply a scalar base exponentiation. + * + * This function creates an evenly spaced sequence from `start` to `stop` (exclusive), + * with step size `step`, stores it in a temporary buffer, and then computes: + * + * @f[ + * slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size + * @f] + * + * The results are written to the provided @p slope_buffer. + * + * @param ctx CANN backend context for memory allocation and operator execution. + * @param slope_buffer Pointer to the output buffer (float array) for the computed slope values. + * @param m Scalar base for the exponentiation. + * @param size Number of elements in the generated sequence. + * @param start Starting exponent offset. + * @param stop Stopping exponent offset (exclusive). + * @param step Step size for the exponent increment. + */ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, float m, int64_t size, float start, float stop, float step){ int64_t ne[] = {size}; size_t nb[] = {sizeof(float)}; ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float)); - void * arange_buffer = arange_allocator.get(); + void* arange_buffer = arange_allocator.get(); - aclTensor * arange_tensor = ggml_cann_create_tensor( + aclTensor* arange_tensor = ggml_cann_create_tensor( arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); aclnn_arange(ctx, arange_tensor, start, stop, step, size); - aclTensor * slope_tensor = ggml_cann_create_tensor( + aclTensor* slope_tensor = ggml_cann_create_tensor( slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); - aclScalar * sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); + aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor); ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor); } +/** + * @brief Compute slope values for multiple attention heads based on ALiBi bias parameters. + * + * This function generates slope values for each attention head according to the ALiBi + * (Attention with Linear Biases) method. It splits the computation into two ranges depending + * on whether the head index is less than @p n_head_log2 or not, and uses different base values + * (`m0` and `m1`) for the exponentiation. + * + * @f[ + * slope[h] = + * \begin{cases} + * m_0^{(h + 1)}, & h < n\_head\_log2 \\ + * m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2 + * \end{cases} + * \quad , \quad \text{if } max\_bias > 0 + * @f] + * + * If @p max_bias <= 0, all slope values are set to 1.0. + * + * @param ctx CANN backend context for memory allocation and operator execution. + * @param n_head Total number of attention heads. + * @param slope_buffer Pointer to the output buffer (float array) for storing slopes. + * @param max_bias Maximum bias value for slope computation. + * +*/ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, - void * slope_buffer, float max_bias) { + void* slope_buffer, float max_bias) { const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); float m0 = powf(2.0f, -(max_bias) / n_head_log2); @@ -1382,24 +1426,43 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, } } +/** + * @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask. + * + * This function computes the ALiBi slopes for each attention head (if max_bias > 0), + * multiplies them with the attention mask to produce bias tensors, and adds these biases + * to the destination tensor (@p dst). + * + * The function performs necessary broadcasting of the mask and slope tensors to match + * the shape of the destination tensor, then applies element-wise multiplication and addition + * using CANN operators. + * + * @param ctx CANN backend context for memory management and operator execution. + * @param mask Input attention mask tensor, assumed to be contiguous. + * @param dst Destination tensor to which ALiBi biases will be added. + * @param dst_ptr Pointer to the memory of the destination tensor. + * @param max_bias Maximum bias value controlling the slope scaling. + * + * @note + * - Write data into dst_ptr using only the shape information of the dst tensor. + * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting. + */ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, ggml_tensor* dst, void* dst_ptr, float max_bias) { void* slope_buffer = nullptr; void* bias_buffer = nullptr; - int64_t n_heads = dst->ne[2]; - ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); - slope_buffer = slope_allocator.get(); - ggml_cann_pool_alloc bias_allocator( - ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); - bias_buffer = bias_allocator.get(); - if (max_bias > 0.0f) { + int64_t n_heads = dst->ne[2]; + ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); + slope_buffer = slope_allocator.get(); + ggml_cann_pool_alloc bias_allocator( + ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); + bias_buffer = bias_allocator.get(); aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias); } // broadcast for mask, slop and dst; - GGML_ASSERT(ggml_is_contiguous(mask)); int64_t nr2 = dst->ne[2] / mask->ne[2]; int64_t nr3 = dst->ne[3] / mask->ne[3]; @@ -1424,12 +1487,14 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1]; } - aclTensor * acl_slope = ggml_cann_create_tensor( + aclTensor* acl_slope = ggml_cann_create_tensor( slope_buffer, ACL_FLOAT, sizeof(float), slope_ne, slope_nb, GGML_MAX_DIMS + 2); - aclTensor * acl_mask = ggml_cann_create_tensor( + aclTensor* acl_mask = ggml_cann_create_tensor( mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2); - aclTensor * acl_dst = ggml_cann_create_tensor( + + // write data into dst_ptr using only the shape information of the dst tensor. + aclTensor* acl_dst = ggml_cann_create_tensor( dst_ptr, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), dst_ne, dst_nb, GGML_MAX_DIMS + 2); @@ -1441,7 +1506,7 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, for (int i = 1; i < GGML_MAX_DIMS + 2; i++) { bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1]; } - aclTensor * bias_tensor = ggml_cann_create_tensor( + aclTensor* bias_tensor = ggml_cann_create_tensor( bias_buffer, ACL_FLOAT, sizeof(float), bias_ne, bias_nb, GGML_MAX_DIMS + 2); @@ -1473,16 +1538,16 @@ void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) { * stored. */ static void aclnn_softmax(ggml_backend_cann_context & ctx, - aclTensor * acl_src, int64_t dim, aclTensor * acl_dst) { + aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) { GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst); } void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; - ggml_tensor * src1 = dst->src[1]; // mask + ggml_tensor* src0 = dst->src[0]; + ggml_tensor* src1 = dst->src[1]; // mask - aclTensor * acl_src0 = ggml_cann_create_tensor(src0); - aclTensor * acl_dst = ggml_cann_create_tensor(dst); + aclTensor* acl_src0 = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); float scale = 1.0f; float max_bias = 0.0f; @@ -1491,7 +1556,7 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // input mul scale - aclScalar * acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); + aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0)); void* src_tensor_buffer = src_tensor_allocator.get(); aclTensor* softmax_tensor = ggml_cann_create_tensor( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e72aa4b99d996..d29779cd12b22 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5877,8 +5877,6 @@ static std::vector> make_test_cases_eval() { exponent <<= 1; } #endif - // SOFT_MAX(type=f32,ne=[16,16,1,3],mask=1,sinks=0,m_prec=f32,nr23=[3,1],scale=1.000000,max_bias=8.000000) - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 16, 1, 3}, 1, 0, GGML_TYPE_F32, {3, 1}, 1, 8)); for (bool mask : {false, true}) { for (bool sinks : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { From e4abd0795e603ceffb5f24fc66010e796e728fc0 Mon Sep 17 00:00:00 2001 From: hipudding Date: Mon, 11 Aug 2025 12:18:10 +0000 Subject: [PATCH 6/6] Remove whitespace --- ggml/src/ggml-cann/aclnn_ops.cpp | 36 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 8bbba255cb0db..0b409ce87d2ab 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1349,7 +1349,7 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, * @param stop Stopping exponent offset (exclusive). * @param step Step size for the exponent increment. */ -static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, +static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, float m, int64_t size, float start, float stop, float step){ int64_t ne[] = {size}; size_t nb[] = {sizeof(float)}; @@ -1395,17 +1395,17 @@ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_bu * @param max_bias Maximum bias value for slope computation. * */ -static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, +static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, void* slope_buffer, float max_bias) { const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - // const float slope = (max_bias > 0.0f) ? - // h < n_head_log2 ? - // powf(m0, h + 1) : - // powf(m1, 2*(h - n_head_log2) + 1) : + // const float slope = (max_bias > 0.0f) ? + // h < n_head_log2 ? + // powf(m0, h + 1) : + // powf(m1, 2*(h - n_head_log2) + 1) : // 1.0f; // arange1 float start = 0 + 1; @@ -1421,7 +1421,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, step = 2; count = n_head - n_head_log2; aclnn_get_slope_inner( - ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), + ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step); } } @@ -1447,7 +1447,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, * - Write data into dst_ptr using only the shape information of the dst tensor. * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting. */ -static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, +static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, ggml_tensor* dst, void* dst_ptr, float max_bias) { void* slope_buffer = nullptr; void* bias_buffer = nullptr; @@ -1468,15 +1468,15 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, // broadcast the mask across rows int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 }; - size_t mask_nb[] = { + size_t mask_nb[] = { mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2], - mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3] + mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3] }; int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 }; - size_t dst_nb[] = { + size_t dst_nb[] = { dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2], - dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3] + dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3] }; // slope is a 1 dim tensor, slope.ne2 == dst.ne2 @@ -1488,15 +1488,15 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, } aclTensor* acl_slope = ggml_cann_create_tensor( - slope_buffer, ACL_FLOAT, sizeof(float), + slope_buffer, ACL_FLOAT, sizeof(float), slope_ne, slope_nb, GGML_MAX_DIMS + 2); aclTensor* acl_mask = ggml_cann_create_tensor( mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2); - + // write data into dst_ptr using only the shape information of the dst tensor. aclTensor* acl_dst = ggml_cann_create_tensor( - dst_ptr, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst_ne, dst_nb, + dst_ptr, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst_ne, dst_nb, GGML_MAX_DIMS + 2); if (max_bias > 0.0f) { @@ -1507,7 +1507,7 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1]; } aclTensor* bias_tensor = ggml_cann_create_tensor( - bias_buffer, ACL_FLOAT, sizeof(float), + bias_buffer, ACL_FLOAT, sizeof(float), bias_ne, bias_nb, GGML_MAX_DIMS + 2); aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor); @@ -1537,7 +1537,7 @@ void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) { * @param acl_dst The destination tensor where the softmax results will be * stored. */ -static void aclnn_softmax(ggml_backend_cann_context & ctx, +static void aclnn_softmax(ggml_backend_cann_context & ctx, aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) { GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst); }