@@ -49,8 +49,7 @@ set_lstm_params_region(LstmParam<OpTensor>& param, int word_size) {
49
49
50
50
for (int i = 0 ; i < _cudnn_lstm_weights_layernum; i++) {
51
51
ParamsRegion& region = _inner_weight_region[i];
52
- // get_sub_tensor(cudnnW[i], (Op_dtype*) region._offset, region._size/hidden_size/4, hidden_size, 4*hidden_size, cuda_stream);
53
- get_sub_tensor<Op_dtype>(cudnnW[i], (Op_dtype*) region._offset , region._size /hidden_size/4 , hidden_size, 4 *hidden_size);
52
+ get_sub_tensor<Op_dtype>(cudnnW[i], (Op_dtype*) region._offset , region._size /(sizeof (Op_dtype) * hidden_size), hidden_size, 4 *hidden_size, cuda_stream);
54
53
}
55
54
56
55
for (int i = 0 ; i < _cudnn_lstm_weights_layernum; i++) {
@@ -63,18 +62,7 @@ set_lstm_params_region(LstmParam<OpTensor>& param, int word_size) {
63
62
CUDA_CHECK (cudaMemsetAsync ((void *)(region_b._offset ), 0 , region_b._size , cuda_stream));
64
63
}
65
64
}
66
- cudaDeviceSynchronize ();
67
65
}
68
- int region_id = 0 ;
69
- for (auto region : _inner_weight_region) {
70
- char buf[100 ];
71
- sprintf (buf, " ./lstm_%d.txt" , region_id);
72
- record_dev_tensorfile<NV>((Op_dtype*)region._offset , region._size /4 , buf);
73
- region_id++;
74
- }
75
- cudaDeviceSynchronize ();
76
- record_dev_tensorfile<NV>(param.weight ()->data (), param.weight ()->valid_size (), " lstm_param_weight.txt" );
77
- cudaDeviceSynchronize ();
78
66
}
79
67
80
68
template <>
@@ -183,7 +171,7 @@ create(const std::vector<DataTensor*>& inputs,
183
171
_y_desc.reset (new cudnn::TensorDescriptors<DataDtype>(
184
172
offset_after_sort,
185
173
{batch_size, _hidden_size * lstm_param.num_direction , 1 },
186
- {_hidden_size * lstm_param.num_direction , 1 , 1 }));
174
+ {_hiden_size * lstm_param.num_direction , 1 , 1 }));
187
175
188
176
Shape in_dim = inputs[0 ]->valid_shape ();
189
177
Shape in_stride = inputs[0 ]->get_stride ();
@@ -227,7 +215,7 @@ dispatch(const std::vector<DataTensor*>& inputs,
227
215
if (inputs.size () == 2 ) {
228
216
in_hidden_data = inputs[1 ]->data ();
229
217
}
230
- bool isHW2Seq = inputs[0 ]->get_seq_offset ().size () > 2 ;
218
+ bool isHW2Seq = inputs[0 ]->get_seq_offset ().size () > 2 || param. is_reverse ;
231
219
232
220
if (isHW2Seq) {
233
221
_temp_tensor_in.reshape (inputs[0 ]->valid_shape ());
0 commit comments