File tree 2 files changed +17
-7
lines changed
2 files changed +17
-7
lines changed Original file line number Diff line number Diff line change @@ -31,8 +31,8 @@ class SaveOp : public framework::OperatorWithKernel {
31
31
protected:
32
32
framework::OpKernelType GetExpectedKernelType (
33
33
const framework::ExecutionContext &ctx) const override {
34
- return framework::OpKernelType (ctx.Input <framework::LoDTensor> (" X" )-> type (),
35
- ctx.GetPlace ());
34
+ auto data_type = framework::GetDataTypeOfVar (ctx.InputVar (" X" ));
35
+ return framework::OpKernelType (data_type, ctx.device_context ());
36
36
}
37
37
};
38
38
Original file line number Diff line number Diff line change @@ -103,12 +103,22 @@ class SaveOpKernel : public framework::OpKernel<T> {
103
103
const platform::Place &place,
104
104
const framework::Variable *var) const {
105
105
framework::Variable *out_put_var = ctx.OutputVar (LOOKUP_TABLE_PATH);
106
- PADDLE_ENFORCE (
107
- out_put_var != nullptr ,
108
- " Can not find variable kLookupTablePath for SaveSelectedRows" );
109
- auto *lt_var = out_put_var->GetMutable <std::string>();
110
106
111
- std::string filename = lt_var->data ();
107
+ auto file_path = ctx.Attr <std::string>(" file_path" );
108
+ auto overwrite = ctx.Attr <bool >(" overwrite" );
109
+
110
+ std::string filename = file_path;
111
+
112
+ if (out_put_var != nullptr ) {
113
+ auto *lt_var = out_put_var->GetMutable <std::string>();
114
+ filename = *lt_var;
115
+ }
116
+
117
+ if (FileExists (filename) && !overwrite) {
118
+ PADDLE_THROW (" %s is existed, cannot save to it when overwrite=false" ,
119
+ filename, overwrite);
120
+ }
121
+
112
122
VLOG (4 ) << " SaveSelectedRows get File name: " << filename;
113
123
114
124
MkDirRecursively (DirName (filename).c_str ());
You can’t perform that action at this time.
0 commit comments