Skip to content

Commit a4811c7

Browse files
authored
Modify the srcipt files of case and switch (#504)
* Modify the srcipt files of case and switch test=develop
1 parent 8245f60 commit a4811c7

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

api/tests/case.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
class PDCase(PaddleAPIBenchmarkBase):
1919
def build_program(self, config):
20-
zero_var = fluid.layers.zeros(
21-
shape=config.input_shape, dtype=config.input_dtype)
2220
five_var = fluid.layers.fill_constant(
2321
shape=config.input_shape, dtype=config.input_dtype, value=5)
22+
ten_var = fluid.layers.fill_constant(
23+
shape=config.input_shape, dtype=config.input_dtype, value=10)
24+
one_var = fluid.layers.ones(
25+
shape=config.input_shape, dtype=config.input_dtype)
2426

2527
x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype)
2628
y = self.variable(name='y', shape=config.y_shape, dtype=config.y_dtype)
27-
input = self.variable(
28-
name='input', shape=config.input_shape, dtype=config.input_dtype)
2929

3030
def f1():
3131
return fluid.layers.elementwise_add(x=x, y=y)
@@ -36,28 +36,27 @@ def f2():
3636
def f3():
3737
return fluid.layers.elementwise_mul(x=x, y=y)
3838

39-
pred_1 = fluid.layers.less_than(input, zero_var)
40-
pred_2 = fluid.layers.greater_than(input, five_var)
39+
pred_1 = fluid.layers.less_than(one_var, five_var)
40+
pred_2 = fluid.layers.greater_than(one_var, ten_var)
4141

4242
result = fluid.layers.case(
4343
pred_fn_pairs=[(pred_1, f1), (pred_2, f2)], default=f3)
44-
self.feed_vars = [x, y, input]
44+
self.feed_vars = [x, y]
4545
self.fetch_vars = [result]
4646
if config.backward:
4747
self.append_gradients(result, [x, y])
4848

4949

5050
class TFCase(TensorflowAPIBenchmarkBase):
5151
def build_graph(self, config):
52-
zero_var = tf.constant(
53-
0, shape=config.input_shape, dtype=config.input_dtype)
5452
five_var = tf.constant(
5553
5, shape=config.input_shape, dtype=config.input_dtype)
54+
ten_var = tf.constant(
55+
10, shape=config.input_shape, dtype=config.input_dtype)
56+
one_var = tf.ones(shape=config.input_shape, dtype=config.input_dtype)
5657

5758
x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype)
5859
y = self.variable(name='y', shape=config.y_shape, dtype=config.y_dtype)
59-
input = self.variable(
60-
name='input', shape=config.input_shape, dtype=config.input_dtype)
6160

6261
def f1():
6362
return tf.add(x, y)
@@ -68,13 +67,13 @@ def f2():
6867
def f3():
6968
return tf.multiply(x, y)
7069

71-
pred_1 = tf.less(input, zero_var)
72-
pred_2 = tf.greater(input, five_var)
70+
pred_1 = tf.less(one_var, five_var)
71+
pred_2 = tf.greater(one_var, ten_var)
7372

7473
result = tf.case(
7574
[(tf.reshape(pred_1, []), f1), (tf.reshape(pred_2, []), f2)],
7675
default=f3)
77-
self.feed_list = [x, y, input]
76+
self.feed_list = [x, y]
7877
self.fetch_list = [result]
7978
if config.backward:
8079
self.append_gradients(result, [x, y])

api/tests/switch_case.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class PDSwitchCase(PaddleAPIBenchmarkBase):
1919
def build_program(self, config):
2020
x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype)
2121
y = self.variable(name='y', shape=config.y_shape, dtype=config.y_dtype)
22-
input = self.variable(
23-
name='input', shape=config.input_shape, dtype=config.input_dtype)
22+
zero = fluid.layers.zeros(
23+
shape=config.input_shape, dtype=config.input_dtype)
2424

2525
def f1():
2626
return fluid.layers.elementwise_add(x=x, y=y)
@@ -32,9 +32,9 @@ def f3():
3232
return fluid.layers.elementwise_mul(x=x, y=y)
3333

3434
result = fluid.layers.switch_case(
35-
branch_index=input, branch_fns={0: f1,
36-
1: f2}, default=f3)
37-
self.feed_vars = [x, y, input]
35+
branch_index=zero, branch_fns={0: f1,
36+
1: f2}, default=f3)
37+
self.feed_vars = [x, y]
3838
self.fetch_vars = [result]
3939
if config.backward:
4040
self.append_gradients(result, [x, y])
@@ -44,8 +44,7 @@ class TFSwitchCase(TensorflowAPIBenchmarkBase):
4444
def build_graph(self, config):
4545
x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype)
4646
y = self.variable(name='y', shape=config.y_shape, dtype=config.y_dtype)
47-
input = self.variable(
48-
name='input', shape=config.input_shape, dtype=config.input_dtype)
47+
zero = tf.zeros(shape=config.input_shape, dtype=config.input_dtype)
4948

5049
def f1():
5150
return tf.add(x, y)
@@ -57,11 +56,11 @@ def f3():
5756
return tf.multiply(x, y)
5857

5958
result = tf.switch_case(
60-
branch_index=tf.reshape(input, []),
59+
branch_index=tf.reshape(zero, []),
6160
branch_fns={0: f1,
6261
1: f2},
6362
default=f3)
64-
self.feed_list = [x, y, input]
63+
self.feed_list = [x, y]
6564
self.fetch_list = [result]
6665
if config.backward:
6766
self.append_gradients(result, [x, y])

0 commit comments

Comments
 (0)