Skip to content

Commit 5c981b0

Browse files
committed
polish doc of paddle.save/load
1 parent 946849b commit 5c981b0

File tree

1 file changed

+4
-43
lines changed
  • python/paddle/framework

1 file changed

+4
-43
lines changed

python/paddle/framework/io.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -555,26 +555,7 @@ def save(obj, path, protocol=2, **configs):
555555
paddle.save(obj, path)
556556
557557
558-
# example 3: Save layer
559-
import paddle
560-
from paddle import nn
561-
562-
class LinearNet(nn.Layer):
563-
def __init__(self):
564-
super(LinearNet, self).__init__()
565-
self._linear = nn.Linear(224, 10)
566-
567-
def forward(self, x):
568-
return self._linear(x)
569-
570-
inps = paddle.randn([1, 224], dtype='float32')
571-
layer = LinearNet()
572-
layer.eval()
573-
path = "example/layer.pdmodel"
574-
paddle.save(layer,path)
575-
576-
577-
# example 4: static graph
558+
# example 3: static graph
578559
import paddle
579560
import paddle.static as static
580561
@@ -601,7 +582,7 @@ def forward(self, x):
601582
path_state_dict = 'temp/model.pdparams'
602583
paddle.save(prog.state_dict("param"), path_tensor)
603584
604-
# example 5: save program
585+
# example 4: save program
605586
import paddle
606587
607588
paddle.enable_static()
@@ -796,27 +777,7 @@ def load(path, **configs):
796777
obj_load = paddle.load(path)
797778
798779
799-
# example 3: Load layer
800-
import paddle
801-
from paddle import nn
802-
803-
class LinearNet(nn.Layer):
804-
def __init__(self):
805-
super(LinearNet, self).__init__()
806-
self._linear = nn.Linear(224, 10)
807-
808-
def forward(self, x):
809-
return self._linear(x)
810-
811-
inps = paddle.randn([1, 224], dtype='float32')
812-
layer = LinearNet()
813-
layer.eval()
814-
path = "example/layer.pdmodel"
815-
paddle.save(layer,path)
816-
layer_load=paddle.load(path)
817-
818-
819-
# example 4: static graph
780+
# example 3: static graph
820781
import paddle
821782
import paddle.static as static
822783
@@ -846,7 +807,7 @@ def forward(self, x):
846807
load_state_dict = paddle.load(path_tensor)
847808
848809
849-
# example 5: load program
810+
# example 4: load program
850811
import paddle
851812
852813
paddle.enable_static()

0 commit comments

Comments
 (0)