Skip to content

Commit e727240

Browse files
authored
Add einsum. (#1594)
1 parent 293bf89 commit e727240

File tree

8 files changed

+257
-41
lines changed

8 files changed

+257
-41
lines changed

api/run_op_benchmark.sh

+13-41
Original file line numberDiff line numberDiff line change
@@ -59,57 +59,36 @@ install_package() {
5959
run_op_benchmark() {
6060
local testing_mode=$1
6161

62-
OUTPUT_ROOT=${OP_BENCHMARK_ROOT}/logs
62+
if [ "${op_type}" = "all" ]; then
63+
OUTPUT_ROOT=${OP_BENCHMARK_ROOT}/logs
64+
op_type="None"
65+
else
66+
OUTPUT_ROOT=${OP_BENCHMARK_ROOT}/logs/${op_type}
67+
fi
6368
if [ ! -d ${OUTPUT_ROOT} ]; then
6469
mkdir -p ${OUTPUT_ROOT}
6570
fi
6671

6772
timestamp=`date '+%Y%m%d-%H%M%S'`
6873
if [ ${test_module_name} == "tests" ]; then
69-
output_dir=${OUTPUT_ROOT}/${test_module_name}_${testing_mode}/${timestamp}
74+
subdir=${test_module_name}_${testing_mode}
7075
else
71-
output_dir=${OUTPUT_ROOT}/${test_module_name}/${timestamp}
76+
subdir=${test_module_name}
7277
fi
78+
output_dir=${OUTPUT_ROOT}/${subdir}/${timestamp}
7379
if [ ! -d ${output_dir} ]; then
7480
mkdir -p ${output_dir}
7581
fi
7682
echo "-- output_dir: ${output_dir}"
7783

7884
config_dir=${OP_BENCHMARK_ROOT}/tests_v2/configs
85+
#config_dir=${OP_BENCHMARK_ROOT}/tests_v2/op_configs
7986
echo "-- config_dir: ${config_dir}"
8087

8188
tests_dir=${OP_BENCHMARK_ROOT}/${test_module_name}
8289
echo "-- tests_dir: ${tests_dir}"
83-
log_path=${OUTPUT_ROOT}/log_${test_module_name}_${timestamp}.txt
84-
bash ${OP_BENCHMARK_ROOT}/deploy/main_control.sh ${tests_dir} ${config_dir} ${output_dir} ${gpu_ids} ${device_type} "both" "none" "both" "${testing_mode}" "None" "${precision}" > ${log_path} 2>&1 &
85-
}
86-
87-
run_specified_op() {
88-
local testing_mode=$1
89-
90-
OUTPUT_ROOT=${OP_BENCHMARK_ROOT}/logs/${op_type}
91-
if [ ! -d ${OUTPUT_ROOT} ]; then
92-
mkdir -p ${OUTPUT_ROOT}
93-
fi
94-
95-
timestamp=`date '+%Y%m%d-%H%M%S'`
96-
if [ ${test_module_name} == "tests" ]; then
97-
output_dir=${OUTPUT_ROOT}/${test_module_name}_${testing_mode}/${timestamp}
98-
else
99-
output_dir=${OUTPUT_ROOT}/${test_module_name}/${timestamp}
100-
fi
101-
if [ ! -d ${output_dir} ]; then
102-
mkdir -p ${output_dir}
103-
fi
104-
echo "-- output_dir: ${output_dir}"
105-
106-
config_dir=${OP_BENCHMARK_ROOT}/tests_v2/op_configs
107-
echo "-- config_dir: ${config_dir}"
108-
109-
tests_dir=${OP_BENCHMARK_ROOT}/${test_module_name}
110-
echo "-- tests_dir: ${tests_dir}"
111-
log_path=${OUTPUT_ROOT}/log_${test_module_name}_${timestamp}.txt
112-
bash ${OP_BENCHMARK_ROOT}/deploy/main_control.sh ${tests_dir} ${config_dir} ${output_dir} "${gpu_ids}" "gpu" "${task}" "none" "${framework}" "${testing_mode}" "${op_type}" "${precision}" > ${log_path} 2>&1 &
90+
log_path=${OUTPUT_ROOT}/log_${subdir}_${timestamp}.txt
91+
bash ${OP_BENCHMARK_ROOT}/deploy/main_control.sh ${tests_dir} ${config_dir} ${output_dir} ${gpu_ids} ${device_type} ${task} "none" ${framework} ${testing_mode} ${op_type} ${precision} > ${log_path} 2>&1 &
11392
}
11493

11594
main() {
@@ -124,14 +103,7 @@ main() {
124103
install_package "tensorflow" "2.3.1"
125104
fi
126105

127-
case ${op_type} in
128-
all)
129-
run_op_benchmark ${testing_mode}
130-
;;
131-
*)
132-
run_specified_op ${testing_mode}
133-
;;
134-
esac
106+
run_op_benchmark ${testing_mode}
135107
}
136108

137109
main

api/tests/einsum.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from common_import import *
16+
17+
18+
@benchmark_registry.register("einsum")
19+
class EinsumConfig(APIConfig):
20+
def __init__(self):
21+
super(EinsumConfig, self).__init__("einsum")
22+
23+
def init_from_json(self, filename, config_id=0, unknown_dim=16):
24+
super(EinsumConfig, self).init_from_json(filename, config_id,
25+
unknown_dim)
26+
27+
self.num_operands = 0
28+
for name, value in vars(self).items():
29+
if name.endswith("_dtype"):
30+
self.num_operands += 1
31+
32+
33+
@benchmark_registry.register("einsum")
34+
class PaddleEinsum(PaddleOpBenchmarkBase):
35+
def build_graph(self, config):
36+
if config.num_operands == 2:
37+
x = self.variable(
38+
name="x", shape=config.x_shape, dtype=config.x_dtype)
39+
y = self.variable(
40+
name="y", shape=config.y_shape, dtype=config.y_dtype)
41+
result = paddle.einsum(config.equation, x, y)
42+
43+
self.feed_list = [x, y]
44+
self.fetch_list = [result]
45+
if config.backward:
46+
self.append_gradients(result, [x, y])
47+
48+
49+
@benchmark_registry.register("einsum")
50+
class TorchEinsum(PytorchOpBenchmarkBase):
51+
def build_graph(self, config):
52+
if config.num_operands == 2:
53+
x = self.variable(
54+
name="x", shape=config.x_shape, dtype=config.x_dtype)
55+
y = self.variable(
56+
name="y", shape=config.y_shape, dtype=config.y_dtype)
57+
result = torch.einsum(config.equation, x, y).contiguous()
58+
59+
self.feed_list = [x, y]
60+
self.fetch_list = [result]
61+
if config.backward:
62+
self.append_gradients(result, [x, y])
File renamed without changes.

api/tests_v2/alphafold/einsum.json

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
[{
2+
"op": "einsum",
3+
"param_info": {
4+
"equation": {
5+
"type": "string",
6+
"value": "nqkc,ch->nhqk"
7+
},
8+
"x": {
9+
"dtype": "float32",
10+
"shape": "[1L, 256L, 256L, 128L]",
11+
"type": "Variable"
12+
},
13+
"y": {
14+
"dtype": "float32",
15+
"shape": "[128L, 8L]",
16+
"type": "Variable"
17+
}
18+
}
19+
}, {
20+
"op": "einsum",
21+
"param_info": {
22+
"equation": {
23+
"type": "string",
24+
"value": "nabc,nadc->nbdc"
25+
},
26+
"x": {
27+
"dtype": "float32",
28+
"shape": "[1L, 128L, 256L, 1L]",
29+
"type": "Variable"
30+
},
31+
"y": {
32+
"dtype": "float32",
33+
"shape": "[1L, 128L, 256L, 1L]",
34+
"type": "Variable"
35+
}
36+
}
37+
}, {
38+
"op": "einsum",
39+
"param_info": {
40+
"equation": {
41+
"type": "string",
42+
"value": "nacb,nade->ndceb"
43+
},
44+
"x": {
45+
"dtype": "float32",
46+
"shape": "[1L, 128L, 32L, 256L]",
47+
"type": "Variable"
48+
},
49+
"y": {
50+
"dtype": "float32",
51+
"shape": "[1L, 128L, 256L, 32L]",
52+
"type": "Variable"
53+
}
54+
}
55+
}, {
56+
"op": "einsum",
57+
"param_info": {
58+
"equation": {
59+
"type": "string",
60+
"value": "ndceb,cef->ndbf"
61+
},
62+
"x": {
63+
"dtype": "float32",
64+
"shape": "[1L, 256L, 32L, 32L, 256L]",
65+
"type": "Variable"
66+
},
67+
"y": {
68+
"dtype": "float32",
69+
"shape": "[32L, 32L, 128L]",
70+
"type": "Variable"
71+
}
72+
}
73+
}, {
74+
"op": "einsum",
75+
"param_info": {
76+
"equation": {
77+
"type": "string",
78+
"value": "bikc,bjkc->bijc"
79+
},
80+
"x": {
81+
"dtype": "float32",
82+
"shape": "[1L, 256L, 256L, 128L]",
83+
"type": "Variable"
84+
},
85+
"y": {
86+
"dtype": "float32",
87+
"shape": "[1L, 256L, 256L, 128L]",
88+
"type": "Variable"
89+
}
90+
}
91+
}, {
92+
"op": "einsum",
93+
"param_info": {
94+
"equation": {
95+
"type": "string",
96+
"value": "bkjc,bkic->bijc"
97+
},
98+
"x": {
99+
"dtype": "float32",
100+
"shape": "[1L, 256L, 256L, 128L]",
101+
"type": "Variable"
102+
},
103+
"y": {
104+
"dtype": "float32",
105+
"shape": "[1L, 256L, 256L, 128L]",
106+
"type": "Variable"
107+
}
108+
}
109+
}, {
110+
"op": "einsum",
111+
"param_info": {
112+
"equation": {
113+
"type": "string",
114+
"value": "bqkc,ch->bhqk"
115+
},
116+
"x": {
117+
"dtype": "float32",
118+
"shape": "[1L, 256L, 256L, 128L]",
119+
"type": "Variable"
120+
},
121+
"y": {
122+
"dtype": "float32",
123+
"shape": "[128L, 4L]",
124+
"type": "Variable"
125+
}
126+
}
127+
}]

api/tests_v2/configs/einsum.json

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
[{
2+
"op": "einsum",
3+
"param_info": {
4+
"equation": {
5+
"type": "string",
6+
"value": "nabc,nadc->nbdc"
7+
},
8+
"x": {
9+
"dtype": "float32",
10+
"shape": "[1L, 128L, 256L, 1L]",
11+
"type": "Variable"
12+
},
13+
"y": {
14+
"dtype": "float32",
15+
"shape": "[1L, 128L, 256L, 1L]",
16+
"type": "Variable"
17+
}
18+
}
19+
}, {
20+
"op": "einsum",
21+
"param_info": {
22+
"equation": {
23+
"type": "string",
24+
"value": "nacb,nade->ndceb"
25+
},
26+
"x": {
27+
"dtype": "float32",
28+
"shape": "[1L, 128L, 32L, 256L]",
29+
"type": "Variable"
30+
},
31+
"y": {
32+
"dtype": "float32",
33+
"shape": "[1L, 128L, 256L, 32L]",
34+
"type": "Variable"
35+
}
36+
}
37+
}, {
38+
"op": "einsum",
39+
"param_info": {
40+
"equation": {
41+
"type": "string",
42+
"value": "ndceb,cef->ndbf"
43+
},
44+
"x": {
45+
"dtype": "float32",
46+
"shape": "[1L, 256L, 32L, 32L, 256L]",
47+
"type": "Variable"
48+
},
49+
"y": {
50+
"dtype": "float32",
51+
"shape": "[32L, 32L, 128L]",
52+
"type": "Variable"
53+
}
54+
}
55+
}]

0 commit comments

Comments
 (0)