@@ -191,32 +191,36 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
191
191
SetConvMathType (ctx, dtype, args.cdesc );
192
192
193
193
if (deterministic) {
194
- result = FindAlgoDeterministic ();
194
+ result = FindAlgoDeterministic (args );
195
195
} else {
196
196
// 1. Once turning on exhaustive FLAGS, always get exhaustive_search.
197
197
// 2. Once turning on auto-tune, runn heuristic search(default) before
198
198
// auto-tune process, run exhaustive_search during mentioned process.
199
199
// 3. After auto-tune process, run cached algorithm if cached, run
200
200
// default mode for the rest.
201
- size_t key = args.GetCacheKey <T>();
201
+ auto key = args.Convert2ConvCacheKey <T>();
202
202
auto & cache = phi::autotune::AutoTuneCache::Instance ().GetConvForward ();
203
203
if (cache.Find (key)) {
204
- result.algo = static_cast <AlgoT>(cache.Get (key));
204
+ auto t = cache.Get (key);
205
+ result.algo = static_cast <AlgoT>(t.algo );
206
+ result.workspace_size = t.workspace_size ;
205
207
} else {
206
208
bool use_autotune =
207
209
phi::autotune::AutoTuneStatus::Instance ().UseAutoTune ();
208
210
if (exhaustive_search || use_autotune) {
209
211
result = FindAlgoExhaustiveSearch<T>(args, ctx);
210
- cache.Set (key, static_cast <int64_t >(result.algo ));
211
212
} else {
212
213
result = FindAlgoHeuristic (args, ctx);
213
214
}
215
+ phi::autotune::DnnNode node (static_cast <int64_t >(result.algo ),
216
+ result.workspace_size );
217
+ cache.Set (key, node);
214
218
}
215
219
}
216
220
VLOG (3 ) << " [cuDNN Convoltion] exhaustive_search=" << exhaustive_search
217
221
<< " , deterministic=" << deterministic
218
- << " , choose algo=" << result.algo << " , workspace= "
219
- << ToMegaBytes ( GetWorkspaceSize (args, result.algo ) ) << " MB" ;
222
+ << " , choose algo=" << result.algo
223
+ << " , workspace= " << ToMegaBytes ( result.workspace_size ) << " MB" ;
220
224
return result;
221
225
}
222
226
@@ -236,8 +240,9 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
236
240
}
237
241
238
242
private:
239
- static SearchResult<AlgoT> FindAlgoDeterministic () {
240
- return SearchResult<AlgoT>(static_cast <AlgoT>(1 ));
243
+ static SearchResult<AlgoT> FindAlgoDeterministic (const ConvArgs& args) {
244
+ auto workspace_size = GetWorkspaceSize (args, static_cast <AlgoT>(1 ));
245
+ return SearchResult<AlgoT>(static_cast <AlgoT>(1 ), -1.0 , workspace_size);
241
246
}
242
247
243
248
// Heuristic search mode, calling the cudnnGetXxxAlgorithm.
@@ -298,6 +303,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
298
303
workspace_size_limit,
299
304
&(result.algo )));
300
305
#endif
306
+ result.workspace_size = GetWorkspaceSize (args, result.algo );
301
307
return result;
302
308
}
303
309
@@ -343,6 +349,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
343
349
ChooseAlgoByWorkspace<PerfT, AlgoT>(
344
350
perf_results, workspace_size_limit, &result);
345
351
352
+ result.workspace_size = GetWorkspaceSize (args, result.algo );
346
353
return result;
347
354
}
348
355
@@ -394,33 +401,37 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
394
401
SetConvMathType (ctx, dtype, args.cdesc );
395
402
396
403
if (deterministic) {
397
- result = FindAlgoDeterministic ();
404
+ result = FindAlgoDeterministic (args );
398
405
} else {
399
406
// 1. Once turning on exhaustive FLAGS, always get exhaustive_search.
400
407
// 2. Once turning on auto-tune, runn heuristic search(default) before
401
408
// auto-tune process, run exhaustive_search during mentioned process.
402
409
// 3. After auto-tune process, run cached algorithm if cached, run
403
410
// default mode for the rest.
404
- size_t key = args.GetCacheKey <T>();
411
+ auto key = args.Convert2ConvCacheKey <T>();
405
412
auto & cache =
406
413
phi::autotune::AutoTuneCache::Instance ().GetConvBackwardData ();
407
414
if (cache.Find (key)) {
408
- result.algo = static_cast <AlgoT>(cache.Get (key));
415
+ auto t = cache.Get (key);
416
+ result.algo = static_cast <AlgoT>(t.algo );
417
+ result.workspace_size = t.workspace_size ;
409
418
} else {
410
419
bool use_autotune =
411
420
phi::autotune::AutoTuneStatus::Instance ().UseAutoTune ();
412
421
if (exhaustive_search || use_autotune) {
413
422
result = FindAlgoExhaustiveSearch<T>(args, ctx);
414
- cache.Set (key, static_cast <int64_t >(result.algo ));
415
423
} else {
416
424
result = FindAlgoHeuristic (args, ctx);
417
425
}
426
+ phi::autotune::DnnNode node (static_cast <int64_t >(result.algo ),
427
+ result.workspace_size );
428
+ cache.Set (key, node);
418
429
}
419
430
}
420
431
VLOG (3 ) << " [cuDNN Convoltion] exhaustive_search=" << exhaustive_search
421
432
<< " , deterministic=" << deterministic
422
- << " , choose algo=" << result.algo << " , workspace= "
423
- << ToMegaBytes ( GetWorkspaceSize (args, result.algo ) ) << " MB" ;
433
+ << " , choose algo=" << result.algo
434
+ << " , workspace= " << ToMegaBytes ( result.workspace_size ) << " MB" ;
424
435
return result;
425
436
}
426
437
@@ -440,8 +451,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
440
451
}
441
452
442
453
private:
443
- static SearchResult<AlgoT> FindAlgoDeterministic () {
444
- return SearchResult<AlgoT>(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1);
454
+ static SearchResult<AlgoT> FindAlgoDeterministic (const ConvArgs& args) {
455
+ auto workspace_size =
456
+ GetWorkspaceSize (args, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1);
457
+ return SearchResult<AlgoT>(
458
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, -1.0 , workspace_size);
445
459
}
446
460
447
461
static SearchResult<AlgoT> FindAlgoHeuristic (const ConvArgs& args,
@@ -513,7 +527,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
513
527
workspace_size_limit,
514
528
&(result.algo )));
515
529
#endif
516
-
530
+ result. workspace_size = GetWorkspaceSize (args, result. algo );
517
531
return result;
518
532
}
519
533
@@ -559,6 +573,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
559
573
ChooseAlgoByWorkspace<PerfT, AlgoT>(
560
574
perf_results, workspace_size_limit, &result);
561
575
576
+ result.workspace_size = GetWorkspaceSize (args, result.algo );
562
577
return result;
563
578
}
564
579
@@ -609,33 +624,37 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
609
624
SetConvMathType (ctx, dtype, args.cdesc );
610
625
611
626
if (deterministic) {
612
- result = FindAlgoDeterministic ();
627
+ result = FindAlgoDeterministic (args );
613
628
} else {
614
629
// 1. Once turning on exhaustive FLAGS, always get exhaustive_search.
615
630
// 2. Once turning on auto-tune, runn heuristic search(default) before
616
631
// auto-tune process, run exhaustive_search during mentioned process.
617
632
// 3. After auto-tune process, run cached algorithm if cached, run
618
633
// default mode for the rest.
619
- size_t key = args.GetCacheKey <T>();
634
+ auto key = args.Convert2ConvCacheKey <T>();
620
635
auto & cache =
621
636
phi::autotune::AutoTuneCache::Instance ().GetConvBackwardFilter ();
622
637
if (cache.Find (key)) {
623
- result.algo = static_cast <AlgoT>(cache.Get (key));
638
+ auto t = cache.Get (key);
639
+ result.algo = static_cast <AlgoT>(t.algo );
640
+ result.workspace_size = t.workspace_size ;
624
641
} else {
625
642
bool use_autotune =
626
643
phi::autotune::AutoTuneStatus::Instance ().UseAutoTune ();
627
644
if (exhaustive_search || use_autotune) {
628
645
result = FindAlgoExhaustiveSearch<T>(args, ctx);
629
- cache.Set (key, static_cast <int64_t >(result.algo ));
630
646
} else {
631
647
result = FindAlgoHeuristic (args, ctx);
632
648
}
649
+ phi::autotune::DnnNode node (static_cast <int64_t >(result.algo ),
650
+ result.workspace_size );
651
+ cache.Set (key, node);
633
652
}
634
653
}
635
654
VLOG (3 ) << " [cuDNN Convoltion] exhaustive_search=" << exhaustive_search
636
655
<< " , deterministic=" << deterministic
637
- << " , choose algo=" << result.algo << " , workspace= "
638
- << ToMegaBytes ( GetWorkspaceSize (args, result.algo ) ) << " MB" ;
656
+ << " , choose algo=" << result.algo
657
+ << " , workspace= " << ToMegaBytes ( result.workspace_size ) << " MB" ;
639
658
return result;
640
659
}
641
660
@@ -656,8 +675,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
656
675
}
657
676
658
677
private:
659
- static SearchResult<AlgoT> FindAlgoDeterministic () {
660
- return SearchResult<AlgoT>(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1);
678
+ static SearchResult<AlgoT> FindAlgoDeterministic (const ConvArgs& args) {
679
+ auto workspace_size =
680
+ GetWorkspaceSize (args, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1);
681
+ return SearchResult<AlgoT>(
682
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, -1.0 , workspace_size);
661
683
}
662
684
663
685
static SearchResult<AlgoT> FindAlgoHeuristic (const ConvArgs& args,
@@ -718,6 +740,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
718
740
&(result.algo )));
719
741
#endif
720
742
743
+ result.workspace_size = GetWorkspaceSize (args, result.algo );
721
744
return result;
722
745
}
723
746
@@ -786,6 +809,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
786
809
ChooseAlgo (perf_results, workspace_size_limit, &result);
787
810
}
788
811
812
+ result.workspace_size = GetWorkspaceSize (args, result.algo );
789
813
return result;
790
814
}
791
815
0 commit comments