Skip to content

Commit 0a931f4

Browse files
committed
Add support for calls via pointer, such as GOT calls
1 parent d82ddf7 commit 0a931f4

File tree

1 file changed

+68
-18
lines changed

1 file changed

+68
-18
lines changed

__init__.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,13 @@ def define_cstring(bv, address):
259259
bv.define_data_var(address, Type.array(Type.char(), nul + 1))
260260
return data[:nul]
261261

262+
# Add TAILCALL ops here once "override call type" works on them
263+
LLIL_CALLS = {LowLevelILOperation.LLIL_CALL,
264+
LowLevelILOperation.LLIL_CALL_STACK_ADJUST}
265+
266+
MLIL_CALLS = {MediumLevelILOperation.MLIL_CALL,
267+
MediumLevelILOperation.MLIL_TAILCALL}
268+
262269
class PrintfTyperBase:
263270
def __init__(self, view):
264271
self.view = view
@@ -269,7 +276,9 @@ def __init__(self, view):
269276

270277
def handle_function(self, symbol, variadic_type, fmt_pos, arg_pos, thread=None):
271278
bv = self.view
272-
calls = list(bv.get_callers(symbol.address))
279+
# Using code refs instead of callers here to handle calls through named
280+
# function pointers
281+
calls = list(bv.get_code_refs(symbol.address))
273282
ncalls = len(calls)
274283
it = 1
275284

@@ -280,22 +289,48 @@ def handle_function(self, symbol, variadic_type, fmt_pos, arg_pos, thread=None):
280289
break
281290
thread.progress = "processing: {} ({}/{})".format(symbol.name, it, ncalls)
282291
it += 1
292+
293+
mlil = ref.mlil
294+
mlil_index = None
295+
if mlil is None:
296+
# If there is no mlil at this address, we'll look at the LLIL
297+
# and scan forward until we see a call that seems to match up
298+
llil_instr = ref.llil
299+
llil = ref.function.llil
300+
if llil_instr is None:
301+
log_info(f"no llil for {ref.address:#x}")
302+
continue
303+
for idx in range(llil_instr.instr_index, len(llil)):
304+
if llil[idx].operation in LLIL_CALLS and llil[idx].dest.value.value == symbol.address:
305+
call_address = llil[idx].address
306+
mlil_index = ref.function.mlil.get_instruction_start(call_address)
307+
break
308+
if idx > llil_instr.instr_index + 128:
309+
# Don't scan forward forever...
310+
break
311+
else:
312+
call_address = ref.address
313+
mlil_index = mlil.instr_index
314+
283315
func = ref.function
284316
mlil = func.medium_level_il
285-
mlil_index = mlil.get_instruction_start(ref.address)
286317
if mlil_index is None:
318+
log_info(f"no mlil index for {ref.address:#x}")
287319
continue
288320

289321
il = mlil[mlil_index]
290-
call_expr = find_expr(il, {MediumLevelILOperation.MLIL_CALL,
291-
MediumLevelILOperation.MLIL_TAILCALL})
322+
call_expr = find_expr(il, MLIL_CALLS)
292323
if call_expr is None:
293-
log_warn("Cannot find call expr for ref {:#x}".format(ref.address))
324+
log_debug("Cannot find call expr for ref {:#x}".format(call_address))
325+
continue
326+
327+
if call_expr.dest.constant != symbol.address:
328+
log_warn("{:#x}: Call expression dest {!r} does not match {!r}".format(call_address, call_expr.dest, symbol))
294329
continue
295330

296331
call_args = call_expr.operands[2]
297332
if len(call_args) <= fmt_pos:
298-
log_warn("Call at {:#x} does not respect function type".format(ref.address))
333+
log_warn("Call at {:#x} does not respect function type".format(call_address))
299334
continue
300335

301336
fmt_arg = call_args[fmt_pos]
@@ -304,13 +339,13 @@ def handle_function(self, symbol, variadic_type, fmt_pos, arg_pos, thread=None):
304339
fmt_ptr = fmt_arg_value.value
305340
fmt = define_cstring(bv, fmt_ptr)
306341
if fmt is None:
307-
log_warn("{:#x}: Bad format string at {:#x}".format(ref.address, fmt_ptr))
342+
log_warn("{:#x}: Bad format string at {:#x}".format(call_address, fmt_ptr))
308343
continue
309344

310345
fmt_type_strs = format_types(self.local_extns, fmt)
311346
# print(fmt, fmt_type_strs)
312347
if fmt_type_strs is None:
313-
log_warn("{:#x}: Failed to parse format string {!r}".format(ref.address, fmt))
348+
log_warn("{:#x}: Failed to parse format string {!r}".format(call_address, fmt))
314349
continue
315350

316351

@@ -319,29 +354,29 @@ def handle_function(self, symbol, variadic_type, fmt_pos, arg_pos, thread=None):
319354
for fmt_ptr in fmt_arg_value.values:
320355
fmt = define_cstring(bv, fmt_ptr)
321356
if fmt is None:
322-
log_warn("{:#x}: Bad format string at {:#x}".format(ref.address, fmt_ptr))
357+
log_warn("{:#x}: Bad format string at {:#x}".format(call_address, fmt_ptr))
323358
break
324359
fmt_type_strs = format_types(self.local_extns, fmt)
325360
if fmt_type_strs is None:
326-
log_warn("{:#x}: Failed to parse format string {!r}".format(ref.address, fmt))
361+
log_warn("{:#x}: Failed to parse format string {!r}".format(call_address, fmt))
327362
fmt = None
328363
break
329364
fmts.update((tuple(fmt_type_strs),))
330365

331366
if fmt is None:
332367
continue
333368
elif not fmts:
334-
log_warn("{:#x}: Unable to resolve format string from {!r}".format(ref.address, fmt_arg))
369+
log_warn("{:#x}: Unable to resolve format string from {!r}".format(call_address, fmt_arg))
335370
continue
336371
elif len(fmts) > 1:
337-
log_warn("{:#x}: Differing format types passed to one call: {!r}".format(ref.address, fmts))
372+
log_warn("{:#x}: Differing format types passed to one call: {!r}".format(call_address, fmts))
338373
continue
339374

340375
# print(fmt, fmt_type_strs)
341376
fmt_type_strs = fmts.pop()
342377

343378
else:
344-
log_warn("{:#x}: Ooh, format bug? {!r} ({!r}) is not const".format(ref.address, fmt_arg, fmt_arg_value))
379+
log_warn("{:#x}: Ooh, format bug? {!r} ({!r}) is not const".format(call_address, fmt_arg, fmt_arg_value))
345380
continue
346381

347382
try:
@@ -357,8 +392,8 @@ def handle_function(self, symbol, variadic_type, fmt_pos, arg_pos, thread=None):
357392
variable_arguments=False,
358393
calling_convention=variadic_type.calling_convention,
359394
stack_adjust=variadic_type.stack_adjustment or None)
360-
log_debug("{:#x}: format string {!r}: explicit type {!r}".format(ref.address, fmt, explicit_type))
361-
func.set_call_type_adjustment(ref.address, explicit_type)
395+
log_debug("{:#x}: format string {!r}: explicit type {!r}".format(call_address, fmt, explicit_type))
396+
func.set_call_type_adjustment(call_address, explicit_type)
362397

363398
class PrintfTyperSingle(BackgroundTaskThread):
364399
def __init__(self, view, symbol, variadic_type, fmt_pos, arg_pos):
@@ -372,6 +407,7 @@ def __init__(self, view, symbol, variadic_type, fmt_pos, arg_pos):
372407

373408
def run(self):
374409
self.progress = "processing: {}".format(self.symbol.name)
410+
log_debug(self.symbol.name)
375411
PrintfTyperBase(self.view).handle_function(self.symbol, self.variadic_type, self.fmt_pos, self.arg_pos, self)
376412

377413
def update_analysis_and_handle(self):
@@ -407,9 +443,23 @@ def run(self):
407443
for decl, positions in printf_functions.items():
408444
decl_type, name = bv.parse_type_string(decl)
409445
for symbol in bv.get_symbols_by_name(str(name)):
410-
func = bv.get_function_at(symbol.address)
411-
if func:
412-
func.set_auto_type(decl_type)
446+
# Handle PLTs and local functions
447+
if symbol.type == SymbolType.FunctionSymbol:
448+
func = bv.get_function_at(symbol.address)
449+
if func is None:
450+
continue
451+
func.set_user_type(decl_type)
452+
symbols.append((symbol, decl_type, positions))
453+
# Handle GOT entries
454+
elif symbol.type == SymbolType.ImportAddressSymbol:
455+
var = bv.get_data_var_at(symbol.address)
456+
if var is None:
457+
continue
458+
if var.type.type_class != TypeClass.PointerTypeClass:
459+
continue
460+
var.type = Type.pointer(bv.arch,
461+
decl_type,
462+
const=var.type.const)
413463
symbols.append((symbol, decl_type, positions))
414464

415465
self.progress = ""

0 commit comments

Comments
 (0)