Skip to content

Commit 93c85c9

Browse files
zhaifyluotao1
authored andcommitted
Implement FusedEmbeddingSeqPoolGradKernel with cblas_saxpy (#19770)
* Implement the operator with sprase matrix multiply * Update the URL of mklml library. test=develop * Disable MKLML implematation when using no-linux. test=develop * optimize bp with mkl sparse matrix test=develop * tmp add fused_emb_seq layer * Add the support of padding_idx attribute. test=develop * add padding_idx support test=develop * implement grad refer lego test=develop
1 parent 2729c17 commit 93c85c9

File tree

3 files changed

+77
-37
lines changed

3 files changed

+77
-37
lines changed

paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc

+6
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
7878
"are supported, sum computes the weighted sum of the "
7979
"embedding results for each row.")
8080
.SetDefault("sum");
81+
AddAttr<int64_t>("padding_idx",
82+
"(int64, default -1) "
83+
"If the value is -1, it makes no effect to lookup. "
84+
"Otherwise the given value indicates padding the output "
85+
"with zeros whenever lookup encounters it in Ids.")
86+
.SetDefault(kNoPadding);
8187
// NOTE(minqiyang): grad_inplace is an temporal attribute,
8288
// please do NOT set this attribute in python layer.
8389
AddAttr<bool>("grad_inplace",

paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h

+31-26
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,15 @@ using LoDTensor = framework::LoDTensor;
3333
using SelectedRows = framework::SelectedRows;
3434
using DDim = framework::DDim;
3535

36+
constexpr int64_t kNoPadding = -1;
37+
3638
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
37-
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
39+
!defined(__OSX__)
3840
template <typename T>
3941
void prepare_csr_data(const std::vector<uint64_t> &offset,
4042
const int64_t *ids_data, const size_t idx_width,
41-
T *csr_vals, int *csr_colmuns, int *csr_row_idx) {
43+
T *csr_vals, int *csr_colmuns, int *csr_row_idx,
44+
int64_t padding_idx = kNoPadding) {
4245
int val_idx = 0;
4346
int row_idx = 0;
4447
csr_row_idx[0] = 0;
@@ -52,9 +55,11 @@ void prepare_csr_data(const std::vector<uint64_t> &offset,
5255

5356
// construct a map for creating csr
5457
for (size_t j = offset[i]; j < offset[i + 1]; ++j) {
55-
unsigned int word_idx =
56-
static_cast<unsigned int>(ids_data[idx + j * idx_width]);
57-
++ids_map[word_idx];
58+
auto ids_value = ids_data[idx + j * idx_width];
59+
if (ids_value != padding_idx) {
60+
unsigned int word_idx = static_cast<unsigned int>(ids_value);
61+
++ids_map[word_idx];
62+
}
5863
}
5964

6065
VLOG(4) << "====sequence %d====" << i;
@@ -124,16 +129,17 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
124129
FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims());
125130
const auto &ids_lod = ids_t->lod();
126131
// in run time, the LoD of ids must be 1
127-
PADDLE_ENFORCE(ids_lod.size(), 1UL,
128-
"The LoD level of Input(Ids) must be 1");
132+
PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL,
133+
"The LoD level of Input(Ids) must be 1");
129134
int64_t batch_size = ids_lod[0].size() - 1;
130135
// in run time, the shape from Ids -> output
131136
// should be [seq_length, 1] -> [batch_size, last_dim]
132137
output_t->Resize({batch_size, last_dim});
133138

134139
if (combiner_type == "sum") {
135140
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
136-
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
141+
!defined(__OSX__)
142+
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
137143
auto output = output_t->mutable_data<T>(context.GetPlace());
138144
int64_t table_height = table_var->dims()[0];
139145
int64_t table_width = table_var->dims()[1];
@@ -151,7 +157,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
151157
auto csr_colmuns = csr_colmuns_t.mutable_data<int>(context.GetPlace());
152158
auto csr_row_idx = csr_row_idx_t.mutable_data<int>(context.GetPlace());
153159
prepare_csr_data<T>(offset, ids_t->data<int64_t>(), idx_width, csr_vals,
154-
csr_colmuns, csr_row_idx);
160+
csr_colmuns, csr_row_idx, padding_idx);
155161

156162
const char transa = 'N';
157163
const T alpha = 1.0;
@@ -226,18 +232,19 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
226232
}
227233
} else {
228234
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
229-
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
235+
!defined(__OSX__)
230236
auto *ids = context.Input<LoDTensor>("Ids");
231237
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
232238
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
239+
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
233240

234241
d_table->Resize(table_dim);
235242
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
236243
memset(d_table_data, 0, d_table->numel() * sizeof(T));
237244

238245
const auto &ids_lod = ids->lod();
239-
PADDLE_ENFORCE(ids_lod.size(), 1UL,
240-
"The LoD level of Input(Ids) must be 1");
246+
PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL,
247+
"The LoD level of Input(Ids) must be 1");
241248
const std::vector<uint64_t> offset = ids_lod[0];
242249
auto len = ids->numel();
243250
int idx_width = len / offset.back();
@@ -251,23 +258,21 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
251258
auto csr_colmuns = csr_colmuns_t.mutable_data<int>(context.GetPlace());
252259
auto csr_row_idx = csr_row_idx_t.mutable_data<int>(context.GetPlace());
253260
prepare_csr_data<T>(offset, ids->data<int64_t>(), idx_width, csr_vals,
254-
csr_colmuns, csr_row_idx);
261+
csr_colmuns, csr_row_idx, padding_idx);
255262

256263
auto *d_output_data = d_output->data<T>();
257-
const char transa = 'T';
258-
const T alpha = 1.0;
259-
const T beta = 0.0;
260-
const char matdescra[] = {'G', 'L', 'N', 'C'};
261-
262-
const int m = batch_size * idx_width;
263-
const int n = table_dim[1];
264-
const int k = table_dim[1];
265-
266264
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
267-
blas.CSRMM(&transa, &m, &n, &k, &alpha, matdescra, (const T *)csr_vals,
268-
(const int *)csr_colmuns, (const int *)csr_row_idx,
269-
(const int *)csr_row_idx + 1, d_output_data, &n, &beta,
270-
d_table_data, &n);
265+
int width = static_cast<int>(table_dim[1]);
266+
int num_seq = batch_size * idx_width;
267+
LOG(INFO) << "num seq = " << num_seq << " width = " << width;
268+
for (int i = 0; i < num_seq; ++i) {
269+
for (int j = csr_row_idx[i]; j < csr_row_idx[i + 1]; ++j) {
270+
unsigned int word_idx = csr_colmuns[j];
271+
T val = csr_vals[j];
272+
blas.AXPY(width, val, d_output_data + i * width,
273+
d_table_data + word_idx * width);
274+
}
275+
}
271276
#else
272277
LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now";
273278
#endif

python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py

+40-11
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,67 @@
2222
import paddle.fluid as fluid
2323
from paddle.fluid.op import Operator
2424
import paddle.compat as cpt
25+
import paddle.version as ver
2526

2627

2728
class TestFusedEmbeddingSeqPoolOp(OpTest):
2829
def setUp(self):
2930
self.op_type = "fused_embedding_seq_pool"
3031
self.emb_size = 2
31-
table = np.random.random((17, self.emb_size)).astype("float32")
32-
ids = np.array([[[4], [3]], [[4], [3]], [[2], [1]],
33-
[[16], [1]]]).astype("int64")
34-
merged_ids = np.array([4, 2, 16]).astype("int64")
35-
ids_expand = np.expand_dims(ids, axis=1)
32+
self.table = np.random.random((17, self.emb_size)).astype("float32")
33+
self.ids = np.array([[[4], [3]], [[4], [3]], [[2], [1]],
34+
[[16], [1]]]).astype("int64")
35+
ids_expand = np.expand_dims(self.ids, axis=1)
3636
self.lod = [[3, 1]]
3737
self.attrs = {'is_sparse': True}
38-
self.inputs = {'W': table, 'Ids': (ids_expand, self.lod)}
38+
self.inputs = {'W': self.table, 'Ids': (ids_expand, self.lod)}
3939
self.outputs = {
4040
'Out': np.reshape(
4141
np.array([
42-
table[[4, 3]] + table[[4, 3]] + table[[2, 1]],
43-
table[[16, 1]]
42+
self.table[[4, 3]] + self.table[[4, 3]] +
43+
self.table[[2, 1]], self.table[[16, 1]]
4444
]), [len(self.lod[0]), 2 * self.emb_size])
4545
}
4646

4747
def test_check_output(self):
4848
self.check_output()
4949

5050
def test_check_grad(self):
51-
if fluid.core.is_compiled_with_mkldnn(
52-
) and not fluid.core.is_compiled_with_cuda(
53-
) and 'Linux' in platform.platform():
51+
if ver.mkl() == "ON" and 'Linux' in platform.platform():
5452
self.attrs = {'is_sparse': False}
5553
self.check_grad(['W'], 'Out', no_grad_set=('Ids'))
5654

5755

56+
class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp):
57+
def test_check_output(self):
58+
if ver.mkl() == "ON" and 'Linux' in platform.platform():
59+
ids = np.squeeze(self.ids, axis=2)
60+
padding_idx = np.random.choice(ids.flatten(), 1)[0]
61+
output = list()
62+
index = 0
63+
for count in self.lod[0]:
64+
arr = ids[index:count + index]
65+
out = np.reshape(self.table[arr.flatten()],
66+
[arr.shape[0], arr.shape[1], self.emb_size])
67+
idx = np.argwhere(arr == padding_idx)
68+
for item in idx:
69+
out[item[0], item[1], :] = np.zeros(self.emb_size)
70+
output.append(np.sum(out, 0))
71+
index += count
72+
self.outputs = {
73+
'Out': np.reshape(
74+
np.array(output), [len(self.lod[0]), 2 * self.emb_size])
75+
}
76+
self.attrs = {'padding_idx': int(padding_idx)}
77+
self.check_output()
78+
79+
def test_check_grad(self):
80+
if ver.mkl() == "ON" and 'Linux' in platform.platform():
81+
ids = np.squeeze(self.ids, axis=2)
82+
padding_idx = np.random.choice(ids.flatten(), 1)[0]
83+
self.attrs = {'padding_idx': int(padding_idx), 'is_sparse': False}
84+
self.check_grad(['W'], 'Out', no_grad_set=('Ids'))
85+
86+
5887
if __name__ == "__main__":
5988
unittest.main()

0 commit comments

Comments
 (0)