@@ -24,7 +24,7 @@ namespace backends {
24
24
25
25
std::tuple<CUdeviceptr, CUdeviceptr, CUdeviceptr, std::vector<float >, std::vector<float >, std::vector<float >>
26
26
CreateNVMemory (int M, int N) {
27
- CUDA_CALL (cudaThreadSynchronize ());
27
+ CUDA_CALL (cudaDeviceSynchronize ());
28
28
29
29
CUdeviceptr Ad, Bd, Cd;
30
30
cuMemAlloc (&Ad, M * N * sizeof (float ));
@@ -419,7 +419,7 @@ TEST(CodeGenCUDA, jit_host_call_cuda_kernel) {
419
419
B_buf->host_memory = reinterpret_cast <uint8_t *>(Bd);
420
420
C_buf->host_memory = reinterpret_cast <uint8_t *>(Cd);
421
421
422
- CUDA_CALL (cudaThreadSynchronize ());
422
+ CUDA_CALL (cudaDeviceSynchronize ());
423
423
424
424
// call the kernel
425
425
auto comp = reinterpret_cast <void (*)(cinn_pod_value_t *, int )>(fn_ptr);
@@ -428,7 +428,7 @@ TEST(CodeGenCUDA, jit_host_call_cuda_kernel) {
428
428
429
429
comp (args.data (), args.size ());
430
430
431
- CUDA_CALL (cudaThreadSynchronize ());
431
+ CUDA_CALL (cudaDeviceSynchronize ());
432
432
433
433
CUDA_CALL (cudaMemcpy (host_data3.data (),
434
434
reinterpret_cast <void *>(Cd),
@@ -716,7 +716,7 @@ TEST(elementwise_add, share_local_cache) {
716
716
B_buf->host_memory = reinterpret_cast <uint8_t *>(Bd);
717
717
C_buf->host_memory = reinterpret_cast <uint8_t *>(Cd);
718
718
719
- CUDA_CALL (cudaThreadSynchronize ());
719
+ CUDA_CALL (cudaDeviceSynchronize ());
720
720
721
721
// call the kernel
722
722
auto comp = reinterpret_cast <void (*)(cinn_pod_value_t *, int )>(fn_ptr);
@@ -725,7 +725,7 @@ TEST(elementwise_add, share_local_cache) {
725
725
726
726
comp (args.data (), args.size ());
727
727
728
- CUDA_CALL (cudaThreadSynchronize ());
728
+ CUDA_CALL (cudaDeviceSynchronize ());
729
729
}
730
730
731
731
CUDA_CALL (cudaFree (reinterpret_cast <void *>(Ad)))
@@ -883,6 +883,8 @@ TEST(Conv, optimize) {
883
883
}
884
884
885
885
TEST (ElementwiseAdd, cache_read) {
886
+ Context::Global ().ResetNameId ();
887
+
886
888
Expr M (100 );
887
889
Expr N (200 );
888
890
@@ -933,14 +935,82 @@ void fn_kernel(const float* __restrict__ A, const float* __restrict__ B, float*
933
935
};
934
936
};
935
937
for (int32_t i = 0; i < 10; i += 1) {
936
- C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[((10 * blockIdx.x) + ((10 * threadIdx.x) + i))] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]);
938
+ C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[i] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]);
939
+ };
940
+ };
941
+ }
942
+
943
+ }
944
+ )ROC" ;
945
+ ASSERT_EQ (utils::Trim (source_target), source_code);
946
+
947
+ backends::NVRTC_Compiler compiler;
948
+
949
+ auto ptx = compiler (source_code);
950
+ CHECK (!ptx.empty ()) << " Compile error!" ;
951
+ }
952
+
953
+ TEST (ElementwiseAdd, cache_read1) {
954
+ Expr M (100 );
955
+ Expr N (200 );
956
+
957
+ Placeholder<float > A (" A" , {M, N});
958
+ Placeholder<float > B (" B" , {M, N});
959
+
960
+ auto C = Compute (
961
+ {M - 2 , N}, [&](Expr i, Expr j) { return A (i, j) + A (i + 1 , j) + A (i + 2 , j) + B (i, j); }, " C" );
962
+ C->stage ()->Split (1 , 10 );
963
+
964
+ auto AL = A->stage ()->CacheRead (" local" , {C});
965
+ AL->stage ()->Split (1 , 10 );
966
+
967
+ AL->stage ()->ComputeAt (C->stage (), 1 , poly::Stage::ComputeAtKind::kComputeAtUnk , A->name );
968
+ C->stage ()->Bind (0 , " threadIdx.x" );
969
+ C->stage ()->Bind (1 , " blockIdx.x" );
970
+
971
+ Target target;
972
+ CodeGenCUDA_Dev codegen (target);
973
+
974
+ auto fn = Lower (" fn" , {A, B, C}, {}, {AL});
975
+
976
+ Module::Builder builder (" module" , target);
977
+ builder.AddFunction (fn);
978
+
979
+ auto source_code = codegen.Compile (builder.Build ());
980
+ std::cout << " source:\n " << source_code << std::endl;
981
+
982
+ std::string source_target = R"ROC(
983
+ extern "C" {
984
+
985
+ #ifdef __CUDACC_RTC__
986
+ typedef int int32_t;
987
+ typedef char int8_t;
988
+ #endif
989
+
990
+
991
+
992
+ __global__
993
+ void fn_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C)
994
+ {
995
+ float _A_read_cache_6 [ 3 * 10 ];
996
+ float* A_read_cache_6 = _A_read_cache_6;
997
+ {
998
+ if (((((threadIdx.x >= 0) && (threadIdx.x <= 97)) && (blockIdx.x >= 0)) && (blockIdx.x <= 19))) {
999
+ for (int32_t i = threadIdx.x; i < (3 + threadIdx.x); i += 1) {
1000
+ for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) {
1001
+ A_read_cache_6[((10 * i) + j_inner)] = A[((10 * blockIdx.x) + ((200 * i) + j_inner))];
1002
+ };
1003
+ };
1004
+ };
1005
+ for (int32_t i = 0; i < 10; i += 1) {
1006
+ C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_6[i] + (A_read_cache_6[(10 + i)] + (A_read_cache_6[(20 + i)] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))])));
937
1007
};
938
1008
};
939
1009
}
940
1010
941
1011
}
942
1012
)ROC" ;
943
- // ASSERT_EQ(utils::Trim(source_target), source );
1013
+ ASSERT_EQ (utils::Trim (source_target), source_code );
944
1014
945
1015
backends::NVRTC_Compiler compiler;
946
1016
0 commit comments