Skip to content
This repository was archived by the owner on May 12, 2024. It is now read-only.

Commit d337c54

Browse files
authored
Merge pull request #28 from KenjiAsaba/mediapipeCustomOp_FaceLandMarkWithAttention
Thank you so much!
2 parents d0a2bd5 + c46f01c commit d337c54

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
###############################################################################
2+
# TensorFlow implementation of MediaPipe custom operators
3+
###############################################################################
4+
#
5+
# MIT License
6+
#
7+
# Copyright (c) 2022 Akiya Research Institute, Inc.
8+
#
9+
# Permission is hereby granted, free of charge, to any person obtaining a copy
10+
# of this software and associated documentation files (the "Software"), to deal
11+
# in the Software without restriction, including without limitation the rights
12+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13+
# copies of the Software, and to permit persons to whom the Software is
14+
# furnished to do so, subject to the following conditions:
15+
#
16+
# The above copyright notice and this permission notice shall be included in all
17+
# copies or substantial portions of the Software.
18+
#
19+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25+
# SOFTWARE.
26+
27+
import tensorflow.compat.v1 as tf
28+
import numpy as np
29+
30+
#Affine transform points
31+
def TransformLandmarks(operator, custom_options, tensors, interpreter, landmarks2d=None, mat=None):
32+
if landmarks2d is None:
33+
landmarks2d = tensors[operator['inputs'][0]] #float32 [b,80,2] landmarks 2d
34+
if mat is None:
35+
mat = tensors[operator['inputs'][1]] #float32 [b,4,4] affine transform matrix
36+
b = landmarks2d.shape[0]
37+
38+
# extract important values
39+
mat_rot = mat[:,0:2,0:2] #[b,2,2]
40+
translation = mat[:,0:2,3:4] #[b,2,1]
41+
translation = tf.reshape(translation, [b,1,2])
42+
43+
# Find the corresponding point in the input image
44+
landmarks2d_transformed = tf.matmul(landmarks2d, mat_rot, transpose_b=True) #[b,80,2]
45+
landmarks2d_transformed = tf.add(landmarks2d_transformed, translation) #[b,80,2]
46+
return landmarks2d_transformed
47+
48+
#Affine transform images using bilinear interpolation
49+
def TransformTensorBilinear(operator, custom_options, tensors, interpreter, features=None, mat=None):
50+
if features is None:
51+
features = tensors[operator['inputs'][0]] #float32 [b,48,48,32] feature maps
52+
if mat is None:
53+
mat = tensors[operator['inputs'][1]] #float32 [b,4,4] affine transform matrix
54+
w = custom_options['output_width']
55+
h = custom_options['output_height']
56+
b = features.shape[0]
57+
input_h = features.shape[1]
58+
input_w = features.shape[2]
59+
60+
# extract important values
61+
mat_rot = mat[:,0:2,0:2] #[b,2,2]
62+
translation = mat[:,0:2,3:4] #[b,2,1]
63+
translation = tf.reshape(translation, [b,1,1,2])
64+
65+
# construct output image coordinates
66+
# out_coord = [[[ 0,0],[ 0,1],[ 0,2],...,[0,15]],
67+
# [[ 1,0],[ 1,1],[ 1,2],...,[1,15]],
68+
# ...
69+
# [[15,0],[15,1],[15,2],...,[15,15]]]
70+
array_w = np.arange(w) #[0,1,2,...,15]
71+
array_h = np.arange(h) #[0,1,2,...,15]
72+
X, Y = np.meshgrid(array_w, array_h) #[h,w]
73+
out_coord = np.stack([X,Y], axis=2) #[h,w,2]
74+
out_coord = np.expand_dims(out_coord, axis=0).astype(np.float32) #[1,h,w,2]
75+
out_coord = tf.tile(out_coord, [b,1,1,1]) #[b,h,w,2]
76+
77+
# Find the corresponding point in the input image
78+
in_coord = tf.matmul(out_coord, mat_rot, transpose_b=True) #[b,h,w,2]
79+
in_coord = tf.add(in_coord, translation) #[b,h,w,2]
80+
81+
# Find the weights for the nearest 4 points
82+
in_coord_floor = tf.floor(in_coord) #[b,h,w,2]
83+
weight_ceil_ = tf.subtract(in_coord, in_coord_floor) #[b,h,w,2]
84+
weight_floor = tf.subtract(tf.ones(2), weight_ceil_) #[b,h,w,2]
85+
weight_ceilX = tf.multiply(weight_ceil_[:,:,:,0:1], weight_floor[:,:,:,1:2]) #[b,h,w]
86+
weight_ceilY = tf.multiply(weight_floor[:,:,:,0:1], weight_ceil_[:,:,:,1:2]) #[b,h,w]
87+
weight_ceil_ = tf.reduce_prod(weight_ceil_, axis=3, keepdims=True) #[b,h,w,1]
88+
weight_floor = tf.reduce_prod(weight_floor, axis=3, keepdims=True) #[b,h,w,1]
89+
90+
# Find nearest 4 points.
91+
# Make sure they are in the input image
92+
in_coord_floor = tf.cast(in_coord_floor, dtype=tf.int32) #[b,h,w,XY]
93+
in_coord_floor = tf.maximum(in_coord_floor, tf.zeros(2, dtype=tf.int32)) #[b,h,w,XY]
94+
in_coord_floor = tf.minimum(in_coord_floor, [input_w, input_h]) #[b,h,w,XY]
95+
96+
in_coord_ceil_ = tf.add(in_coord_floor, tf.ones(2, dtype=tf.int32)) #[b,h,w,XY]
97+
# in_coord_ceil_ = tf.maximum(in_coord_ceil_, tf.zeros(2, dtype=tf.int32)) #[b,h,w,XY]
98+
in_coord_ceil_ = tf.minimum(in_coord_ceil_, [input_w, input_h]) #[b,h,w,XY]
99+
100+
in_coord_ceilX = tf.concat([in_coord_floor[:,:,:,1:2], in_coord_ceil_[:,:,:,0:1]], axis=3) #[b,h,w,YX] YX for BHWC
101+
in_coord_ceilY = tf.concat([in_coord_ceil_[:,:,:,1:2], in_coord_floor[:,:,:,0:1]], axis=3) #[b,h,w,YX]
102+
in_coord_floor = tf.concat([in_coord_floor[:,:,:,1:2], in_coord_floor[:,:,:,0:1]], axis=3) #[b,h,w,YX]
103+
in_coord_ceil_ = tf.concat([in_coord_ceil_[:,:,:,1:2], in_coord_ceil_[:,:,:,0:1]], axis=3) #[b,h,w,YX]
104+
105+
# calc final pixel value
106+
value_floor = tf.gather_nd(params=features, indices=in_coord_floor, batch_dims=1) #[b,h,w,32]
107+
value_ceilX = tf.gather_nd(params=features, indices=in_coord_ceilX, batch_dims=1) #[b,h,w,32]
108+
value_ceilY = tf.gather_nd(params=features, indices=in_coord_ceilY, batch_dims=1) #[b,h,w,32]
109+
value_ceil_ = tf.gather_nd(params=features, indices=in_coord_ceil_, batch_dims=1) #[b,h,w,32]
110+
value_floor_fraction = tf.multiply(value_floor, weight_floor)
111+
value_ceil__fraction = tf.multiply(value_ceil_, weight_ceil_)
112+
value_ceilX_fraction = tf.multiply(value_ceilX, weight_ceilX)
113+
value_ceilY_fraction = tf.multiply(value_ceilY, weight_ceilY)
114+
115+
#[b,h,w,32]
116+
value = tf.add(
117+
tf.add(value_floor_fraction, value_ceil__fraction),
118+
tf.add(value_ceilX_fraction, value_ceilY_fraction)
119+
)
120+
121+
return value
122+
123+
# Left indexとRight indexで指定されたLandmarkを結ぶ線が水平になり、Subset indicesで指定されたLandmrakをちょうど含むような範囲をcropするように、元の画像をAffine変換する行列
124+
# の逆行列を求める。なぜ、逆行列かといういうと、後の計算で使うのが逆行列だから。
125+
# Calc inverse of the matrix which represetns the affine transform which crops the area which covers all the landmarks specified by "subset indices" and rotates so that the landmarks specified by "Left index" and "Right index" are horizontally aligned.
126+
def Landmarks2TransformMatrix(operator, custom_options, tensors, interpreter, landmarks3d=None):
127+
if landmarks3d is None:
128+
landmarks3d = tensors[operator['inputs'][0]] #float32 [b,468,3] landmarks
129+
landmarks2d = landmarks3d[:,:,0:2] # [b,468,2]
130+
b = landmarks3d.shape[0]
131+
132+
######################################
133+
# calc rotation
134+
######################################
135+
rot90_t = tf.constant([[ 0.0, 1.0],
136+
[ -1.0, 0.0]]) #[2,2], already transposed
137+
138+
idx_rot_l = custom_options['left_rotation_idx']
139+
idx_rot_r = custom_options['right_rotation_idx']
140+
left_ = landmarks2d[:,idx_rot_l:idx_rot_l+1,:] #[b,1,2]
141+
right = landmarks2d[:,idx_rot_r:idx_rot_r+1,:] #[b,1,2]
142+
143+
delta = tf.subtract(right, left_) #[b,1,2]
144+
length = tf.norm(delta, axis=2, keepdims=True) #[b,1,1]
145+
146+
u = tf.divide(delta, length) #[b,1,2] = [[ dx, dy]]
147+
v = tf.matmul(u, rot90_t) #[b,1,2] = [[-dy, dx]]
148+
149+
# mat_rot_inv = [[ dx, dy],
150+
# [-dy, dx]]
151+
# mat_rot = [[ dx, -dy],
152+
# [ dy, dx]]
153+
mat_rot_inv = tf.concat([u, v], axis=1) #[b,2,2] 切り取り後の画像座標から、切り取り前の画像座標への回転
154+
mat_rot = tf.transpose(mat_rot_inv, perm=[0,2,1]) #[b,2,2] 切り取り前の画像座標から、切り取り後の画像座標への回転
155+
156+
######################################
157+
# calc crop size and center
158+
######################################
159+
subset_idxs = custom_options['subset_idxs'] #[80]
160+
landmarks2d_subset = tf.gather(landmarks2d, indices=subset_idxs, axis=1) #[b,80,2]
161+
landmarks2d_subset_rotated = tf.matmul(landmarks2d_subset, mat_rot) #[b,80,2] 切り取り前の画像上でのLandmark座標を、切り取り後の画像上での向きにあわせて回転
162+
landmarks2d_subset_rotated_min = tf.reduce_min(landmarks2d_subset_rotated, axis=1, keepdims=True) #[b,1,2]
163+
landmarks2d_subset_rotated_max = tf.reduce_max(landmarks2d_subset_rotated, axis=1, keepdims=True) #[b,1,2]
164+
165+
crop_size = tf.subtract(landmarks2d_subset_rotated_max, landmarks2d_subset_rotated_min) #[b,1,2], max - min
166+
center = tf.multiply(tf.add(landmarks2d_subset_rotated_min, landmarks2d_subset_rotated_max), tf.constant(0.5)) #[b,1,2], 1/2 * (max + min)
167+
center = tf.matmul(center, mat_rot_inv) #[b,1,2] 切り取り後の画像上での向きから、切り取り前の画像上での向きに回転
168+
169+
######################################
170+
# calc scale
171+
######################################
172+
# s = [[scale_x * crop_size.x / output_w],
173+
# [scale_y * crop_size.y / output_h]]]
174+
output_w = custom_options['output_width']
175+
output_h = custom_options['output_height']
176+
scale_x = custom_options['scale_x']
177+
scale_y = custom_options['scale_y']
178+
scaling_const_x = scale_x / output_w
179+
scaling_const_y = scale_y / output_h
180+
scaling_const = tf.constant([[scaling_const_x, scaling_const_y]]) #[1,2]
181+
scale = tf.multiply(scaling_const, crop_size) #[b,1,2]
182+
183+
######################################
184+
# calc translation and final mat
185+
######################################
186+
# mat = [[ sx*dx, -sy*dy, 0, tx],
187+
# [ sx*dy, sy*dx, 0, ty]]
188+
# where
189+
#
190+
# t = center - shift
191+
#
192+
# shift = -0.5 * output_w * sx * u
193+
# + -0.5 * output_h * sy * v
194+
sxu = tf.multiply(u, scale[:,:,0:1]) #[b,1,2]
195+
syv = tf.multiply(v, scale[:,:,1:2]) #[b,1,2]
196+
zeros = tf.zeros([b, 1, 2])
197+
198+
shift_u = tf.multiply(sxu, output_w * 0.5) #[b,1,2]
199+
shift_v = tf.multiply(syv, output_h * 0.5) #[b,1,2]
200+
shift = tf.add(shift_u, shift_v) #[b,1,2]
201+
translation = tf.subtract(center, shift) #[b,1,2]
202+
203+
mat = tf.concat([sxu, syv, zeros, translation], axis=1) #[b,4,2]
204+
mat = tf.transpose(mat, perm=[0,2,1]) #[b,2,4]
205+
206+
# mat = [[ sx*dx, -sy*dy, 0, tx],
207+
# [ sx*dy, sy*dx, 0, ty],
208+
# [ 0, 0, 1, 0],
209+
# [ 0, 0, 0, 1]]
210+
unit_zw = tf.tile(tf.constant([[[0.0, 0.0, 1.0, 0.0],
211+
[0.0, 0.0, 0.0, 1.0]]]), [b,1,1]) #[b,2,4]
212+
mat = tf.concat([mat, unit_zw], axis=1) #[b,4,4]
213+
return mat

tflite2tensorflow/tflite2tensorflow.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import struct
2222
import itertools
2323
import pandas as pd
24+
from mediapipeCustomOp import Landmarks2TransformMatrix, TransformTensorBilinear, TransformLandmarks
2425

2526
class Color:
2627
BLACK = '\033[30m'
@@ -5080,6 +5081,25 @@ def complexabs_(x, tout):
50805081
)
50815082
tensors[output_detail['index']] = output_tensor
50825083

5084+
5085+
# MediaPipe v0.8.9
5086+
elif custom_op_type == 'Landmarks2TransformMatrix':
5087+
options = op['custom_options']
5088+
custom_options = read_flexbuffer(np.array(options, dtype=np.uint8).tobytes())
5089+
output_detail = interpreter._get_tensor_details(op['outputs'][0])
5090+
tensors[output_detail['index']] = Landmarks2TransformMatrix(op, custom_options, tensors, interpreter)
5091+
5092+
elif custom_op_type == 'TransformTensorBilinear':
5093+
options = op['custom_options']
5094+
custom_options = read_flexbuffer(np.array(options, dtype=np.uint8).tobytes())
5095+
output_detail = interpreter._get_tensor_details(op['outputs'][0])
5096+
tensors[output_detail['index']] = TransformTensorBilinear(op, custom_options, tensors, interpreter)
5097+
5098+
elif custom_op_type == 'TransformLandmarks':
5099+
custom_options = None
5100+
output_detail = interpreter._get_tensor_details(op['outputs'][0])
5101+
tensors[output_detail['index']] = TransformLandmarks(op, custom_options, tensors, interpreter)
5102+
50835103
elif custom_op_type == 'FlexRFFT':
50845104
input_tensor1 = None
50855105
try:

0 commit comments

Comments
 (0)