Skip to content

Commit ab06f63

Browse files
Steffy-zxfnepeplwu
authored andcommitted
Add inference model (#327)
* add-inference-model
1 parent 0197f9a commit ab06f63

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

paddlehub/finetune/checkpoint.py

-4
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,6 @@ def save_checkpoint(checkpoint_dir,
7474
ckpt = checkpoint_pb2.CheckPoint()
7575

7676
model_saved_dir = os.path.join(checkpoint_dir, "step_%d" % global_step)
77-
logger.info("Saving model checkpoint to {}".format(model_saved_dir))
78-
fluid.io.save_persistables(
79-
exe, dirname=model_saved_dir, main_program=main_program)
80-
8177
ckpt.current_epoch = current_epoch
8278
ckpt.global_step = global_step
8379
ckpt.latest_model_dir = model_saved_dir

paddlehub/finetune/task/base_task.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -662,11 +662,7 @@ def _default_eval_end_event(self, run_states):
662662
"best_model")
663663
logger.eval("best model saved to %s [best %s=%.5f]" %
664664
(model_saved_dir, main_metric, main_value))
665-
666-
save_result = fluid.io.save_persistables(
667-
executor=self.exe,
668-
dirname=model_saved_dir,
669-
main_program=self.main_program)
665+
self.save_inference_model(dirname=model_saved_dir)
670666

671667
def _default_log_interval_event(self, run_states):
672668
scores, avg_loss, run_speed = self._calculate_metrics(run_states)
@@ -717,6 +713,10 @@ def _calculate_metrics(self, run_states):
717713
# NOTE: current saved checkpoint machanism is not completed,
718714
# it can't restore dataset training status
719715
def save_checkpoint(self):
716+
model_saved_dir = os.path.join(self.config.checkpoint_dir,
717+
"step_%d" % self.current_step)
718+
logger.info("Saving model checkpoint to {}".format(model_saved_dir))
719+
self.save_inference_model(dirname=model_saved_dir)
720720
save_checkpoint(
721721
checkpoint_dir=self.config.checkpoint_dir,
722722
current_epoch=self.current_epoch,

paddlehub/finetune/task/classifier_task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def _calculate_metrics(self, run_states):
317317
def fetch_list(self):
318318
if self.is_train_phase or self.is_test_phase:
319319
return [metric.name for metric in self.metrics] + [self.loss.name]
320-
return self.outputs
320+
return [output.name for output in self.outputs]
321321

322322
def _postprocessing(self, run_states):
323323
results = []

0 commit comments

Comments
 (0)