Skip to content

Optimize the kernel implementation of layernorm with openmp #20895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 114 additions & 104 deletions paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,131 +26,141 @@ namespace intrinsic {
void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height,
const float epsilon, int right) {
__m256 sum;
__m256 mean_vec, var_vec;
__m128 hi, lo;
__m256 tmp;
size_t offset;
size_t j;
int block = YMM_FLOAT_BLOCK;
const int rest = right % block;
const int end = right - rest;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel
{
#endif
__m256 sum;
__m256 mean_vec, var_vec;
__m128 hi, lo;
__m256 tmp;
size_t offset;
size_t j;
__m256 reverse_num_vec =
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right));
__m256 epsilon_vec = _mm256_set1_ps(epsilon);
int rest_mask =
((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff;
__m256i mask_vec = _mm256_set_epi32(
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0,
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0,
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0,
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以拆成#pragma omp parallel#paragma omp for两个部分,这样__m256i mask_vec就只需设置一次?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改好了


__m256 reverse_num_vec =
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right));
__m256 epsilon_vec = _mm256_set1_ps(epsilon);
int rest_mask =
((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff;
__m256i mask_vec = _mm256_set_epi32(
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0,
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0,
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0,
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0);
#ifdef PADDLE_WITH_MKLML
#pragma omp for
#endif
for (int i = 0; i < height; ++i) {
offset = i * right;

for (int i = 0; i < height; ++i) {
offset = i * right;

/* get mean */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j));
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)x + j);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
mean_vec = _mm256_mul_ps(sum, reverse_num_vec);
mean[i] = *reinterpret_cast<float*>(&mean_vec);

/* get variance */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
sum = _mm256_add_ps(sum, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
var_vec = _mm256_mul_ps(sum, reverse_num_vec);
var[i] = *reinterpret_cast<float*>(&var_vec);

/* get x_norm and calculate output*/
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}

if (scale) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
/* get mean */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)scale + j - offset)));
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(tmp,
_mm256_loadu_ps((const float*)scale + j - offset)));
tmp = _mm256_loadu_ps((const float*)x + j);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
mean_vec = _mm256_mul_ps(sum, reverse_num_vec);
mean[i] = *reinterpret_cast<float*>(&mean_vec);

if (bias) {
/* get variance */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
sum = _mm256_add_ps(sum, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
var_vec = _mm256_mul_ps(sum, reverse_num_vec);
var[i] = *reinterpret_cast<float*>(&var_vec);

/* get x_norm and calculate output*/
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_add_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)bias + j - offset)));
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j,
_mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias +
j - offset)));
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}

if (scale) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)scale + j - offset)));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(tmp,
_mm256_loadu_ps((const float*)scale + j - offset)));
}
}

if (bias) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_add_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)bias + j - offset)));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_add_ps(tmp,
_mm256_loadu_ps((const float*)bias + j - offset)));
}
}
}
#ifdef PADDLE_WITH_MKLML
}
#endif
}

bool LayerNormKernel::CanBeUsed(const int& d) const {
Expand Down