Skip to content
This repository was archived by the owner on May 12, 2024. It is now read-only.

Commit 37297b9

Browse files
committed
transform_landmarks, transform_tensor_bilinear, landmarks_to_transform_matrix
1 parent d337c54 commit 37297b9

File tree

6 files changed

+21
-14
lines changed

6 files changed

+21
-14
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite, ONNX, Ope
142142
|125|Densify|(const)||
143143
|126|SPACE_TO_BATCH_ND|tf.space_to_batch_nd||
144144
|127|BATCH_TO_SPACE_ND|tf.compat.v1.batch_to_space_nd||
145+
|128|TransformLandmarks|tf.reshape, tf.linalg.matmul, tf.math.add|CUSTOM, MediaPipe|
146+
|129|TransformTensorBilinear|tf.reshape, tf.linalg.matmul, tf.math.add, tf.tile, tf.math.floor, tf.math.subtract, tf.math.multiply, tf.math.reduce_prod, tf.cast, tf.math.maximum, tf.math.maximum, tf.concat, tf.gather_nd|CUSTOM, MediaPipe|
147+
|130|Landmarks2TransformMatrix|tf.constant, tf.math.subtract, tf.math.norm, tf.math.divide, tf.linalg.matmul, tf.concat, tf.transpose, tf.gather, tf.math.reduce_min, tf.math.reduce_max, tf.math.multiply, tf.zeros, tf.math.add, tf.tile|CUSTOM, MediaPipe|
145148

146149
## 2. Environment
147150
- Python3.8+
@@ -549,7 +552,9 @@ $ tflite2tensorflow \
549552
$ view_npy --npy_file_path calibration_data_img_sample.npy
550553
```
551554
Press the **`Q`** button to display the next image. **`calibration_data_img_sample.npy`** contains 20 images extracted from the MS-COCO data set.
552-
![ezgif com-gif-maker](https://user-images.githubusercontent.com/33194443/109318923-aba15480-7891-11eb-84aa-034f77125f34.gif)
555+
556+
![image](https://user-images.githubusercontent.com/33194443/160409583-66c45d47-636b-442c-94d6-51ad4170cc9b.png)
557+
553558
## 5. Sample image
554559
This is the result of converting MediaPipe's Meet Segmentation model (segm_full_v679.tflite / Float16 / Google Meet) to **`saved_model`** and then reconverting it to Float32 tflite. Replace the GPU-optimized **`Convolution2DTransposeBias`** layer with the standard **`TransposeConv`** and **`BiasAdd`** layers in a fully automatic manner. The weights and biases of the Float16 **`Dequantize`** layer are automatically back-quantized to Float32 precision. The generated **`saved_model`** in Float32 precision can be easily converted to **`Float16`**, **`INT8`**, **`EdgeTPU`**, **`TFJS`**, **`TF-TRT`**, **`CoreML`**, **`ONNX`**, **`OpenVINO`**, **`Myriad Inference Engine blob`**.
555560

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
setup(
1212
name="tflite2tensorflow",
1313
scripts=scripts,
14-
version="1.20.5",
14+
version="1.20.6",
1515
description="Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite, ONNX, OpenVINO, Myriad Inference Engine blob and .pb from .tflite.",
1616
long_description=long_description,
1717
long_description_content_type="text/markdown",

tflite2tensorflow/__init__.py

Whitespace-only changes.

tflite2tensorflow/mediapipeCustomOp.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
###############################################################################
44
#
55
# MIT License
6-
#
6+
#
77
# Copyright (c) 2022 Akiya Research Institute, Inc.
8-
#
8+
#
99
# Permission is hereby granted, free of charge, to any person obtaining a copy
1010
# of this software and associated documentation files (the "Software"), to deal
1111
# in the Software without restriction, including without limitation the rights
1212
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
1313
# copies of the Software, and to permit persons to whom the Software is
1414
# furnished to do so, subject to the following conditions:
15-
#
15+
#
1616
# The above copyright notice and this permission notice shall be included in all
1717
# copies or substantial portions of the Software.
18-
#
18+
#
1919
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
2020
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
2121
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -87,7 +87,7 @@ def TransformTensorBilinear(operator, custom_options, tensors, interpreter, feat
8787
weight_ceil_ = tf.reduce_prod(weight_ceil_, axis=3, keepdims=True) #[b,h,w,1]
8888
weight_floor = tf.reduce_prod(weight_floor, axis=3, keepdims=True) #[b,h,w,1]
8989

90-
# Find nearest 4 points.
90+
# Find nearest 4 points.
9191
# Make sure they are in the input image
9292
in_coord_floor = tf.cast(in_coord_floor, dtype=tf.int32) #[b,h,w,XY]
9393
in_coord_floor = tf.maximum(in_coord_floor, tf.zeros(2, dtype=tf.int32)) #[b,h,w,XY]
@@ -97,7 +97,7 @@ def TransformTensorBilinear(operator, custom_options, tensors, interpreter, feat
9797
# in_coord_ceil_ = tf.maximum(in_coord_ceil_, tf.zeros(2, dtype=tf.int32)) #[b,h,w,XY]
9898
in_coord_ceil_ = tf.minimum(in_coord_ceil_, [input_w, input_h]) #[b,h,w,XY]
9999

100-
in_coord_ceilX = tf.concat([in_coord_floor[:,:,:,1:2], in_coord_ceil_[:,:,:,0:1]], axis=3) #[b,h,w,YX] YX for BHWC
100+
in_coord_ceilX = tf.concat([in_coord_floor[:,:,:,1:2], in_coord_ceil_[:,:,:,0:1]], axis=3) #[b,h,w,YX] YX for BHWC
101101
in_coord_ceilY = tf.concat([in_coord_ceil_[:,:,:,1:2], in_coord_floor[:,:,:,0:1]], axis=3) #[b,h,w,YX]
102102
in_coord_floor = tf.concat([in_coord_floor[:,:,:,1:2], in_coord_floor[:,:,:,0:1]], axis=3) #[b,h,w,YX]
103103
in_coord_ceil_ = tf.concat([in_coord_ceil_[:,:,:,1:2], in_coord_ceil_[:,:,:,0:1]], axis=3) #[b,h,w,YX]
@@ -122,7 +122,7 @@ def TransformTensorBilinear(operator, custom_options, tensors, interpreter, feat
122122

123123
# Left indexとRight indexで指定されたLandmarkを結ぶ線が水平になり、Subset indicesで指定されたLandmrakをちょうど含むような範囲をcropするように、元の画像をAffine変換する行列
124124
# の逆行列を求める。なぜ、逆行列かといういうと、後の計算で使うのが逆行列だから。
125-
# Calc inverse of the matrix which represetns the affine transform which crops the area which covers all the landmarks specified by "subset indices" and rotates so that the landmarks specified by "Left index" and "Right index" are horizontally aligned.
125+
# Calc inverse of the matrix which represetns the affine transform which crops the area which covers all the landmarks specified by "subset indices" and rotates so that the landmarks specified by "Left index" and "Right index" are horizontally aligned.
126126
def Landmarks2TransformMatrix(operator, custom_options, tensors, interpreter, landmarks3d=None):
127127
if landmarks3d is None:
128128
landmarks3d = tensors[operator['inputs'][0]] #float32 [b,468,3] landmarks
@@ -189,7 +189,7 @@ def Landmarks2TransformMatrix(operator, custom_options, tensors, interpreter, la
189189
#
190190
# t = center - shift
191191
#
192-
# shift = -0.5 * output_w * sx * u
192+
# shift = -0.5 * output_w * sx * u
193193
# + -0.5 * output_h * sy * v
194194
sxu = tf.multiply(u, scale[:,:,0:1]) #[b,1,2]
195195
syv = tf.multiply(v, scale[:,:,1:2]) #[b,1,2]

tflite2tensorflow/tflite2tensorflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import struct
2222
import itertools
2323
import pandas as pd
24-
from mediapipeCustomOp import Landmarks2TransformMatrix, TransformTensorBilinear, TransformLandmarks
24+
from tflite2tensorflow.mediapipeCustomOp import Landmarks2TransformMatrix, TransformTensorBilinear, TransformLandmarks
2525

2626
class Color:
2727
BLACK = '\033[30m'
@@ -5085,18 +5085,18 @@ def complexabs_(x, tout):
50855085
# MediaPipe v0.8.9
50865086
elif custom_op_type == 'Landmarks2TransformMatrix':
50875087
options = op['custom_options']
5088-
custom_options = read_flexbuffer(np.array(options, dtype=np.uint8).tobytes())
5088+
custom_options = read_flexbuffer(np.array(options, dtype=np.uint8).tobytes())
50895089
output_detail = interpreter._get_tensor_details(op['outputs'][0])
50905090
tensors[output_detail['index']] = Landmarks2TransformMatrix(op, custom_options, tensors, interpreter)
50915091

50925092
elif custom_op_type == 'TransformTensorBilinear':
50935093
options = op['custom_options']
5094-
custom_options = read_flexbuffer(np.array(options, dtype=np.uint8).tobytes())
5094+
custom_options = read_flexbuffer(np.array(options, dtype=np.uint8).tobytes())
50955095
output_detail = interpreter._get_tensor_details(op['outputs'][0])
50965096
tensors[output_detail['index']] = TransformTensorBilinear(op, custom_options, tensors, interpreter)
50975097

50985098
elif custom_op_type == 'TransformLandmarks':
5099-
custom_options = None
5099+
custom_options = None
51005100
output_detail = interpreter._get_tensor_details(op['outputs'][0])
51015101
tensors[output_detail['index']] = TransformLandmarks(op, custom_options, tensors, interpreter)
51025102

tflite2tensorflow/view_npy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#! /usr/bin/env python
2+
13
import numpy as np
24
from matplotlib import pyplot as plt
35
import argparse

0 commit comments

Comments
 (0)