Skip to content
This repository was archived by the owner on May 12, 2024. It is now read-only.

Commit 7b55f61

Browse files
committed
Added Unity Barracuda support option for GatherND (gather_nd)
1 parent 37297b9 commit 7b55f61

File tree

4 files changed

+121
-16
lines changed

4 files changed

+121
-16
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ usage: tflite2tensorflow
350350
[--optimizing_for_edgetpu]
351351
[--replace_prelu_and_minmax]
352352
[--disable_experimental_new_quantizer]
353+
[--optimizing_barracuda]
353354
[--locationids_of_the_terminating_output]
354355
355356
optional arguments:
@@ -455,6 +456,9 @@ optional arguments:
455456
--disable_experimental_new_quantizer
456457
Disable MLIRs new quantization feature during INT8 quantization
457458
in TensorFlowLite.
459+
--optimizing_barracuda
460+
Generates ONNX by replacing Barracuda unsupported layers
461+
with standard layers. For example, GatherND.
458462
--locationids_of_the_terminating_output
459463
A comma-separated list of LocationIDs to be used as output layers.
460464
e.g. --locationids_of_the_terminating_output 100,201,560

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
setup(
1212
name="tflite2tensorflow",
1313
scripts=scripts,
14-
version="1.20.6",
14+
version="1.20.7",
1515
description="Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite, ONNX, OpenVINO, Myriad Inference Engine blob and .pb from .tflite.",
1616
long_description=long_description,
1717
long_description_content_type="text/markdown",

tflite2tensorflow/mediapipeCustomOp.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,35 @@
2424
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2525
# SOFTWARE.
2626

27+
import sys
2728
import tensorflow.compat.v1 as tf
2829
import numpy as np
2930

31+
class Color:
32+
BLACK = '\033[30m'
33+
RED = '\033[31m'
34+
GREEN = '\033[32m'
35+
YELLOW = '\033[33m'
36+
BLUE = '\033[34m'
37+
MAGENTA = '\033[35m'
38+
CYAN = '\033[36m'
39+
WHITE = '\033[37m'
40+
COLOR_DEFAULT = '\033[39m'
41+
BOLD = '\033[1m'
42+
UNDERLINE = '\033[4m'
43+
INVISIBLE = '\033[08m'
44+
REVERCE = '\033[07m'
45+
BG_BLACK = '\033[40m'
46+
BG_RED = '\033[41m'
47+
BG_GREEN = '\033[42m'
48+
BG_YELLOW = '\033[43m'
49+
BG_BLUE = '\033[44m'
50+
BG_MAGENTA = '\033[45m'
51+
BG_CYAN = '\033[46m'
52+
BG_WHITE = '\033[47m'
53+
BG_DEFAULT = '\033[49m'
54+
RESET = '\033[0m'
55+
3056
#Affine transform points
3157
def TransformLandmarks(operator, custom_options, tensors, interpreter, landmarks2d=None, mat=None):
3258
if landmarks2d is None:
@@ -46,7 +72,7 @@ def TransformLandmarks(operator, custom_options, tensors, interpreter, landmarks
4672
return landmarks2d_transformed
4773

4874
#Affine transform images using bilinear interpolation
49-
def TransformTensorBilinear(operator, custom_options, tensors, interpreter, features=None, mat=None):
75+
def TransformTensorBilinear(operator, custom_options, tensors, interpreter, optimizing_barracuda, features=None, mat=None):
5076
if features is None:
5177
features = tensors[operator['inputs'][0]] #float32 [b,48,48,32] feature maps
5278
if mat is None:
@@ -102,11 +128,46 @@ def TransformTensorBilinear(operator, custom_options, tensors, interpreter, feat
102128
in_coord_floor = tf.concat([in_coord_floor[:,:,:,1:2], in_coord_floor[:,:,:,0:1]], axis=3) #[b,h,w,YX]
103129
in_coord_ceil_ = tf.concat([in_coord_ceil_[:,:,:,1:2], in_coord_ceil_[:,:,:,0:1]], axis=3) #[b,h,w,YX]
104130

131+
def barracuda_gather_nd(params, indices):
132+
if len(indices.shape) == 4 and indices.shape[0] == 1:
133+
indices = indices[0]
134+
elif len(indices.shape) == 3:
135+
pass
136+
else:
137+
print(f'{Color.RED}ERROR:{Color.RESET} gather_nd when optimizing_barracuda is enabled must have 4 dimensions and batch size = 1 or 3 dimensions.')
138+
print(f'{Color.RED}ERROR:{Color.RESET} params.shape: {params.shape}, indices.shape: {indices.shape}')
139+
sys.exit(-1)
140+
if len(params.shape) == 4 and params.shape[0] == 1:
141+
params = params[0]
142+
elif len(params.shape) == 3:
143+
pass
144+
else:
145+
print(f'{Color.RED}ERROR:{Color.RESET} gather_nd when optimizing_barracuda is enabled must have 4 dimensions and batch size = 1 or 3 dimensions.')
146+
print(f'{Color.RED}ERROR:{Color.RESET} params.shape: {params.shape}, indices.shape: {indices.shape}')
147+
sys.exit(-1)
148+
idx_shape = indices.shape
149+
params_shape = params.shape
150+
idx_dims = idx_shape[-1]
151+
gather_shape = params_shape[idx_dims:]
152+
params_flat = tf.reshape(params, tf.concat([[-1], gather_shape], axis=0))
153+
axis_step = tf.math.cumprod(params_shape[:idx_dims], exclusive=True, reverse=True)
154+
mul = tf.math.multiply(indices, axis_step)
155+
indices_flat = tf.reduce_sum(mul, axis=-1)
156+
result_flat = tf.gather(params_flat, indices_flat)
157+
return tf.expand_dims(tf.reshape(result_flat, tf.concat([idx_shape[:-1], gather_shape], axis=0)), axis=0)
158+
105159
# calc final pixel value
106-
value_floor = tf.gather_nd(params=features, indices=in_coord_floor, batch_dims=1) #[b,h,w,32]
107-
value_ceilX = tf.gather_nd(params=features, indices=in_coord_ceilX, batch_dims=1) #[b,h,w,32]
108-
value_ceilY = tf.gather_nd(params=features, indices=in_coord_ceilY, batch_dims=1) #[b,h,w,32]
109-
value_ceil_ = tf.gather_nd(params=features, indices=in_coord_ceil_, batch_dims=1) #[b,h,w,32]
160+
if not optimizing_barracuda:
161+
value_floor = tf.gather_nd(params=features, indices=in_coord_floor, batch_dims=1) #[b,h,w,32]
162+
value_ceilX = tf.gather_nd(params=features, indices=in_coord_ceilX, batch_dims=1) #[b,h,w,32]
163+
value_ceilY = tf.gather_nd(params=features, indices=in_coord_ceilY, batch_dims=1) #[b,h,w,32]
164+
value_ceil_ = tf.gather_nd(params=features, indices=in_coord_ceil_, batch_dims=1) #[b,h,w,32]
165+
else:
166+
value_floor = barracuda_gather_nd(params=features, indices=in_coord_floor) #[b,h,w,32]
167+
value_ceilX = barracuda_gather_nd(params=features, indices=in_coord_ceilX) #[b,h,w,32]
168+
value_ceilY = barracuda_gather_nd(params=features, indices=in_coord_ceilY) #[b,h,w,32]
169+
value_ceil_ = barracuda_gather_nd(params=features, indices=in_coord_ceil_) #[b,h,w,32]
170+
110171
value_floor_fraction = tf.multiply(value_floor, weight_floor)
111172
value_ceil__fraction = tf.multiply(value_ceil_, weight_ceil_)
112173
value_ceilX_fraction = tf.multiply(value_ceilX, weight_ceilX)
@@ -132,8 +193,12 @@ def Landmarks2TransformMatrix(operator, custom_options, tensors, interpreter, la
132193
######################################
133194
# calc rotation
134195
######################################
135-
rot90_t = tf.constant([[ 0.0, 1.0],
136-
[ -1.0, 0.0]]) #[2,2], already transposed
196+
rot90_t = tf.constant(
197+
[
198+
[ 0.0, 1.0],
199+
[ -1.0, 0.0]
200+
]
201+
) #[2,2], already transposed
137202

138203
idx_rot_l = custom_options['left_rotation_idx']
139204
idx_rot_r = custom_options['right_rotation_idx']

tflite2tensorflow/tflite2tensorflow.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def make_graph(
418418
optimizing_for_openvino_and_myriad,
419419
rigorous_optimization_for_myriad,
420420
optimizing_for_coreml,
421+
optimizing_barracuda,
421422
):
422423

423424
import tensorflow.compat.v1 as tf
@@ -3455,11 +3456,44 @@ def pad_v2(x, paddings, constant_values):
34553456
except:
34563457
input_tensor2 = interpreter.get_tensor(positions_detail['index'])
34573458
output_detail = interpreter._get_tensor_details(op['outputs'][0])
3458-
output_tensor = tf.gather_nd(
3459-
input_tensor1,
3460-
input_tensor2,
3461-
name=get_op_name(output_detail['name'])
3462-
)
3459+
3460+
def barracuda_gather_nd(params, indices):
3461+
if len(indices.shape) == 4 and indices.shape[0] == 1:
3462+
indices = indices[0]
3463+
elif len(indices.shape) == 3:
3464+
pass
3465+
else:
3466+
print(f'{Color.RED}ERROR:{Color.RESET} gather_nd when optimizing_barracuda is enabled must have 4 dimensions and batch size = 1 or 3 dimensions.')
3467+
print(f'{Color.RED}ERROR:{Color.RESET} params.shape: {params.shape}, indices.shape: {indices.shape}')
3468+
sys.exit(-1)
3469+
if len(params.shape) == 4 and params.shape[0] == 1:
3470+
params = params[0]
3471+
elif len(params.shape) == 3:
3472+
pass
3473+
else:
3474+
print(f'{Color.RED}ERROR:{Color.RESET} gather_nd when optimizing_barracuda is enabled must have 4 dimensions and batch size = 1 or 3 dimensions.')
3475+
print(f'{Color.RED}ERROR:{Color.RESET} params.shape: {params.shape}, indices.shape: {indices.shape}')
3476+
sys.exit(-1)
3477+
idx_shape = indices.shape
3478+
params_shape = params.shape
3479+
idx_dims = idx_shape[-1]
3480+
gather_shape = params_shape[idx_dims:]
3481+
params_flat = tf.reshape(params, tf.concat([[-1], gather_shape], axis=0))
3482+
axis_step = tf.math.cumprod(params_shape[:idx_dims], exclusive=True, reverse=True)
3483+
mul = tf.math.multiply(indices, axis_step)
3484+
indices_flat = tf.reduce_sum(mul, axis=-1)
3485+
result_flat = tf.gather(params_flat, indices_flat)
3486+
return tf.expand_dims(tf.reshape(result_flat, tf.concat([idx_shape[:-1], gather_shape], axis=0)), axis=0)
3487+
3488+
if not optimizing_barracuda:
3489+
output_tensor = tf.gather_nd(
3490+
input_tensor1,
3491+
input_tensor2,
3492+
name=get_op_name(output_detail['name'])
3493+
)
3494+
else:
3495+
output_tensor = barracuda_gather_nd(input_tensor1, input_tensor2)
3496+
34633497
tensors[output_detail['index']] = output_tensor
34643498

34653499
elif op_type == 'COS':
@@ -5081,7 +5115,6 @@ def complexabs_(x, tout):
50815115
)
50825116
tensors[output_detail['index']] = output_tensor
50835117

5084-
50855118
# MediaPipe v0.8.9
50865119
elif custom_op_type == 'Landmarks2TransformMatrix':
50875120
options = op['custom_options']
@@ -5093,7 +5126,7 @@ def complexabs_(x, tout):
50935126
options = op['custom_options']
50945127
custom_options = read_flexbuffer(np.array(options, dtype=np.uint8).tobytes())
50955128
output_detail = interpreter._get_tensor_details(op['outputs'][0])
5096-
tensors[output_detail['index']] = TransformTensorBilinear(op, custom_options, tensors, interpreter)
5129+
tensors[output_detail['index']] = TransformTensorBilinear(op, custom_options, tensors, interpreter, optimizing_barracuda)
50975130

50985131
elif custom_op_type == 'TransformLandmarks':
50995132
custom_options = None
@@ -5642,6 +5675,7 @@ def main():
56425675
parser.add_argument('--optimizing_for_edgetpu', action='store_true', help='Optimizing for edgetpu')
56435676
parser.add_argument('--replace_prelu_and_minmax', action='store_true', help='Replace prelu and minimum/maximum with each other')
56445677
parser.add_argument('--disable_experimental_new_quantizer', action='store_true', help='Disable MLIR\'s new quantization feature during INT8 quantization in TensorFlowLite.')
5678+
parser.add_argument('--optimizing_barracuda', action='store_true', help='Generates ONNX by replacing Barracuda\'s unsupported layers with standard layers.')
56455679
parser.add_argument('--locationids_of_the_terminating_output', type=str, default='', help='A comma-separated list of location IDs to be used as output layers. Default: \'\'')
56465680
args = parser.parse_args()
56475681

@@ -5691,6 +5725,7 @@ def main():
56915725
optimizing_for_edgetpu = args.optimizing_for_edgetpu
56925726
replace_prelu_and_minmax = args.replace_prelu_and_minmax
56935727
use_experimental_new_quantizer = not args.disable_experimental_new_quantizer
5728+
optimizing_barracuda = args.optimizing_barracuda
56945729
locationids_of_the_terminating_output_tmp = args.locationids_of_the_terminating_output
56955730
locationids_of_the_terminating_output = None
56965731
if locationids_of_the_terminating_output_tmp:
@@ -5844,7 +5879,8 @@ def main():
58445879
optimizing_for_edgetpu_flg,
58455880
optimizing_for_openvino_and_myriad,
58465881
rigorous_optimization_for_myriad,
5847-
optimizing_for_coreml
5882+
optimizing_for_coreml,
5883+
optimizing_barracuda
58485884
)
58495885
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')
58505886
print('outputs:')

0 commit comments

Comments
 (0)