Skip to content

Commit f86f49e

Browse files
sfvaroglutensor-tang
authored andcommitted
[NGraph] add increment op to ngraph engine (#16929)
* add increment op to ngraph engine test=develop * fix style errors test=develop
1 parent 8923612 commit f86f49e

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <string>
19+
#include <unordered_map>
20+
#include <vector>
21+
22+
#include "ngraph/ngraph.hpp"
23+
#include "paddle/fluid/operators/ngraph/ops/elementwise_node.h"
24+
#include "paddle/fluid/platform/ngraph_helper.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
namespace ngraphs {
29+
30+
void BuildIncrementNode(
31+
const std::shared_ptr<paddle::framework::OperatorBase>& op,
32+
std::shared_ptr<
33+
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
34+
ngb_node_map) {
35+
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
36+
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
37+
float step = op_attrs.Get<float>("step");
38+
auto step_op = std::make_shared<ngraph::op::Constant>(
39+
x->get_element_type(), x->get_shape(), std::vector<float>{step});
40+
std::shared_ptr<ngraph::Node> out =
41+
std::make_shared<ngraph::op::Add>(x, step_op);
42+
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
43+
}
44+
45+
} // namespace ngraphs
46+
} // namespace operators
47+
} // namespace paddle
48+
49+
REGISTER_NG_OP(increment, BuildIncrementNode);
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from paddle.fluid.tests.unittests.op_test import OpTest
20+
21+
22+
class TestNGRAPHIncrementOp(OpTest):
23+
def setUp(self):
24+
self.op_type = "increment"
25+
self.dtype = np.float32
26+
self.init_dtype_type()
27+
self.inputs = {'X': np.random.random(1).astype(self.dtype)}
28+
self.attrs = {'step': 2.0}
29+
self.outputs = {
30+
'Out': self.inputs['X'] + self.dtype(self.attrs['step'])
31+
}
32+
self._cpu_only = True
33+
34+
def init_dtype_type(self):
35+
pass
36+
37+
def test_check_output(self):
38+
self.check_output()
39+
40+
def test_check_grad(self):
41+
self.check_grad(['X'], 'Out')
42+
43+
44+
if __name__ == "__main__":
45+
unittest.main()

0 commit comments

Comments
 (0)