|
| 1 | +############################################################################### |
| 2 | +# TensorFlow implementation of MediaPipe custom operators |
| 3 | +############################################################################### |
| 4 | +# |
| 5 | +# MIT License |
| 6 | +# |
| 7 | +# Copyright (c) 2022 Akiya Research Institute, Inc. |
| 8 | +# |
| 9 | +# Permission is hereby granted, free of charge, to any person obtaining a copy |
| 10 | +# of this software and associated documentation files (the "Software"), to deal |
| 11 | +# in the Software without restriction, including without limitation the rights |
| 12 | +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 13 | +# copies of the Software, and to permit persons to whom the Software is |
| 14 | +# furnished to do so, subject to the following conditions: |
| 15 | +# |
| 16 | +# The above copyright notice and this permission notice shall be included in all |
| 17 | +# copies or substantial portions of the Software. |
| 18 | +# |
| 19 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 20 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 21 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 22 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 23 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 24 | +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 25 | +# SOFTWARE. |
| 26 | + |
| 27 | +import tensorflow.compat.v1 as tf |
| 28 | +import numpy as np |
| 29 | + |
| 30 | +#Affine transform points |
| 31 | +def TransformLandmarks(operator, custom_options, tensors, interpreter, landmarks2d=None, mat=None): |
| 32 | + if landmarks2d is None: |
| 33 | + landmarks2d = tensors[operator['inputs'][0]] #float32 [b,80,2] landmarks 2d |
| 34 | + if mat is None: |
| 35 | + mat = tensors[operator['inputs'][1]] #float32 [b,4,4] affine transform matrix |
| 36 | + b = landmarks2d.shape[0] |
| 37 | + |
| 38 | + # extract important values |
| 39 | + mat_rot = mat[:,0:2,0:2] #[b,2,2] |
| 40 | + translation = mat[:,0:2,3:4] #[b,2,1] |
| 41 | + translation = tf.reshape(translation, [b,1,2]) |
| 42 | + |
| 43 | + # Find the corresponding point in the input image |
| 44 | + landmarks2d_transformed = tf.matmul(landmarks2d, mat_rot, transpose_b=True) #[b,80,2] |
| 45 | + landmarks2d_transformed = tf.add(landmarks2d_transformed, translation) #[b,80,2] |
| 46 | + return landmarks2d_transformed |
| 47 | + |
| 48 | +#Affine transform images using bilinear interpolation |
| 49 | +def TransformTensorBilinear(operator, custom_options, tensors, interpreter, features=None, mat=None): |
| 50 | + if features is None: |
| 51 | + features = tensors[operator['inputs'][0]] #float32 [b,48,48,32] feature maps |
| 52 | + if mat is None: |
| 53 | + mat = tensors[operator['inputs'][1]] #float32 [b,4,4] affine transform matrix |
| 54 | + w = custom_options['output_width'] |
| 55 | + h = custom_options['output_height'] |
| 56 | + b = features.shape[0] |
| 57 | + input_h = features.shape[1] |
| 58 | + input_w = features.shape[2] |
| 59 | + |
| 60 | + # extract important values |
| 61 | + mat_rot = mat[:,0:2,0:2] #[b,2,2] |
| 62 | + translation = mat[:,0:2,3:4] #[b,2,1] |
| 63 | + translation = tf.reshape(translation, [b,1,1,2]) |
| 64 | + |
| 65 | + # construct output image coordinates |
| 66 | + # out_coord = [[[ 0,0],[ 0,1],[ 0,2],...,[0,15]], |
| 67 | + # [[ 1,0],[ 1,1],[ 1,2],...,[1,15]], |
| 68 | + # ... |
| 69 | + # [[15,0],[15,1],[15,2],...,[15,15]]] |
| 70 | + array_w = np.arange(w) #[0,1,2,...,15] |
| 71 | + array_h = np.arange(h) #[0,1,2,...,15] |
| 72 | + X, Y = np.meshgrid(array_w, array_h) #[h,w] |
| 73 | + out_coord = np.stack([X,Y], axis=2) #[h,w,2] |
| 74 | + out_coord = np.expand_dims(out_coord, axis=0).astype(np.float32) #[1,h,w,2] |
| 75 | + out_coord = tf.tile(out_coord, [b,1,1,1]) #[b,h,w,2] |
| 76 | + |
| 77 | + # Find the corresponding point in the input image |
| 78 | + in_coord = tf.matmul(out_coord, mat_rot, transpose_b=True) #[b,h,w,2] |
| 79 | + in_coord = tf.add(in_coord, translation) #[b,h,w,2] |
| 80 | + |
| 81 | + # Find the weights for the nearest 4 points |
| 82 | + in_coord_floor = tf.floor(in_coord) #[b,h,w,2] |
| 83 | + weight_ceil_ = tf.subtract(in_coord, in_coord_floor) #[b,h,w,2] |
| 84 | + weight_floor = tf.subtract(tf.ones(2), weight_ceil_) #[b,h,w,2] |
| 85 | + weight_ceilX = tf.multiply(weight_ceil_[:,:,:,0:1], weight_floor[:,:,:,1:2]) #[b,h,w] |
| 86 | + weight_ceilY = tf.multiply(weight_floor[:,:,:,0:1], weight_ceil_[:,:,:,1:2]) #[b,h,w] |
| 87 | + weight_ceil_ = tf.reduce_prod(weight_ceil_, axis=3, keepdims=True) #[b,h,w,1] |
| 88 | + weight_floor = tf.reduce_prod(weight_floor, axis=3, keepdims=True) #[b,h,w,1] |
| 89 | + |
| 90 | + # Find nearest 4 points. |
| 91 | + # Make sure they are in the input image |
| 92 | + in_coord_floor = tf.cast(in_coord_floor, dtype=tf.int32) #[b,h,w,XY] |
| 93 | + in_coord_floor = tf.maximum(in_coord_floor, tf.zeros(2, dtype=tf.int32)) #[b,h,w,XY] |
| 94 | + in_coord_floor = tf.minimum(in_coord_floor, [input_w, input_h]) #[b,h,w,XY] |
| 95 | + |
| 96 | + in_coord_ceil_ = tf.add(in_coord_floor, tf.ones(2, dtype=tf.int32)) #[b,h,w,XY] |
| 97 | + # in_coord_ceil_ = tf.maximum(in_coord_ceil_, tf.zeros(2, dtype=tf.int32)) #[b,h,w,XY] |
| 98 | + in_coord_ceil_ = tf.minimum(in_coord_ceil_, [input_w, input_h]) #[b,h,w,XY] |
| 99 | + |
| 100 | + in_coord_ceilX = tf.concat([in_coord_floor[:,:,:,1:2], in_coord_ceil_[:,:,:,0:1]], axis=3) #[b,h,w,YX] YX for BHWC |
| 101 | + in_coord_ceilY = tf.concat([in_coord_ceil_[:,:,:,1:2], in_coord_floor[:,:,:,0:1]], axis=3) #[b,h,w,YX] |
| 102 | + in_coord_floor = tf.concat([in_coord_floor[:,:,:,1:2], in_coord_floor[:,:,:,0:1]], axis=3) #[b,h,w,YX] |
| 103 | + in_coord_ceil_ = tf.concat([in_coord_ceil_[:,:,:,1:2], in_coord_ceil_[:,:,:,0:1]], axis=3) #[b,h,w,YX] |
| 104 | + |
| 105 | + # calc final pixel value |
| 106 | + value_floor = tf.gather_nd(params=features, indices=in_coord_floor, batch_dims=1) #[b,h,w,32] |
| 107 | + value_ceilX = tf.gather_nd(params=features, indices=in_coord_ceilX, batch_dims=1) #[b,h,w,32] |
| 108 | + value_ceilY = tf.gather_nd(params=features, indices=in_coord_ceilY, batch_dims=1) #[b,h,w,32] |
| 109 | + value_ceil_ = tf.gather_nd(params=features, indices=in_coord_ceil_, batch_dims=1) #[b,h,w,32] |
| 110 | + value_floor_fraction = tf.multiply(value_floor, weight_floor) |
| 111 | + value_ceil__fraction = tf.multiply(value_ceil_, weight_ceil_) |
| 112 | + value_ceilX_fraction = tf.multiply(value_ceilX, weight_ceilX) |
| 113 | + value_ceilY_fraction = tf.multiply(value_ceilY, weight_ceilY) |
| 114 | + |
| 115 | + #[b,h,w,32] |
| 116 | + value = tf.add( |
| 117 | + tf.add(value_floor_fraction, value_ceil__fraction), |
| 118 | + tf.add(value_ceilX_fraction, value_ceilY_fraction) |
| 119 | + ) |
| 120 | + |
| 121 | + return value |
| 122 | + |
| 123 | +# Left indexとRight indexで指定されたLandmarkを結ぶ線が水平になり、Subset indicesで指定されたLandmrakをちょうど含むような範囲をcropするように、元の画像をAffine変換する行列 |
| 124 | +# の逆行列を求める。なぜ、逆行列かといういうと、後の計算で使うのが逆行列だから。 |
| 125 | +# Calc inverse of the matrix which represetns the affine transform which crops the area which covers all the landmarks specified by "subset indices" and rotates so that the landmarks specified by "Left index" and "Right index" are horizontally aligned. |
| 126 | +def Landmarks2TransformMatrix(operator, custom_options, tensors, interpreter, landmarks3d=None): |
| 127 | + if landmarks3d is None: |
| 128 | + landmarks3d = tensors[operator['inputs'][0]] #float32 [b,468,3] landmarks |
| 129 | + landmarks2d = landmarks3d[:,:,0:2] # [b,468,2] |
| 130 | + b = landmarks3d.shape[0] |
| 131 | + |
| 132 | + ###################################### |
| 133 | + # calc rotation |
| 134 | + ###################################### |
| 135 | + rot90_t = tf.constant([[ 0.0, 1.0], |
| 136 | + [ -1.0, 0.0]]) #[2,2], already transposed |
| 137 | + |
| 138 | + idx_rot_l = custom_options['left_rotation_idx'] |
| 139 | + idx_rot_r = custom_options['right_rotation_idx'] |
| 140 | + left_ = landmarks2d[:,idx_rot_l:idx_rot_l+1,:] #[b,1,2] |
| 141 | + right = landmarks2d[:,idx_rot_r:idx_rot_r+1,:] #[b,1,2] |
| 142 | + |
| 143 | + delta = tf.subtract(right, left_) #[b,1,2] |
| 144 | + length = tf.norm(delta, axis=2, keepdims=True) #[b,1,1] |
| 145 | + |
| 146 | + u = tf.divide(delta, length) #[b,1,2] = [[ dx, dy]] |
| 147 | + v = tf.matmul(u, rot90_t) #[b,1,2] = [[-dy, dx]] |
| 148 | + |
| 149 | + # mat_rot_inv = [[ dx, dy], |
| 150 | + # [-dy, dx]] |
| 151 | + # mat_rot = [[ dx, -dy], |
| 152 | + # [ dy, dx]] |
| 153 | + mat_rot_inv = tf.concat([u, v], axis=1) #[b,2,2] 切り取り後の画像座標から、切り取り前の画像座標への回転 |
| 154 | + mat_rot = tf.transpose(mat_rot_inv, perm=[0,2,1]) #[b,2,2] 切り取り前の画像座標から、切り取り後の画像座標への回転 |
| 155 | + |
| 156 | + ###################################### |
| 157 | + # calc crop size and center |
| 158 | + ###################################### |
| 159 | + subset_idxs = custom_options['subset_idxs'] #[80] |
| 160 | + landmarks2d_subset = tf.gather(landmarks2d, indices=subset_idxs, axis=1) #[b,80,2] |
| 161 | + landmarks2d_subset_rotated = tf.matmul(landmarks2d_subset, mat_rot) #[b,80,2] 切り取り前の画像上でのLandmark座標を、切り取り後の画像上での向きにあわせて回転 |
| 162 | + landmarks2d_subset_rotated_min = tf.reduce_min(landmarks2d_subset_rotated, axis=1, keepdims=True) #[b,1,2] |
| 163 | + landmarks2d_subset_rotated_max = tf.reduce_max(landmarks2d_subset_rotated, axis=1, keepdims=True) #[b,1,2] |
| 164 | + |
| 165 | + crop_size = tf.subtract(landmarks2d_subset_rotated_max, landmarks2d_subset_rotated_min) #[b,1,2], max - min |
| 166 | + center = tf.multiply(tf.add(landmarks2d_subset_rotated_min, landmarks2d_subset_rotated_max), tf.constant(0.5)) #[b,1,2], 1/2 * (max + min) |
| 167 | + center = tf.matmul(center, mat_rot_inv) #[b,1,2] 切り取り後の画像上での向きから、切り取り前の画像上での向きに回転 |
| 168 | + |
| 169 | + ###################################### |
| 170 | + # calc scale |
| 171 | + ###################################### |
| 172 | + # s = [[scale_x * crop_size.x / output_w], |
| 173 | + # [scale_y * crop_size.y / output_h]]] |
| 174 | + output_w = custom_options['output_width'] |
| 175 | + output_h = custom_options['output_height'] |
| 176 | + scale_x = custom_options['scale_x'] |
| 177 | + scale_y = custom_options['scale_y'] |
| 178 | + scaling_const_x = scale_x / output_w |
| 179 | + scaling_const_y = scale_y / output_h |
| 180 | + scaling_const = tf.constant([[scaling_const_x, scaling_const_y]]) #[1,2] |
| 181 | + scale = tf.multiply(scaling_const, crop_size) #[b,1,2] |
| 182 | + |
| 183 | + ###################################### |
| 184 | + # calc translation and final mat |
| 185 | + ###################################### |
| 186 | + # mat = [[ sx*dx, -sy*dy, 0, tx], |
| 187 | + # [ sx*dy, sy*dx, 0, ty]] |
| 188 | + # where |
| 189 | + # |
| 190 | + # t = center - shift |
| 191 | + # |
| 192 | + # shift = -0.5 * output_w * sx * u |
| 193 | + # + -0.5 * output_h * sy * v |
| 194 | + sxu = tf.multiply(u, scale[:,:,0:1]) #[b,1,2] |
| 195 | + syv = tf.multiply(v, scale[:,:,1:2]) #[b,1,2] |
| 196 | + zeros = tf.zeros([b, 1, 2]) |
| 197 | + |
| 198 | + shift_u = tf.multiply(sxu, output_w * 0.5) #[b,1,2] |
| 199 | + shift_v = tf.multiply(syv, output_h * 0.5) #[b,1,2] |
| 200 | + shift = tf.add(shift_u, shift_v) #[b,1,2] |
| 201 | + translation = tf.subtract(center, shift) #[b,1,2] |
| 202 | + |
| 203 | + mat = tf.concat([sxu, syv, zeros, translation], axis=1) #[b,4,2] |
| 204 | + mat = tf.transpose(mat, perm=[0,2,1]) #[b,2,4] |
| 205 | + |
| 206 | + # mat = [[ sx*dx, -sy*dy, 0, tx], |
| 207 | + # [ sx*dy, sy*dx, 0, ty], |
| 208 | + # [ 0, 0, 1, 0], |
| 209 | + # [ 0, 0, 0, 1]] |
| 210 | + unit_zw = tf.tile(tf.constant([[[0.0, 0.0, 1.0, 0.0], |
| 211 | + [0.0, 0.0, 0.0, 1.0]]]), [b,1,1]) #[b,2,4] |
| 212 | + mat = tf.concat([mat, unit_zw], axis=1) #[b,4,4] |
| 213 | + return mat |
0 commit comments