diff --git a/decalib/datasets/datasets.py b/decalib/datasets/datasets.py index 112c77c1..333e7d19 100755 --- a/decalib/datasets/datasets.py +++ b/decalib/datasets/datasets.py @@ -139,5 +139,7 @@ def __getitem__(self, index): return {'image': torch.tensor(dst_image).float(), 'imagename': imagename, # 'tform': tform, - # 'original_image': torch.tensor(image.transpose(2,0,1)).float(), + 'center': center, + 'size': size, + 'original_image': torch.tensor(image.transpose(2,0,1)).float(), } \ No newline at end of file diff --git a/decalib/deca.py b/decalib/deca.py index ad4c536b..40e98f7e 100644 --- a/decalib/deca.py +++ b/decalib/deca.py @@ -224,12 +224,28 @@ def decode(self, codedict, rendering=True, iddict=None, vis_lmk=True, return_vis uv_texture_gt = uv_gt[:,:3,:,:]*self.uv_face_eye_mask + (torch.ones_like(uv_gt[:,:3,:,:])*(1-self.uv_face_eye_mask)*0.7) opdict['uv_texture_gt'] = uv_texture_gt + + offsets = [1, -1] + shifted_shapes = [] + for offset1 in offsets: + for offset2 in offsets: + shift_offset = torch.zeros_like(verts) + shift_offset[:,:,0] = shift_offset[:,:,0] + offset2 + shift_offset[:,:,1] = shift_offset[:,:,1] + offset1 + shifted_shapes.append(self.render.render_shape(verts, trans_verts+shift_offset)) + + shifted_shape_lt, shifted_shape_rt, shifted_shape_lb, shifted_shape_rb = shifted_shapes + visdict = { 'inputs': images, 'landmarks2d': util.tensor_vis_landmarks(images, landmarks2d), 'landmarks3d': util.tensor_vis_landmarks(images, landmarks3d), 'shape_images': shape_images, - 'shape_detail_images': shape_detail_images + 'shape_detail_images': shape_detail_images, + 'shifted_shape_lt': shifted_shape_lt, + 'shifted_shape_rt': shifted_shape_rt, + 'shifted_shape_lb': shifted_shape_lb, + 'shifted_shape_rb': shifted_shape_rb } if self.cfg.model.use_tex: visdict['rendered_images'] = ops['images'] diff --git a/decalib/utils/util.py b/decalib/utils/util.py index 94f7ff53..24735171 100755 --- a/decalib/utils/util.py +++ b/decalib/utils/util.py @@ -575,6 +575,23 @@ def dict_tensor2npy(tensor_dict): return npy_dict # ---------------------------------- visualization +def render_overlap(src_img, trg_img_data): + ''' + warp source image to match with target background image + ''' + center = trg_img_data['center'] + size = trg_img_data['size'] + src_pts = np.array([[0,0], [0,src_img.shape[2] - 1], + [src_img.shape[3] - 1, 0], [src_img.shape[3] - 1, src_img.shape[2] - 1]]) + dst_pts = np.array([[center[0]-size, center[1]-size], [center[0]-size, center[1]+size], + [center[0]+size, center[1]-size], [center[0]+size, center[1]+size]]) + + tform = estimate_transform('similarity', src_pts, dst_pts) + dst_image = warp(src_img[0].detach().cpu().numpy().transpose(1,2,0), tform.inverse, + output_shape=(trg_img_data['original_image'].shape[1], trg_img_data['original_image'].shape[2])) + dst_image = torch.tensor(dst_image.transpose(2,0,1)).float() + return dst_image + end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype = np.int32) - 1 def plot_kpts(image, kpts, color = 'r'): ''' Draw 68 key points diff --git a/demos/demo_reconstruct.py b/demos/demo_reconstruct.py index db677e95..4cbf5020 100755 --- a/demos/demo_reconstruct.py +++ b/demos/demo_reconstruct.py @@ -70,6 +70,25 @@ def main(args): continue image =util.tensor2image(visdict[vis_name][0]) cv2.imwrite(os.path.join(savefolder, name, name + '_' + vis_name +'.jpg'), util.tensor2image(visdict[vis_name][0])) + # overlap on input (face segmented) image + wfp = os.path.join(savefolder, name, name + '_' + 'overlap' +'.jpg') + alpha = 0.6 + deca.render.render_shape(opdict['verts'], opdict['trans_verts']) + res = cv2.addWeighted(util.tensor2image(visdict['inputs'][0]), 1 - alpha, util.tensor2image(visdict['shape_detail_images'][0]), alpha, 0) + cv2.imwrite(wfp, res) + # save original image + cv2.imwrite(os.path.join(savefolder, name, name + '_' + 'original_image' +'.jpg'), util.tensor2image(testdata[i]['original_image'])) + # get full shape image + shape_full = torch.cat((torch.cat((visdict['shifted_shape_lt'], visdict['shifted_shape_rt']), 3), + torch.cat((visdict['shifted_shape_lb'], visdict['shifted_shape_rb']), 3)),2) + cv2.imwrite(os.path.join(savefolder, name, name + '_' + 'shape_full' +'.jpg'), util.tensor2image(shape_full[0])) + # overlap full shape image to original image + wfp = os.path.join(savefolder, name, name + '_' + 'original_overlap' +'.jpg') + dst_image = util.render_overlap(shape_full, testdata[i]) + #cv2.imwrite(os.path.join(savefolder, name, name + '_' + 'shape_trans' +'.jpg'), util.tensor2image(dst_image)) #.transpose(2,0,1).float() + alpha = 0.6 + res = cv2.addWeighted(util.tensor2image(testdata[i]['original_image']), 1 - alpha, util.tensor2image(dst_image), alpha, 0) + cv2.imwrite(wfp, res) print(f'-- please check the results in {savefolder}') if __name__ == '__main__':