Skip to content

Commit b404e96

Browse files
authored
arm64 conditional jumps support (#161)
* arm64 conditional jumps support Signed-off-by: Prabhu Subramanian <prabhu@appthreat.com> * tweaks Signed-off-by: Prabhu Subramanian <prabhu@appthreat.com> --------- Signed-off-by: Prabhu Subramanian <prabhu@appthreat.com>
1 parent bf1d63a commit b404e96

File tree

3 files changed

+75
-51
lines changed

3 files changed

+75
-51
lines changed

blint/lib/disassembler.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,20 @@
88

99
ARITH_INST = ['add', 'sub', 'imul', 'mul', 'div', 'idiv', 'inc', 'dec', 'neg', 'not', 'and', 'or', 'xor', 'adc', 'sbb', 'xadd', 'cmpxchg']
1010
SHIFT_INST = ['shl', 'shr', 'sal', 'sar', 'rol', 'ror', 'rcl', 'rcr', 'psll', 'psrl', 'psra', 'vpsll', 'vpsrl', 'vpsra']
11-
CONDITIONAL_JMP_INST = ['je', 'jne', 'jz', 'jnz', 'jg', 'jge', 'jl', 'jle', 'ja', 'jae', 'jb', 'jbe',
11+
CONDITIONAL_JMP_INST_X86 = ['je', 'jne', 'jz', 'jnz', 'jg', 'jge', 'jl', 'jle', 'ja', 'jae', 'jb', 'jbe',
1212
'jp', 'jnp', 'jo', 'jno',
1313
'js', 'jns', 'loop', 'loopz', 'loopnz', 'jcxz', 'jecxz', 'jrcxz']
14+
X86_CALL_INST = {'call'}
15+
X86_UNCONDITIONAL_JMP_INST = {'jmp', 'jmpq', 'jmpl'}
16+
X86_RET_INST = {'ret', 'retn', 'retf', 'iret', 'iretd', 'iretq'}
17+
ARM64_B_COND_INST = [
18+
'beq', 'bne', 'bge', 'bgt', 'ble', 'blt', 'bhs', 'bcs',
19+
'blo', 'bcc', 'bvs', 'bvc', 'bmi', 'bpl', 'bhi', 'bls'
20+
]
21+
ARM64_CB_TB_INST = ['cbz', 'cbnz', 'tbz', 'tbnz']
22+
ARM64_CONDITIONAL_JMP_INST = ARM64_B_COND_INST + ARM64_CB_TB_INST
23+
CONDITIONAL_JMP_INST = CONDITIONAL_JMP_INST_X86 + ARM64_CONDITIONAL_JMP_INST
24+
1425
ARM64_GENERAL_REGS_64 = {f'x{i}' for i in range(31)}
1526
ARM64_GENERAL_REGS_32 = {f'w{i}' for i in range(31)}
1627
ARM64_SPECIAL_REGS = {'sp', 'xzr', 'wzr'}
@@ -21,6 +32,11 @@
2132
ARM64_ALL_REGS = (
2233
ARM64_GENERAL_REGS_64 | ARM64_GENERAL_REGS_32 | ARM64_SPECIAL_REGS | ARM64_VFP_NEON_REGS
2334
)
35+
ARM64_CALL_INST = {'bl', 'blr'}
36+
ARM64_UNCONDITIONAL_JMP_INST = {'b', 'br'}
37+
ARM64_RET_INST = {'ret', 'eret'}
38+
TERMINATING_INST = X86_RET_INST | ARM64_RET_INST
39+
UNCONDITIONAL_JMP_INST_ALL = X86_UNCONDITIONAL_JMP_INST | ARM64_UNCONDITIONAL_JMP_INST
2440
SORTED_ARM64_ALL_REGS = sorted(ARM64_ALL_REGS, key=len, reverse=True)
2541

2642
COMMON_REGS_64 = {'rax', 'rbx', 'rcx', 'rdx', 'rsi', 'rdi', 'rbp', 'rsp',
@@ -32,9 +48,6 @@
3248
COMMON_REGS_8l = {'al', 'bl', 'cl', 'dl', 'sil', 'dil', 'bpl', 'spl',
3349
'r8b', 'r9b', 'r10b', 'r11b', 'r12b', 'r13b', 'r14b', 'r15b'}
3450
COMMON_REGS_8h = {'ah', 'bh', 'ch', 'dh'}
35-
36-
TERMINATING_INST = {'ret', 'retn', 'retf', 'iret', 'iretd', 'iretq'}
37-
UNCONDITIONAL_JMP_INST = {'jmp', 'jmpq', 'jmpl'}
3851
READ_WRITE_BOTH_OPS_INST = {'xadd', 'cmpxchg', 'cmpxchg8b', 'cmpxchg16b'}
3952
BIT_MANIPULATION_INST = {'bt', 'bts', 'bsf', 'bsr', 'btr', 'btc', 'popcnt', 'lzcnt', 'tzcnt'}
4053
READ_WRITE_ONE_OP_INST = {'inc', 'dec', 'not', 'neg', 'rol', 'ror', 'rcl', 'rcr', 'shl', 'shr', 'sal', 'sar'}
@@ -104,10 +117,10 @@ def _find_function_end_index(instr_list):
104117
if mnemonic in TERMINATING_INST:
105118
if i + 1 < len(instr_list):
106119
next_mnemonic = instr_list[i+1].assembly.split(None, 1)[0].lower()
107-
if next_mnemonic in ['int3', 'nop']:
120+
if next_mnemonic in ('int3', 'nop'):
108121
return i
109122
return i
110-
if mnemonic in UNCONDITIONAL_JMP_INST:
123+
if mnemonic in UNCONDITIONAL_JMP_INST_ALL:
111124
return i
112125
return len(instr_list) - 1
113126

@@ -130,7 +143,7 @@ def _get_abi_volatile_regs(parsed_obj, arch_target):
130143
def _get_function_ranges(parsed_obj, metadata):
131144
"""Calculates the address ranges for each function based on the next function or section end."""
132145
section_func_map = {}
133-
for func_list_key in ["functions", "ctor_functions", "exception_functions", "unwind_functions", "exports"]:
146+
for func_list_key in ("functions", "ctor_functions", "exception_functions", "unwind_functions", "exports"):
134147
for func_entry in metadata.get(func_list_key, []):
135148
func_addr_str = func_entry.get("address", "")
136149
if func_addr_str:
@@ -249,7 +262,7 @@ def _extract_register_usage(instr_assembly, parsed_obj=None, arch_target="", sor
249262
regs_read = set()
250263
regs_written = set()
251264
if not instr_assembly:
252-
return list(regs_read), list(regs_written)
265+
return sorted(regs_read), sorted(regs_written)
253266
is_aarch64 = "aarch64" in arch_target.lower() or "arm64" in arch_target.lower()
254267
if not sorted_arch_regs:
255268
sorted_arch_regs = get_arch_reg_set(arch_target)
@@ -279,7 +292,7 @@ def _extract_register_usage(instr_assembly, parsed_obj=None, arch_target="", sor
279292
regs_read.add(counter_reg)
280293
regs_written.add(counter_reg)
281294
if is_aarch64:
282-
if mnemonic in ['add', 'adds', 'sub', 'subs', 'neg', 'negs', 'mul', 'umull', 'smull', 'smulh', 'umulh', 'div', 'udiv']:
295+
if mnemonic in ('add', 'adds', 'sub', 'subs', 'neg', 'negs', 'mul', 'umull', 'smull', 'smulh', 'umulh', 'div', 'udiv'):
283296
if num_operands >= 2:
284297
dst_regs = extract_regs_from_operand(operands[0].lower(), sorted_arch_regs)
285298
src1_regs = extract_regs_from_operand(operands[1].lower(), sorted_arch_regs)
@@ -288,24 +301,24 @@ def _extract_register_usage(instr_assembly, parsed_obj=None, arch_target="", sor
288301
if num_operands >= 3:
289302
src2_regs = extract_regs_from_operand(operands[2].lower(), sorted_arch_regs)
290303
regs_read.update(src2_regs)
291-
elif mnemonic in ['mov', 'movz', 'movk', 'movn', 'fmov', 'fmov immediate']:
304+
elif mnemonic in ('mov', 'movz', 'movk', 'movn', 'fmov', 'fmov immediate'):
292305
if num_operands >= 1:
293306
dst_regs = extract_regs_from_operand(operands[0].lower(), sorted_arch_regs)
294307
regs_written.update(dst_regs)
295308
if num_operands >= 2 and not operands[1].lower().startswith('#'):
296309
src_regs = extract_regs_from_operand(operands[1].lower(), sorted_arch_regs)
297310
regs_read.update(src_regs)
298-
elif mnemonic in ['csel', 'csinc', 'csinv', 'cset', 'csetm', 'cinc', 'cinv']:
311+
elif mnemonic in ('csel', 'csinc', 'csinv', 'cset', 'csetm', 'cinc', 'cinv'):
299312
if num_operands >= 3:
300313
dst_regs = extract_regs_from_operand(operands[0].lower(), sorted_arch_regs)
301314
src1_regs = extract_regs_from_operand(operands[1].lower(), sorted_arch_regs)
302315
src2_regs = extract_regs_from_operand(operands[2].lower(), sorted_arch_regs)
303316
regs_written.update(dst_regs)
304317
regs_read.update(src1_regs)
305318
regs_read.update(src2_regs)
306-
if mnemonic in ['cinc', 'cinv']:
319+
if mnemonic in ('cinc', 'cinv'):
307320
regs_read.update(dst_regs)
308-
elif mnemonic in ['cmp', 'cmn', 'tst']:
321+
elif mnemonic in ('cmp', 'cmn', 'tst'):
309322
if num_operands >= 2:
310323
src1_regs = extract_regs_from_operand(operands[0].lower(), sorted_arch_regs)
311324
src2_regs = extract_regs_from_operand(operands[1].lower(), sorted_arch_regs)
@@ -343,17 +356,17 @@ def _extract_register_usage(instr_assembly, parsed_obj=None, arch_target="", sor
343356
if num_operands >= 1:
344357
src_regs = extract_regs_from_operand(operands[0].lower(), sorted_arch_regs)
345358
regs_read.update(src_regs)
346-
elif mnemonic.startswith('b') and mnemonic not in ['bl', 'blr', 'br']:
359+
elif mnemonic.startswith('b') and mnemonic not in ('bl', 'blr', 'br'):
347360
pass
348-
elif mnemonic in ['bl', 'blr', 'br']:
361+
elif mnemonic in ('bl', 'blr', 'br'):
349362
if num_operands >= 1 and mnemonic != 'bl':
350363
target_op = operands[0].lower()
351364
if not target_op.startswith('#') and not target_op.isdigit():
352365
target_regs = extract_regs_from_operand(target_op, sorted_arch_regs)
353366
regs_read.update(target_regs)
354-
elif mnemonic in ['ret']:
367+
elif mnemonic in ('ret', 'eret'):
355368
pass
356-
elif mnemonic in ['and', 'orr', 'eor', 'bic', 'tst']:
369+
elif mnemonic in ('and', 'orr', 'eor', 'bic', 'tst'):
357370
if num_operands >= 2:
358371
dst_regs = extract_regs_from_operand(operands[0].lower(), sorted_arch_regs)
359372
src1_regs = extract_regs_from_operand(operands[1].lower(), sorted_arch_regs)
@@ -362,7 +375,7 @@ def _extract_register_usage(instr_assembly, parsed_obj=None, arch_target="", sor
362375
if num_operands >= 3:
363376
src2_regs = extract_regs_from_operand(operands[2].lower(), sorted_arch_regs)
364377
regs_read.update(src2_regs)
365-
elif mnemonic in ['lsl', 'lsr', 'asr', 'ror', 'uxtw', 'sxtw', 'sxtx', 'uxtb', 'uxth', 'sxtb', 'sxth']:
378+
elif mnemonic in ('lsl', 'lsr', 'asr', 'ror', 'uxtw', 'sxtw', 'sxtx', 'uxtb', 'uxth', 'sxtb', 'sxth'):
366379
if num_operands >= 2:
367380
dst_regs = extract_regs_from_operand(operands[0].lower(), sorted_arch_regs)
368381
src1_regs = extract_regs_from_operand(operands[1].lower(), sorted_arch_regs)
@@ -378,7 +391,7 @@ def _extract_register_usage(instr_assembly, parsed_obj=None, arch_target="", sor
378391
src_regs = extract_regs_from_operand(operands[1].lower(), sorted_arch_regs)
379392
regs_written.update(dst_regs)
380393
regs_read.update(src_regs)
381-
if mnemonic not in ['mov', 'movzx', 'movsx', 'movsxd', 'lea'] and not mnemonic.startswith('cmov'):
394+
if mnemonic not in ('mov', 'movzx', 'movsx', 'movsxd', 'lea') and not mnemonic.startswith('cmov'):
382395
regs_read.update(dst_regs)
383396
elif mnemonic in READ_WRITE_BOTH_OPS_INST:
384397
if num_operands >= 2:
@@ -395,18 +408,18 @@ def _extract_register_usage(instr_assembly, parsed_obj=None, arch_target="", sor
395408
src_regs = extract_regs_from_operand(operands[1].lower(), sorted_arch_regs)
396409
regs_written.update(dst_regs)
397410
regs_read.update(src_regs)
398-
if mnemonic not in ['bsf', 'bsr', 'lzcnt', 'tzcnt', 'popcnt']:
411+
if mnemonic not in ('bsf', 'bsr', 'lzcnt', 'tzcnt', 'popcnt'):
399412
regs_read.update(dst_regs)
400413
elif mnemonic in READ_WRITE_ONE_OP_INST:
401414
if num_operands >= 1:
402415
op_regs = extract_regs_from_operand(operands[0].lower(), sorted_arch_regs)
403416
regs_read.update(op_regs)
404417
regs_written.update(op_regs)
405-
elif mnemonic in ['cmp', 'test']:
418+
elif mnemonic in ('cmp', 'test'):
406419
if num_operands >= 2:
407420
regs_read.update(extract_regs_from_operand(operands[0].lower(), sorted_arch_regs))
408421
regs_read.update(extract_regs_from_operand(operands[1].lower(), sorted_arch_regs))
409-
elif mnemonic in ['push', 'pop']:
422+
elif mnemonic in ('push', 'pop'):
410423
is_64bit = "64" in arch_target
411424
stack_reg = 'rsp' if is_64bit else 'esp'
412425
regs_read.add(stack_reg)
@@ -448,14 +461,23 @@ def _extract_register_usage(instr_assembly, parsed_obj=None, arch_target="", sor
448461
regs_written.update(op1_regs)
449462
regs_read.update(op2_regs)
450463
regs_written.update(op2_regs)
451-
if mnemonic in ['mul', 'imul', 'div', 'idiv'] and num_operands == 1:
464+
if mnemonic in ('mul', 'imul', 'div', 'idiv') and num_operands == 1:
452465
op_regs = extract_regs_from_operand(operands[0].lower(), sorted_arch_regs)
453466
regs_read.update(op_regs)
454467

455-
return list(regs_read), list(regs_written)
468+
return sorted(regs_read), sorted(regs_written)
456469

457470
def _analyze_instructions(instr_list, func_addr, next_func_addr_in_sec, instr_addresses, parsed_obj=None, arch_target=""):
458471
"""Analyzes the list of instructions for metrics, loops, and indirect calls."""
472+
is_aarch64 = "aarch64" in arch_target.lower() or "arm64" in arch_target.lower()
473+
if is_aarch64:
474+
CALL_INST = ARM64_CALL_INST
475+
UNCONDITIONAL_JMP_INST = ARM64_UNCONDITIONAL_JMP_INST
476+
RET_INST = ARM64_RET_INST
477+
else:
478+
CALL_INST = X86_CALL_INST
479+
UNCONDITIONAL_JMP_INST = X86_UNCONDITIONAL_JMP_INST
480+
RET_INST = X86_RET_INST
459481
instruction_mnemonics = []
460482
instruction_metrics = {
461483
"call_count": 0,
@@ -502,7 +524,7 @@ def _analyze_instructions(instr_list, func_addr, next_func_addr_in_sec, instr_ad
502524
if sreg_operand and sreg_operand in _SREG_TO_CATEGORY_MAP:
503525
sreg_interactions.add(_SREG_TO_CATEGORY_MAP[sreg_operand])
504526
instruction_mnemonics.append(mnemonic)
505-
if mnemonic in ('call'):
527+
if mnemonic in CALL_INST:
506528
instruction_metrics["call_count"] += 1
507529
elif mnemonic in CONDITIONAL_JMP_INST:
508530
instruction_metrics["conditional_jump_count"] += 1
@@ -515,36 +537,27 @@ def _analyze_instructions(instr_list, func_addr, next_func_addr_in_sec, instr_ad
515537
has_loop = True
516538
except ValueError:
517539
continue
540+
elif mnemonic in UNCONDITIONAL_JMP_INST:
541+
instruction_metrics["jump_count"] += 1
518542
elif mnemonic == 'xor':
519543
instruction_metrics["xor_count"] += 1
520544
elif mnemonic in SHIFT_INST:
521545
instruction_metrics["shift_count"] += 1
522546
elif mnemonic in ARITH_INST:
523547
instruction_metrics["arith_count"] += 1
524-
elif mnemonic == 'ret':
548+
elif mnemonic in RET_INST:
525549
instruction_metrics["ret_count"] += 1
526-
elif mnemonic in ['jmp', 'jmpq', 'jmpl']:
527-
instruction_metrics["jump_count"] += 1
528-
if instr_assembly.startswith(('call ', 'jmp ')):
529-
parts = instr_assembly.split(None, 1)
530-
if len(parts) > 1:
531-
operand = parts[1].lower().strip()
532-
if operand.startswith('[') and operand.endswith(']'):
533-
has_indirect_call = True
534-
elif any(operand.startswith(reg) for reg in sorted_arch_regs):
535-
if operand.isalnum() or '_' in operand:
536-
has_indirect_call = True
537550
# Check for ARM64 indirect calls and jumps
538-
elif instr_assembly.startswith(('bl ', 'blr ', 'br ')):
539-
parts = instr_assembly.split(None, 1)
551+
if mnemonic in (CALL_INST | UNCONDITIONAL_JMP_INST):
552+
is_indirect = False
540553
if len(parts) > 1:
541554
operand = parts[1].lower().strip()
542-
if any(operand.startswith(reg) for reg in sorted_arch_regs):
543-
has_indirect_call = True
544-
elif '[' in operand and ']' in operand:
555+
if any(operand.startswith(reg) for reg in sorted_arch_regs) and (operand.isalnum() or '_' in operand):
556+
is_indirect = True
557+
elif '[' in operand and ']' in operand:
558+
is_indirect = True
559+
if is_indirect:
545560
has_indirect_call = True
546-
elif operand.startswith('#') or operand.startswith(('+', '-')) or operand.startswith('0x'):
547-
instruction_metrics["call_count"] += 1
548561
regs_read, regs_written = _extract_register_usage(instr_assembly, parsed_obj, arch_target, sorted_arch_regs)
549562
all_instr_regs = set(regs_read) | set(regs_written)
550563
is_simd_fpu = False
@@ -578,12 +591,12 @@ def _analyze_instructions(instr_list, func_addr, next_func_addr_in_sec, instr_ad
578591
})
579592
instruction_metrics["unique_regs_read_count"] = len(all_regs_read)
580593
instruction_metrics["unique_regs_written_count"] = len(all_regs_written)
581-
return instruction_metrics, instruction_mnemonics, has_indirect_call, has_loop, list(all_regs_read), list(all_regs_written), instructions_with_registers, list(used_simd_reg_types), list(proprietary_instr_found), list(sreg_interactions)
594+
return instruction_metrics, instruction_mnemonics, has_indirect_call, has_loop, sorted(all_regs_read), sorted(all_regs_written), instructions_with_registers, sorted(used_simd_reg_types), sorted(proprietary_instr_found), sorted(sreg_interactions)
582595

583596
def _build_addr_to_name_map(metadata):
584597
"""Builds a lookup map from address (int) to name from metadata functions."""
585598
addr_to_name_map = {}
586-
for func_list_key in ["functions", "ctor_functions", "exception_functions", "unwind_functions", "exports", "imports", "symtab_symbols", "dynamic_symbols"]:
599+
for func_list_key in ("functions", "ctor_functions", "exception_functions", "unwind_functions", "exports", "imports", "symtab_symbols", "dynamic_symbols"):
587600
for func_entry in metadata.get(func_list_key, []):
588601
addr_str = func_entry.get("address", "")
589602
name = func_entry.get("name", "")
@@ -634,7 +647,18 @@ def _resolve_direct_calls(instr_list, addr_to_name_map, arch_target=""):
634647
def _classify_function(instruction_metrics, instruction_count, plain_assembly_text, has_system_call, has_indirect_call):
635648
"""Classifies the function based on metrics and other flags."""
636649
function_type = ""
637-
if instruction_metrics["jump_count"] > 0 and instruction_count <= 5 and all(mnem in ['jmp', 'push', 'sub'] for mnem in [i.split(None, 1)[0].lower() for i in plain_assembly_text.split('\n') if i.strip()]):
650+
if (
651+
instruction_metrics["jump_count"] > 0
652+
and instruction_count <= 5
653+
and all(
654+
mnem in ("jmp", "push", "sub")
655+
for mnem in [
656+
i.split(None, 1)[0].lower()
657+
for i in plain_assembly_text.split("\n")
658+
if i.strip()
659+
]
660+
)
661+
):
638662
function_type = "PLT_Thunk"
639663
elif instruction_count == 1 and instruction_metrics["ret_count"] == 1:
640664
function_type = "Simple_Return"

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/test_disassembler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_analyze_instructions_basic(mock_instructions):
108108
regs_read, regs_written, instrs_with_regs, _, _, _) = _analyze_instructions(
109109
mock_instructions, func_addr, next_func_addr_in_sec, instr_addresses, {}, "x86_64"
110110
)
111-
assert metrics["call_count"] == 3
111+
assert metrics["call_count"] == 2
112112
assert metrics["arith_count"] == 1
113113
assert metrics["ret_count"] == 1
114114
assert metrics["conditional_jump_count"] == 1

0 commit comments

Comments
 (0)