@@ -259,6 +259,13 @@ def define_cstring(bv, address):
259
259
bv .define_data_var (address , Type .array (Type .char (), nul + 1 ))
260
260
return data [:nul ]
261
261
262
+ # Add TAILCALL ops here once "override call type" works on them
263
+ LLIL_CALLS = {LowLevelILOperation .LLIL_CALL ,
264
+ LowLevelILOperation .LLIL_CALL_STACK_ADJUST }
265
+
266
+ MLIL_CALLS = {MediumLevelILOperation .MLIL_CALL ,
267
+ MediumLevelILOperation .MLIL_TAILCALL }
268
+
262
269
class PrintfTyperBase :
263
270
def __init__ (self , view ):
264
271
self .view = view
@@ -269,7 +276,9 @@ def __init__(self, view):
269
276
270
277
def handle_function (self , symbol , variadic_type , fmt_pos , arg_pos , thread = None ):
271
278
bv = self .view
272
- calls = list (bv .get_callers (symbol .address ))
279
+ # Using code refs instead of callers here to handle calls through named
280
+ # function pointers
281
+ calls = list (bv .get_code_refs (symbol .address ))
273
282
ncalls = len (calls )
274
283
it = 1
275
284
@@ -280,22 +289,48 @@ def handle_function(self, symbol, variadic_type, fmt_pos, arg_pos, thread=None):
280
289
break
281
290
thread .progress = "processing: {} ({}/{})" .format (symbol .name , it , ncalls )
282
291
it += 1
292
+
293
+ mlil = ref .mlil
294
+ mlil_index = None
295
+ if mlil is None :
296
+ # If there is no mlil at this address, we'll look at the LLIL
297
+ # and scan forward until we see a call that seems to match up
298
+ llil_instr = ref .llil
299
+ llil = ref .function .llil
300
+ if llil_instr is None :
301
+ log_info (f"no llil for { ref .address :#x} " )
302
+ continue
303
+ for idx in range (llil_instr .instr_index , len (llil )):
304
+ if llil [idx ].operation in LLIL_CALLS and llil [idx ].dest .value .value == symbol .address :
305
+ call_address = llil [idx ].address
306
+ mlil_index = ref .function .mlil .get_instruction_start (call_address )
307
+ break
308
+ if idx > llil_instr .instr_index + 128 :
309
+ # Don't scan forward forever...
310
+ break
311
+ else :
312
+ call_address = ref .address
313
+ mlil_index = mlil .instr_index
314
+
283
315
func = ref .function
284
316
mlil = func .medium_level_il
285
- mlil_index = mlil .get_instruction_start (ref .address )
286
317
if mlil_index is None :
318
+ log_info (f"no mlil index for { ref .address :#x} " )
287
319
continue
288
320
289
321
il = mlil [mlil_index ]
290
- call_expr = find_expr (il , {MediumLevelILOperation .MLIL_CALL ,
291
- MediumLevelILOperation .MLIL_TAILCALL })
322
+ call_expr = find_expr (il , MLIL_CALLS )
292
323
if call_expr is None :
293
- log_warn ("Cannot find call expr for ref {:#x}" .format (ref .address ))
324
+ log_debug ("Cannot find call expr for ref {:#x}" .format (call_address ))
325
+ continue
326
+
327
+ if call_expr .dest .constant != symbol .address :
328
+ log_warn ("{:#x}: Call expression dest {!r} does not match {!r}" .format (call_address , call_expr .dest , symbol ))
294
329
continue
295
330
296
331
call_args = call_expr .operands [2 ]
297
332
if len (call_args ) <= fmt_pos :
298
- log_warn ("Call at {:#x} does not respect function type" .format (ref . address ))
333
+ log_warn ("Call at {:#x} does not respect function type" .format (call_address ))
299
334
continue
300
335
301
336
fmt_arg = call_args [fmt_pos ]
@@ -304,13 +339,13 @@ def handle_function(self, symbol, variadic_type, fmt_pos, arg_pos, thread=None):
304
339
fmt_ptr = fmt_arg_value .value
305
340
fmt = define_cstring (bv , fmt_ptr )
306
341
if fmt is None :
307
- log_warn ("{:#x}: Bad format string at {:#x}" .format (ref . address , fmt_ptr ))
342
+ log_warn ("{:#x}: Bad format string at {:#x}" .format (call_address , fmt_ptr ))
308
343
continue
309
344
310
345
fmt_type_strs = format_types (self .local_extns , fmt )
311
346
# print(fmt, fmt_type_strs)
312
347
if fmt_type_strs is None :
313
- log_warn ("{:#x}: Failed to parse format string {!r}" .format (ref . address , fmt ))
348
+ log_warn ("{:#x}: Failed to parse format string {!r}" .format (call_address , fmt ))
314
349
continue
315
350
316
351
@@ -319,29 +354,29 @@ def handle_function(self, symbol, variadic_type, fmt_pos, arg_pos, thread=None):
319
354
for fmt_ptr in fmt_arg_value .values :
320
355
fmt = define_cstring (bv , fmt_ptr )
321
356
if fmt is None :
322
- log_warn ("{:#x}: Bad format string at {:#x}" .format (ref . address , fmt_ptr ))
357
+ log_warn ("{:#x}: Bad format string at {:#x}" .format (call_address , fmt_ptr ))
323
358
break
324
359
fmt_type_strs = format_types (self .local_extns , fmt )
325
360
if fmt_type_strs is None :
326
- log_warn ("{:#x}: Failed to parse format string {!r}" .format (ref . address , fmt ))
361
+ log_warn ("{:#x}: Failed to parse format string {!r}" .format (call_address , fmt ))
327
362
fmt = None
328
363
break
329
364
fmts .update ((tuple (fmt_type_strs ),))
330
365
331
366
if fmt is None :
332
367
continue
333
368
elif not fmts :
334
- log_warn ("{:#x}: Unable to resolve format string from {!r}" .format (ref . address , fmt_arg ))
369
+ log_warn ("{:#x}: Unable to resolve format string from {!r}" .format (call_address , fmt_arg ))
335
370
continue
336
371
elif len (fmts ) > 1 :
337
- log_warn ("{:#x}: Differing format types passed to one call: {!r}" .format (ref . address , fmts ))
372
+ log_warn ("{:#x}: Differing format types passed to one call: {!r}" .format (call_address , fmts ))
338
373
continue
339
374
340
375
# print(fmt, fmt_type_strs)
341
376
fmt_type_strs = fmts .pop ()
342
377
343
378
else :
344
- log_warn ("{:#x}: Ooh, format bug? {!r} ({!r}) is not const" .format (ref . address , fmt_arg , fmt_arg_value ))
379
+ log_warn ("{:#x}: Ooh, format bug? {!r} ({!r}) is not const" .format (call_address , fmt_arg , fmt_arg_value ))
345
380
continue
346
381
347
382
try :
@@ -357,8 +392,8 @@ def handle_function(self, symbol, variadic_type, fmt_pos, arg_pos, thread=None):
357
392
variable_arguments = False ,
358
393
calling_convention = variadic_type .calling_convention ,
359
394
stack_adjust = variadic_type .stack_adjustment or None )
360
- log_debug ("{:#x}: format string {!r}: explicit type {!r}" .format (ref . address , fmt , explicit_type ))
361
- func .set_call_type_adjustment (ref . address , explicit_type )
395
+ log_debug ("{:#x}: format string {!r}: explicit type {!r}" .format (call_address , fmt , explicit_type ))
396
+ func .set_call_type_adjustment (call_address , explicit_type )
362
397
363
398
class PrintfTyperSingle (BackgroundTaskThread ):
364
399
def __init__ (self , view , symbol , variadic_type , fmt_pos , arg_pos ):
@@ -372,6 +407,7 @@ def __init__(self, view, symbol, variadic_type, fmt_pos, arg_pos):
372
407
373
408
def run (self ):
374
409
self .progress = "processing: {}" .format (self .symbol .name )
410
+ log_debug (self .symbol .name )
375
411
PrintfTyperBase (self .view ).handle_function (self .symbol , self .variadic_type , self .fmt_pos , self .arg_pos , self )
376
412
377
413
def update_analysis_and_handle (self ):
@@ -407,9 +443,23 @@ def run(self):
407
443
for decl , positions in printf_functions .items ():
408
444
decl_type , name = bv .parse_type_string (decl )
409
445
for symbol in bv .get_symbols_by_name (str (name )):
410
- func = bv .get_function_at (symbol .address )
411
- if func :
412
- func .set_auto_type (decl_type )
446
+ # Handle PLTs and local functions
447
+ if symbol .type == SymbolType .FunctionSymbol :
448
+ func = bv .get_function_at (symbol .address )
449
+ if func is None :
450
+ continue
451
+ func .set_user_type (decl_type )
452
+ symbols .append ((symbol , decl_type , positions ))
453
+ # Handle GOT entries
454
+ elif symbol .type == SymbolType .ImportAddressSymbol :
455
+ var = bv .get_data_var_at (symbol .address )
456
+ if var is None :
457
+ continue
458
+ if var .type .type_class != TypeClass .PointerTypeClass :
459
+ continue
460
+ var .type = Type .pointer (bv .arch ,
461
+ decl_type ,
462
+ const = var .type .const )
413
463
symbols .append ((symbol , decl_type , positions ))
414
464
415
465
self .progress = ""
0 commit comments