Skip to content

Commit afa6f12

Browse files
authored
paddlespeech/audiotools/ml/basemodel.py (#3994)
1 parent 793a89d commit afa6f12

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

paddlespeech/audiotools/ml/basemodel.py

+14-23
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def save(
110110
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
111111
paddle.save(state_dict, str(path))
112112
else:
113-
self._save_package(path, intern=intern, extern=extern, mock=mock)
113+
raise NotImplementedError(
114+
"Currently Paddle does not support packaging")
114115

115116
return path
116117

@@ -151,31 +152,21 @@ def load(
151152
BaseModel
152153
A model that inherits from BaseModel.
153154
"""
154-
try:
155-
model = cls._load_package(location, package_name=package_name)
156-
except:
157-
model_dict = paddle.load(location)
158-
metadata = model_dict["metadata"]
159-
metadata["kwargs"].update(kwargs)
160-
161-
sig = inspect.signature(cls)
162-
class_keys = list(sig.parameters.keys())
163-
for k in list(metadata["kwargs"].keys()):
164-
if k not in class_keys:
165-
metadata["kwargs"].pop(k)
166-
167-
model = cls(*args, **metadata["kwargs"])
168-
model.set_state_dict(model_dict["state_dict"])
169-
model.metadata = metadata
155+
model_dict = paddle.load(location)
156+
metadata = model_dict["metadata"]
157+
metadata["kwargs"].update(kwargs)
170158

171-
return model
159+
sig = inspect.signature(cls)
160+
class_keys = list(sig.parameters.keys())
161+
for k in list(metadata["kwargs"].keys()):
162+
if k not in class_keys:
163+
metadata["kwargs"].pop(k)
172164

173-
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
174-
raise NotImplementedError("Currently Paddle does not support packaging")
165+
model = cls(*args, **metadata["kwargs"])
166+
model.set_state_dict(model_dict["state_dict"])
167+
model.metadata = metadata
175168

176-
@classmethod
177-
def _load_package(cls, path, package_name=None):
178-
raise NotImplementedError("Currently Paddle does not support packaging")
169+
return model
179170

180171
def save_to_folder(
181172
self,

0 commit comments

Comments
 (0)