|
14 | 14 |
|
15 | 15 | import numpy as np
|
16 | 16 |
|
17 |
| -import paddle.fluid.core as core |
18 |
| -import paddle.fluid.proto.framework_pb2 as framework_pb2 |
| 17 | +from paddle.fluid import core |
| 18 | +from paddle.fluid.proto import framework_pb2 |
| 19 | + |
19 | 20 |
|
20 | 21 | # NOTE: this is added to support creating a Scalar message
|
21 | 22 | # from a python number
|
@@ -256,13 +257,13 @@ def __impl__(*args, **kwargs):
|
256 | 257 | inputs=[(var.name, var.duplicable) for var in op_proto.inputs],
|
257 | 258 | outputs=[(var.name, var.duplicable) for var in op_proto.outputs],
|
258 | 259 | attrs=[attr.name for attr in op_proto.attrs],
|
259 |
| - extra_attrs=[item for item in extra_attrs_map.keys()], |
| 260 | + extra_attrs=list(extra_attrs_map.keys()), |
260 | 261 | )
|
261 | 262 |
|
262 | 263 |
|
263 | 264 | class OperatorFactory:
|
264 | 265 | def __init__(self):
|
265 |
| - self.op_methods = dict() |
| 266 | + self.op_methods = {} |
266 | 267 |
|
267 | 268 | for op_proto in get_all_op_protos():
|
268 | 269 | method = create_op_creation_method(op_proto)
|
@@ -313,70 +314,4 @@ def get_op_extra_attr_names(self, type):
|
313 | 314 | return self.get_op_info(type).extra_attrs
|
314 | 315 |
|
315 | 316 |
|
316 |
| -class __RecurrentOp__: |
317 |
| - __proto__ = None |
318 |
| - type = "recurrent" |
319 |
| - |
320 |
| - def __init__(self): |
321 |
| - # cache recurrent_op's proto |
322 |
| - if self.__proto__ is None: |
323 |
| - for op_proto in get_all_op_protos(): |
324 |
| - if op_proto.type == self.type: |
325 |
| - self.__proto__ = op_proto |
326 |
| - |
327 |
| - def __call__(self, *args, **kwargs): |
328 |
| - if self.type not in args and "type" not in kwargs: |
329 |
| - kwargs["type"] = self.type |
330 |
| - # create proto |
331 |
| - create_method = OpDescCreationMethod(self.__proto__) |
332 |
| - proto = create_method(*args, **kwargs) |
333 |
| - # create rnnop |
334 |
| - return core.RecurrentOp.create(proto.SerializeToString()) |
335 |
| - |
336 |
| - |
337 |
| -class __DynamicRecurrentOp__: |
338 |
| - __proto__ = None |
339 |
| - type = "dynamic_recurrent" |
340 |
| - |
341 |
| - def __init__(self): |
342 |
| - # cache recurrent_op's proto |
343 |
| - if self.__proto__ is None: |
344 |
| - for op_proto in get_all_op_protos(): |
345 |
| - if op_proto.type == self.type: |
346 |
| - self.__proto__ = op_proto |
347 |
| - |
348 |
| - def __call__(self, *args, **kwargs): |
349 |
| - if self.type not in args and "type" not in kwargs: |
350 |
| - kwargs["type"] = self.type |
351 |
| - # create proto |
352 |
| - create_method = OpDescCreationMethod(self.__proto__) |
353 |
| - proto = create_method(*args, **kwargs) |
354 |
| - # create rnnop |
355 |
| - return core.DynamicRecurrentOp.create(proto.SerializeToString()) |
356 |
| - |
357 |
| - |
358 |
| -class __CondOp__: |
359 |
| - __proto__ = None |
360 |
| - type = "cond" |
361 |
| - |
362 |
| - def __init__(self): |
363 |
| - # cache recurrent_op's proto |
364 |
| - if self.__proto__ is None: |
365 |
| - for op_proto in get_all_op_protos(): |
366 |
| - if op_proto.type == self.type: |
367 |
| - self.__proto__ = op_proto |
368 |
| - |
369 |
| - def __call__(self, *args, **kwargs): |
370 |
| - if self.type not in args and "type" not in kwargs: |
371 |
| - kwargs["type"] = self.type |
372 |
| - # create proto |
373 |
| - create_method = OpDescCreationMethod(self.__proto__) |
374 |
| - proto = create_method(*args, **kwargs) |
375 |
| - # create condop |
376 |
| - return core.CondOp.create(proto.SerializeToString()) |
377 |
| - |
378 |
| - |
379 | 317 | Operator = OperatorFactory() # The default global factory
|
380 |
| -RecurrentOp = __RecurrentOp__() |
381 |
| -DynamicRecurrentOp = __DynamicRecurrentOp__() |
382 |
| -CondOp = __CondOp__() |
0 commit comments