Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit dda804d

Browse files
authored
Merge pull request #268 from qq332982511/developing
compile developing
2 parents 7ad4046 + 3c87f43 commit dda804d

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

saber/funcs/impl/cuda/vender_lstm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ create(const std::vector<DataTensor*>& inputs,
171171
_y_desc.reset(new cudnn::TensorDescriptors<DataDtype>(
172172
offset_after_sort,
173173
{batch_size, _hidden_size * lstm_param.num_direction, 1},
174-
{_hiden_size * lstm_param.num_direction, 1, 1}));
174+
{_hidden_size * lstm_param.num_direction, 1, 1}));
175175

176176
Shape in_dim = inputs[0]->valid_shape();
177177
Shape in_stride = inputs[0]->get_stride();

test/saber/cuda/test_saber_func_attension_lstm.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
#include "saber_types.h"
99
#include "saber/funcs/timer.h"
1010
#include "saber/funcs/impl/cuda/saber_attension_lstm.h"
11+
#ifdef USE_X86_PLACE
1112
#include "saber/funcs/impl/x86/saber_attension_lstm.h"
13+
#endif
1214
#include "saber/funcs/attension_lstm.h"
1315
#include "stdio.h"
1416

@@ -153,7 +155,7 @@ void test_saber_attension_lstm(int sequence_size = 2, int batch_size = 1, int wo
153155
//t1.end(ctx_dev);
154156
//LOG(INFO) << "!!cudnn lstm :" << test_iter << " cudnn test, total time: "
155157
// << t1.get_average_ms()/test_iter;
156-
#ifdef TEST_X86
158+
#if defined(TEST_X86)&&defined(USE_X86_PLACE)
157159
LstmParam<TensorHf4> h_lstm_param(&h_weight,
158160
&h_bias,
159161
nullptr,
@@ -222,8 +224,9 @@ TEST(TestSaberFuncNV, test_func_saber_lstm) {
222224
int main(int argc, const char** argv) {
223225
// initial logger
224226
//logger::init(argv[0]);
225-
//#ifdef TEST_X86
227+
#if defined(TEST_X86)&&defined(USE_X86_PLACE)
226228
Env<X86>::env_init();
229+
#endif
227230
//#else
228231
Env<NV>::env_init();
229232
//#endif

test/saber/cuda/test_saber_func_lstm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ void test_saber_lstm(int sequence_size = 2, int batch_size = 1, int word_size =
114114
t1.end(ctx_dev);
115115
LOG(INFO) << "!!cudnn lstm :" << test_iter << " cudnn test, total time: "
116116
<< t1.get_average_ms();
117-
#ifdef TEST_X86
117+
#if defined(TEST_X86) &&defined(USE_X86_PLACE)
118118
Lstm<X86, AK_FLOAT, AK_FLOAT, AK_FLOAT, NCHW, NCHW, NCHW> x86_lstm;
119119
LstmParam<TensorHf4> h_lstm_param(&host_weight,
120120
&host_bias,
@@ -181,9 +181,9 @@ TEST(TestSaberFuncNV, test_func_saber_lstm) {
181181
int main(int argc, const char** argv) {
182182
// initial logger
183183
//logger::init(argv[0]);
184-
//#ifdef TEST_X86
184+
#if defined(TEST_X86) &&defined(USE_X86_PLACE)
185185
Env<X86>::env_init();
186-
//#else
186+
#endif
187187
Env<NV>::env_init();
188188
//#endif
189189

0 commit comments

Comments
 (0)