File tree 1 file changed +42
-0
lines changed
1 file changed +42
-0
lines changed Original file line number Diff line number Diff line change
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import textwrap
16
+
17
+ from apibase import APIBase
18
+
19
+ obj = APIBase ("torch.optim.Optimizer.zero_grad" )
20
+
21
+
22
+ def test_case_1 ():
23
+ pytorch_code = textwrap .dedent (
24
+ """
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ x = torch.randn(1, 2, 10)
29
+ model = nn.Linear(10, 20)
30
+ optimizer = torch.optim.Optimizer(params=model.parameters(), defaults={"learning_rate": 1.0})
31
+ out = model(x)
32
+ out.backward()
33
+ optimizer.step()
34
+ result = optimizer.zero_grad()
35
+ """
36
+ )
37
+ obj .run (
38
+ pytorch_code ,
39
+ ["result" ],
40
+ unsupport = True ,
41
+ reason = "paddle does not support this function temporarily" ,
42
+ )
You can’t perform that action at this time.
0 commit comments