@@ -84,6 +84,7 @@ def __call__(self, x, file_names=None):
84
84
if isinstance (x , dict ):
85
85
x = x ['logits' ]
86
86
assert isinstance (x , paddle .Tensor )
87
+
87
88
if file_names is not None :
88
89
assert x .shape [0 ] == len (file_names )
89
90
x = F .sigmoid (x ).numpy ()
@@ -98,6 +99,7 @@ def __call__(self, x, file_names=None):
98
99
'Skirt&Dress'
99
100
]
100
101
batch_res = []
102
+
101
103
for idx , res in enumerate (x ):
102
104
res = res .tolist ()
103
105
label_res = []
@@ -171,6 +173,66 @@ def __call__(self, x, file_names=None):
171
173
return batch_res
172
174
173
175
176
+ class FaceAttribute (object ):
177
+ def __init__ (self , threshold = 0.65 , convert_cn = False ):
178
+ self .threshold = threshold
179
+ self .convert_cn = convert_cn
180
+
181
+ def __call__ (self , x , file_names = None ):
182
+ if isinstance (x , dict ):
183
+ x = x ['logits' ]
184
+ assert isinstance (x , paddle .Tensor )
185
+
186
+ if file_names is not None :
187
+ assert x .shape [0 ] == len (file_names )
188
+ x = F .sigmoid (x ).numpy ()
189
+
190
+ attribute_list = [
191
+ ["CheekWhiskers" , "刚长出的双颊胡须" ], ["ArchedEyebrows" , "柳叶眉" ],
192
+ ["Attractive" , "吸引人的" ], ["BagsUnderEyes" , "眼袋" ], ["Bald" , "秃头" ],
193
+ ["Bangs" , "刘海" ], ["BigLips" , "大嘴唇" ], ["BigNose" , "大鼻子" ],
194
+ ["BlackHair" , "黑发" ], ["BlondHair" , "金发" ], ["Blurry" , "模糊的" ],
195
+ ["BrownHair" , "棕发" ], ["BushyEyebrows" , "浓眉" ], ["Chubby" , "圆胖的" ],
196
+ ["DoubleChin" , "双下巴" ], ["Eyeglasses" , "带眼镜" ], ["Goatee" , "山羊胡子" ],
197
+ ["GrayHair" , "灰发或白发" ], ["HeavyMakeup" , "浓妆" ],
198
+ ["HighCheekbones" , "高颧骨" ], ["Male" , "男性" ],
199
+ ["MouthSlightlyOpen" , "微微张开嘴巴" ], ["Mustache" , "胡子" ],
200
+ ["NarrowEyes" , "细长的眼睛" ], ["NoBeard" , "无胡子" ],
201
+ ["OvalFace" , "椭圆形的脸" ], ["PaleSkin" , "苍白的皮肤" ],
202
+ ["PointyNose" , "尖鼻子" ], ["RecedingHairline" , "发际线后移" ],
203
+ ["RosyCheeks" , "红润的双颊" ], ["Sideburns" , "连鬓胡子" ], ["Smiling" , "微笑" ],
204
+ ["StraightHair" , "直发" ], ["WavyHair" , "卷发" ],
205
+ ["WearingEarrings" , "戴着耳环" ], ["WearingHat" , "戴着帽子" ],
206
+ ["WearingLipstick" , "涂了唇膏" ], ["WearingNecklace" , "戴着项链" ],
207
+ ["WearingNecktie" , "戴着领带" ], ["Young" , "年轻人" ]
208
+ ]
209
+ gender_list = [["Male" , "男性" ], ["Female" , "女性" ]]
210
+ age_list = [["Young" , "年轻人" ], ["Old" , "老年人" ]]
211
+ batch_res = []
212
+ if self .convert_cn :
213
+ index = 1
214
+ else :
215
+ index = 0
216
+ for idx , res in enumerate (x ):
217
+ res = res .tolist ()
218
+ label_res = []
219
+ threshold_list = [self .threshold ] * len (res )
220
+ pred_res = (np .array (res ) > np .array (threshold_list )
221
+ ).astype (np .int8 ).tolist ()
222
+ for i , value in enumerate (pred_res ):
223
+ if i == 20 :
224
+ label_res .append (gender_list [0 ][index ]
225
+ if value == 1 else gender_list [1 ][index ])
226
+ elif i == 39 :
227
+ label_res .append (age_list [0 ][index ]
228
+ if value == 1 else age_list [1 ][index ])
229
+ else :
230
+ if value == 1 :
231
+ label_res .append (attribute_list [i ][index ])
232
+ batch_res .append ({"attributes" : label_res , "output" : pred_res })
233
+ return batch_res
234
+
235
+
174
236
class TableAttribute (object ):
175
237
def __init__ (
176
238
self ,
0 commit comments