|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import os |
| 16 | +import tempfile |
15 | 17 | import unittest
|
16 | 18 |
|
17 | 19 | import paddle
|
|
30 | 32 | 'should compile with cuda.',
|
31 | 33 | )
|
32 | 34 | class TestConvertToMixedPrecision(unittest.TestCase):
|
33 |
| - def test_convert_to_fp16(self): |
| 35 | + def setUp(self): |
| 36 | + self.temp_dir = tempfile.TemporaryDirectory() |
34 | 37 | model = resnet50(True)
|
35 | 38 | net = to_static(
|
36 | 39 | model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')]
|
37 | 40 | )
|
38 |
| - paddle.jit.save(net, 'resnet50/inference') |
39 |
| - convert_to_mixed_precision( |
40 |
| - 'resnet50/inference.pdmodel', |
41 |
| - 'resnet50/inference.pdiparams', |
42 |
| - 'mixed/inference.pdmodel', |
43 |
| - 'mixed/inference.pdiparams', |
44 |
| - PrecisionType.Half, |
45 |
| - PlaceType.GPU, |
46 |
| - True, |
| 41 | + paddle.jit.save( |
| 42 | + net, os.path.join(self.temp_dir.name, 'resnet50/inference') |
47 | 43 | )
|
48 | 44 |
|
49 |
| - def test_convert_to_fp16_with_fp16_input(self): |
50 |
| - model = resnet50(True) |
51 |
| - net = to_static( |
52 |
| - model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')] |
53 |
| - ) |
54 |
| - paddle.jit.save(net, 'resnet50/inference') |
55 |
| - convert_to_mixed_precision( |
56 |
| - 'resnet50/inference.pdmodel', |
57 |
| - 'resnet50/inference.pdiparams', |
58 |
| - 'mixed1/inference.pdmodel', |
59 |
| - 'mixed1/inference.pdiparams', |
60 |
| - PrecisionType.Half, |
61 |
| - PlaceType.GPU, |
62 |
| - False, |
63 |
| - ) |
| 45 | + def tearDown(self): |
| 46 | + self.temp_dir.cleanup() |
64 | 47 |
|
65 |
| - def test_convert_to_fp16_with_blacklist(self): |
66 |
| - model = resnet50(True) |
67 |
| - net = to_static( |
68 |
| - model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')] |
69 |
| - ) |
70 |
| - paddle.jit.save(net, 'resnet50/inference') |
71 |
| - convert_to_mixed_precision( |
72 |
| - 'resnet50/inference.pdmodel', |
73 |
| - 'resnet50/inference.pdiparams', |
74 |
| - 'mixed2/inference.pdmodel', |
75 |
| - 'mixed2/inference.pdiparams', |
| 48 | + def test_convert_to_mixed_precision(self): |
| 49 | + mixed_precision_options = [ |
| 50 | + PrecisionType.Half, |
| 51 | + PrecisionType.Half, |
76 | 52 | PrecisionType.Half,
|
77 |
| - PlaceType.GPU, |
78 |
| - False, |
79 |
| - set('conv2d'), |
80 |
| - ) |
81 |
| - |
82 |
| - def test_convert_to_bf16(self): |
83 |
| - model = resnet50(True) |
84 |
| - net = to_static( |
85 |
| - model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')] |
86 |
| - ) |
87 |
| - paddle.jit.save(net, 'resnet50/inference') |
88 |
| - convert_to_mixed_precision( |
89 |
| - 'resnet50/inference.pdmodel', |
90 |
| - 'resnet50/inference.pdiparams', |
91 |
| - 'mixed3/inference.pdmodel', |
92 |
| - 'mixed3/inference.pdiparams', |
93 | 53 | PrecisionType.Bfloat16,
|
94 |
| - PlaceType.GPU, |
95 |
| - True, |
| 54 | + ] |
| 55 | + keep_io_types_options = [True, False, False, True] |
| 56 | + black_list_options = [set(), set(), set(['conv2d']), set()] |
| 57 | + |
| 58 | + test_configs = zip( |
| 59 | + mixed_precision_options, keep_io_types_options, black_list_options |
96 | 60 | )
|
| 61 | + for mixed_precision, keep_io_types, black_list in test_configs: |
| 62 | + config = f'mixed_precision={mixed_precision}-keep_io_types={keep_io_types}-black_list={black_list}' |
| 63 | + with self.subTest( |
| 64 | + mixed_precision=mixed_precision, |
| 65 | + keep_io_types=keep_io_types, |
| 66 | + black_list=black_list, |
| 67 | + ): |
| 68 | + convert_to_mixed_precision( |
| 69 | + os.path.join( |
| 70 | + self.temp_dir.name, 'resnet50/inference.pdmodel' |
| 71 | + ), |
| 72 | + os.path.join( |
| 73 | + self.temp_dir.name, 'resnet50/inference.pdiparams' |
| 74 | + ), |
| 75 | + os.path.join( |
| 76 | + self.temp_dir.name, f'{config}/inference.pdmodel' |
| 77 | + ), |
| 78 | + os.path.join( |
| 79 | + self.temp_dir.name, f'{config}/inference.pdiparams' |
| 80 | + ), |
| 81 | + backend=PlaceType.GPU, |
| 82 | + mixed_precision=mixed_precision, |
| 83 | + keep_io_types=keep_io_types, |
| 84 | + black_list=black_list, |
| 85 | + ) |
97 | 86 |
|
98 | 87 |
|
99 | 88 | if __name__ == '__main__':
|
|
0 commit comments