@@ -358,23 +358,78 @@ vector<uint32_t> CallingConvention::GetRegisterArgumentListRegs(uint32_t regList
358358 return vector<uint32_t >();
359359}
360360
361+
361362vector<Variable> CallingConvention::GetVariablesForParameters (
362363 const vector<FunctionParameter>& params, const std::optional<set<uint32_t >>& permittedRegs)
363364{
364- vector<uint32_t > intArgs = GetIntegerArgumentRegisters ();
365- vector<uint32_t > floatArgs = GetFloatArgumentRegisters ();
365+ vector<uint32_t > classes = GetRegisterArgumentClasses ();
366+
367+ // Build register lists for all classes
368+ // The order of iterators matter here, for register class and register list
369+ // we have assumed the INTEGER_SEMANTICS should be the first ones to be processed
370+ vector<vector<uint32_t >> allRegLists;
371+ vector<BNRegisterListKind> allListKinds;
372+ vector<vector<uint32_t >::iterator> allIterators;
373+ vector<vector<uint32_t >::iterator> allEndIterators;
374+ bool hasSharedIndex = false ;
375+
376+ for (uint32_t classId : classes)
377+ {
378+ vector<uint32_t > registerLists = GetRegisterArgumentClassLists (classId);
379+ if (registerLists.size () > 1 )
380+ hasSharedIndex = true ;
381+
382+ for (uint32_t regListId : registerLists)
383+ {
384+ vector<uint32_t > regs = GetRegisterArgumentListRegs (regListId);
385+ BNRegisterListKind kind = GetRegisterArgumentListKind (regListId);
386+
387+ allRegLists.push_back (regs);
388+ allListKinds.push_back (kind);
389+ allIterators.push_back (allRegLists.back ().begin ());
390+ allEndIterators.push_back (allRegLists.back ().end ());
391+ }
392+ }
393+
394+ // Fallback to legacy API if no register classes defined
395+ if (allRegLists.empty ())
396+ {
397+ vector<uint32_t > intArgs = GetIntegerArgumentRegisters ();
398+ vector<uint32_t > floatArgs = GetFloatArgumentRegisters ();
399+
400+ if (!intArgs.empty ())
401+ {
402+ allRegLists.push_back (intArgs);
403+ allListKinds.push_back (REGISTER_LIST_KIND_INTEGER_SEMANTICS);
404+ allIterators.push_back (allRegLists.back ().begin ());
405+ allEndIterators.push_back (allRegLists.back ().end ());
406+ }
407+
408+ if (!floatArgs.empty ())
409+ {
410+ allRegLists.push_back (floatArgs);
411+ allListKinds.push_back (REGISTER_LIST_KIND_FLOAT_SEMANTICS);
412+ allIterators.push_back (allRegLists.back ().begin ());
413+ allEndIterators.push_back (allRegLists.back ().end ());
414+ }
415+
416+ hasSharedIndex = AreArgumentRegistersSharedIndex ();
417+ }
366418
367419 vector<Variable> result;
368- auto intArgIter = intArgs.begin ();
369- auto floatArgIter = floatArgs.begin ();
370420 size_t addrSize = GetArchitecture ()->GetAddressSize ();
371421 int64_t stackOffset = 0 ;
372- bool sharedIndex = AreArgumentRegistersSharedIndex ();
422+
373423 if (GetArchitecture ()->GetLinkRegister () == BN_INVALID_REGISTER)
374424 stackOffset = addrSize;
375425 if (IsStackReservedForArgumentRegisters ())
376- stackOffset += intArgs.size () * addrSize;
377-
426+ {
427+ // Count total registers for stack reservation
428+ size_t totalRegs = 0 ;
429+ for (const auto & list : allRegLists)
430+ totalRegs = std::max (totalRegs, list.size ());
431+ stackOffset += totalRegs * addrSize;
432+ }
378433
379434 // TODO: Structure in register and multi-reg parameters
380435 for (auto & param : params)
@@ -385,22 +440,26 @@ vector<Variable> CallingConvention::GetVariablesForParameters(
385440 {
386441 // Parameter not storage in a normal location, use custom variable
387442 result.push_back (param.location );
443+
388444 if (param.location .type == RegisterVariableSourceType)
389445 {
390- // If non-default location matches the next register in the register parameter
391- // lists, advance the iterators. It may just be a type mismatch, and we still
392- // want to maintain the state for future parameters.
393- if (intArgIter != intArgs.end () && *intArgIter == param.location .storage )
394- {
395- intArgIter++;
396- if (sharedIndex && floatArgIter != floatArgs.end ())
397- floatArgIter++;
398- }
399- else if (floatArgIter != floatArgs.end () && *floatArgIter == param.location .storage )
446+ for (size_t i = 0 ; i < allIterators.size (); ++i)
400447 {
401- floatArgIter++;
402- if (sharedIndex && intArgIter != intArgs.end ())
403- intArgIter++;
448+ if (allIterators[i] != allEndIterators[i] && *allIterators[i] == param.location .storage )
449+ {
450+ allIterators[i]++;
451+
452+ // Advance all other iterators if shared index
453+ if (hasSharedIndex)
454+ {
455+ for (size_t j = i + 1 ; j < allIterators.size (); ++j)
456+ {
457+ if (allIterators[j] != allEndIterators[j])
458+ allIterators[j]++;
459+ }
460+ }
461+ break ;
462+ }
404463 }
405464 }
406465 else if (param.location .type == StackVariableSourceType)
@@ -416,62 +475,74 @@ vector<Variable> CallingConvention::GetVariablesForParameters(
416475 continue ;
417476 }
418477
419- if (param.type ->IsFloat ())
478+ // Try to find a suitable register for this parameter
479+ bool paramPlaced = false ;
480+
481+ for (size_t i = 0 ; i < allIterators.size (); ++i)
420482 {
421- if (permittedRegs.has_value () && floatArgIter != floatArgs.end ()
422- && permittedRegs.value ().count (*floatArgIter) == 0 )
423- {
424- // Disallowed register parameter, start spilling to stack. This is used in calling
425- // conventions that place all variable argument parameters on the stack.
426- floatArgIter = floatArgs.end ();
427- if (sharedIndex)
428- intArgIter = intArgs.end ();
429- }
430- else if (floatArgIter != floatArgs.end ())
483+ if (allIterators[i] == allEndIterators[i])
484+ continue ;
485+
486+ // Check if this register is permitted
487+ if (permittedRegs.has_value () && permittedRegs.value ().count (*allIterators[i]) == 0 )
431488 {
432- BNRegisterInfo regInfo = GetArchitecture ()->GetRegisterInfo (*floatArgIter);
433- if (width <= regInfo.size )
489+ // Disallowed register parameter, mark this list as exhausted
490+ allIterators[i] = allEndIterators[i];
491+ if (hasSharedIndex)
434492 {
435- result.emplace_back (RegisterVariableSourceType, 0 , *floatArgIter);
436- floatArgIter++;
437- if (sharedIndex && intArgIter != intArgs.end ())
438- intArgIter++;
439- continue ;
493+ // Mark all lists as exhausted when shared index
494+ for (size_t j = 0 ; j < allIterators.size (); ++j)
495+ allIterators[j] = allEndIterators[j];
440496 }
497+ continue ;
441498 }
442- }
443- else
444- {
445- if (permittedRegs. has_value () && intArgIter != intArgs. end ()
446- && permittedRegs. value (). count (*intArgIter) == 0 )
447- {
448- // Disallowed register parameter, start spilling to stack. This is used in calling
449- // conventions that place all variable argument parameters on the stack.
450- intArgIter = intArgs. end () ;
451- if (sharedIndex )
452- floatArgIter = floatArgs. end () ;
453- }
454- else if (intArgIter != intArgs. end () )
499+
500+ // Check if the type matches the register semantics
501+ bool typeMatches = false ;
502+ BNRegisterListKind kind = allListKinds[i];
503+
504+ if (kind == REGISTER_LIST_KIND_INTEGER_SEMANTICS && !param. type -> IsFloat ())
505+ typeMatches = true ;
506+ else if (kind == REGISTER_LIST_KIND_FLOAT_SEMANTICS && param. type -> IsFloat ())
507+ typeMatches = true ;
508+ else if (kind == REGISTER_LIST_KIND_POINTER_SEMANTICS && param. type -> IsPointer () )
509+ typeMatches = true ;
510+
511+ if (typeMatches )
455512 {
456- BNRegisterInfo regInfo = GetArchitecture ()->GetRegisterInfo (*intArgIter );
513+ BNRegisterInfo regInfo = GetArchitecture ()->GetRegisterInfo (*allIterators[i] );
457514 if (width <= regInfo.size )
458515 {
459- result.emplace_back (RegisterVariableSourceType, 0 , *intArgIter);
460- intArgIter++;
461- if (sharedIndex && floatArgIter != floatArgs.end ())
462- floatArgIter++;
463- continue ;
516+ result.emplace_back (RegisterVariableSourceType, 0 , *allIterators[i]);
517+ allIterators[i]++;
518+
519+ // Advance all other iterators if shared index
520+ if (hasSharedIndex)
521+ {
522+ for (size_t j = i + 1 ; j < allIterators.size (); ++j)
523+ {
524+ if (allIterators[j] != allEndIterators[j])
525+ allIterators[j]++;
526+ }
527+ }
528+
529+ paramPlaced = true ;
530+ break ;
464531 }
465532 }
466533 }
534+
535+ // If not placed in register, place on stack
536+ if (!paramPlaced)
537+ {
538+ result.emplace_back (StackVariableSourceType, 0 , stackOffset);
467539
468- result.emplace_back (StackVariableSourceType, 0 , stackOffset);
469-
470- if (width < addrSize)
471- width = addrSize;
472- else if ((width % addrSize) != 0 )
473- width += addrSize - (width % addrSize);
474- stackOffset += width;
540+ if (width < addrSize)
541+ width = addrSize;
542+ else if ((width % addrSize) != 0 )
543+ width += addrSize - (width % addrSize);
544+ stackOffset += width;
545+ }
475546 }
476547
477548 return result;
0 commit comments