diff --git a/README.md b/README.md index 6e5994a305..77857ca6a0 100644 --- a/README.md +++ b/README.md @@ -155,6 +155,8 @@ Join our discord community via [this invite link](https://discord.gg/bxgXW8jJGh) | [key\_name](#input\_key\_name) | Key pair name | `string` | `null` | no | | [kms\_key\_arn](#input\_kms\_key\_arn) | Optional CMK Key ARN to be used for Parameter Store. This key must be in the current account. | `string` | `null` | no | | [lambda\_architecture](#input\_lambda\_architecture) | AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions. | `string` | `"arm64"` | no | +| [lambda\_event\_source\_mapping\_batch\_size](#input\_lambda\_event\_source\_mapping\_batch\_size) | Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used. | `number` | `10` | no | +| [lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds](#input\_lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds) | Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10. Defaults to 0. | `number` | `0` | no | | [lambda\_principals](#input\_lambda\_principals) | (Optional) add extra principals to the role created for execution of the lambda, e.g. for local testing. |
list(object({
type = string
identifiers = list(string)
}))
| `[]` | no | | [lambda\_runtime](#input\_lambda\_runtime) | AWS Lambda runtime. | `string` | `"nodejs22.x"` | no | | [lambda\_s3\_bucket](#input\_lambda\_s3\_bucket) | S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly. | `string` | `null` | no | diff --git a/lambdas/functions/control-plane/src/aws/runners.test.ts b/lambdas/functions/control-plane/src/aws/runners.test.ts index c4ec328c9b..8d508fbf5e 100644 --- a/lambdas/functions/control-plane/src/aws/runners.test.ts +++ b/lambdas/functions/control-plane/src/aws/runners.test.ts @@ -419,9 +419,12 @@ describe('create runner with errors', () => { }); it('test ScaleError with multiple error.', async () => { - createFleetMockWithErrors(['UnfulfillableCapacity', 'SomeError']); + createFleetMockWithErrors(['UnfulfillableCapacity', 'MaxSpotInstanceCountExceeded', 'NotMappedError']); - await expect(createRunner(createRunnerConfig(defaultRunnerConfig))).rejects.toBeInstanceOf(ScaleError); + await expect(createRunner(createRunnerConfig(defaultRunnerConfig))).rejects.toMatchObject({ + name: 'ScaleError', + failedInstanceCount: 2, + }); expect(mockEC2Client).toHaveReceivedCommandWith( CreateFleetCommand, expectedCreateFleetRequest(defaultExpectedFleetRequestValues), @@ -521,7 +524,7 @@ describe('create runner with errors fail over to OnDemand', () => { expect(mockEC2Client).toHaveReceivedCommandTimes(CreateFleetCommand, 2); - // first call with spot failuer + // first call with spot failure expect(mockEC2Client).toHaveReceivedNthCommandWith(1, CreateFleetCommand, { ...expectedCreateFleetRequest({ ...defaultExpectedFleetRequestValues, @@ -530,7 +533,7 @@ describe('create runner with errors fail over to OnDemand', () => { }), }); - // second call with with OnDemand failback + // second call with with OnDemand fallback expect(mockEC2Client).toHaveReceivedNthCommandWith(2, CreateFleetCommand, { ...expectedCreateFleetRequest({ ...defaultExpectedFleetRequestValues, @@ -540,13 +543,13 @@ describe('create runner with errors fail over to OnDemand', () => { }); }); - it('test InsufficientInstanceCapacity no failback.', async () => { + it('test InsufficientInstanceCapacity no fallback.', async () => { await expect( createRunner(createRunnerConfig({ ...defaultRunnerConfig, onDemandFailoverOnError: [] })), ).rejects.toBeInstanceOf(Error); }); - it('test InsufficientInstanceCapacity with mutlipte instances and fallback to on demand .', async () => { + it('test InsufficientInstanceCapacity with multiple instances and fallback to on demand .', async () => { const instancesIds = ['i-123', 'i-456']; createFleetMockWithWithOnDemandFallback(['InsufficientInstanceCapacity'], instancesIds); diff --git a/lambdas/functions/control-plane/src/aws/runners.ts b/lambdas/functions/control-plane/src/aws/runners.ts index 6779dd39d2..d95dc99fa4 100644 --- a/lambdas/functions/control-plane/src/aws/runners.ts +++ b/lambdas/functions/control-plane/src/aws/runners.ts @@ -166,53 +166,62 @@ async function processFleetResult( ): Promise { const instances: string[] = fleet.Instances?.flatMap((i) => i.InstanceIds?.flatMap((j) => j) || []) || []; - if (instances.length !== runnerParameters.numberOfRunners) { - logger.warn( - `${ - instances.length === 0 ? 'No' : instances.length + ' off ' + runnerParameters.numberOfRunners - } instances created.`, - { data: fleet }, - ); - const errors = fleet.Errors?.flatMap((e) => e.ErrorCode || '') || []; - - // Educated guess of errors that would make sense to retry based on the list - // https://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html - const scaleErrors = [ - 'UnfulfillableCapacity', - 'MaxSpotInstanceCountExceeded', - 'TargetCapacityLimitExceededException', - 'RequestLimitExceeded', - 'ResourceLimitExceeded', - 'MaxSpotInstanceCountExceeded', - 'MaxSpotFleetRequestCountExceeded', - 'InsufficientInstanceCapacity', - ]; - - if ( - errors.some((e) => runnerParameters.onDemandFailoverOnError?.includes(e)) && - runnerParameters.ec2instanceCriteria.targetCapacityType === 'spot' - ) { - logger.warn(`Create fleet failed, initatiing fall back to on demand instances.`); - logger.debug('Create fleet failed.', { data: fleet.Errors }); - const numberOfInstances = runnerParameters.numberOfRunners - instances.length; - const instancesOnDemand = await createRunner({ - ...runnerParameters, - numberOfRunners: numberOfInstances, - onDemandFailoverOnError: ['InsufficientInstanceCapacity'], - ec2instanceCriteria: { ...runnerParameters.ec2instanceCriteria, targetCapacityType: 'on-demand' }, - }); - instances.push(...instancesOnDemand); - return instances; - } else if (errors.some((e) => scaleErrors.includes(e))) { - logger.warn('Create fleet failed, ScaleError will be thrown to trigger retry for ephemeral runners.'); - logger.debug('Create fleet failed.', { data: fleet.Errors }); - throw new ScaleError('Failed to create instance, create fleet failed.'); - } else { - logger.warn('Create fleet failed, error not recognized as scaling error.', { data: fleet.Errors }); - throw Error('Create fleet failed, no instance created.'); - } + if (instances.length === runnerParameters.numberOfRunners) { + return instances; } - return instances; + + logger.warn( + `${ + instances.length === 0 ? 'No' : instances.length + ' off ' + runnerParameters.numberOfRunners + } instances created.`, + { data: fleet }, + ); + + const errors = fleet.Errors?.flatMap((e) => e.ErrorCode || '') || []; + + if ( + errors.some((e) => runnerParameters.onDemandFailoverOnError?.includes(e)) && + runnerParameters.ec2instanceCriteria.targetCapacityType === 'spot' + ) { + logger.warn(`Create fleet failed, initatiing fall back to on demand instances.`); + logger.debug('Create fleet failed.', { data: fleet.Errors }); + const numberOfInstances = runnerParameters.numberOfRunners - instances.length; + const instancesOnDemand = await createRunner({ + ...runnerParameters, + numberOfRunners: numberOfInstances, + onDemandFailoverOnError: ['InsufficientInstanceCapacity'], + ec2instanceCriteria: { ...runnerParameters.ec2instanceCriteria, targetCapacityType: 'on-demand' }, + }); + instances.push(...instancesOnDemand); + return instances; + } + + // Educated guess of errors that would make sense to retry based on the list + // https://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html + const scaleErrors = [ + 'UnfulfillableCapacity', + 'MaxSpotInstanceCountExceeded', + 'TargetCapacityLimitExceededException', + 'RequestLimitExceeded', + 'ResourceLimitExceeded', + 'MaxSpotInstanceCountExceeded', + 'MaxSpotFleetRequestCountExceeded', + 'InsufficientInstanceCapacity', + ]; + + const failedCount = countScaleErrors(errors, scaleErrors); + if (failedCount > 0) { + logger.warn('Create fleet failed, ScaleError will be thrown to trigger retry for ephemeral runners.'); + logger.debug('Create fleet failed.', { data: fleet.Errors }); + throw new ScaleError(failedCount); + } + + logger.warn('Create fleet failed, error not recognized as scaling error.', { data: fleet.Errors }); + throw Error('Create fleet failed, no instance created.'); +} + +function countScaleErrors(errors: string[], scaleErrors: string[]): number { + return errors.reduce((acc, e) => (scaleErrors.includes(e) ? acc + 1 : acc), 0); } async function getAmiIdOverride(runnerParameters: Runners.RunnerInputParameters): Promise { diff --git a/lambdas/functions/control-plane/src/lambda.test.ts b/lambdas/functions/control-plane/src/lambda.test.ts index 2c54a4d541..3e6a897e88 100644 --- a/lambdas/functions/control-plane/src/lambda.test.ts +++ b/lambdas/functions/control-plane/src/lambda.test.ts @@ -28,11 +28,11 @@ const sqsRecord: SQSRecord = { }, awsRegion: '', body: JSON.stringify(body), - eventSource: 'aws:SQS', + eventSource: 'aws:sqs', eventSourceARN: '', md5OfBody: '', messageAttributes: {}, - messageId: '', + messageId: 'abcd1234', receiptHandle: '', }; @@ -70,19 +70,33 @@ vi.mock('@aws-github-runner/aws-powertools-util'); vi.mock('@aws-github-runner/aws-ssm-util'); describe('Test scale up lambda wrapper.', () => { - it('Do not handle multiple record sets.', async () => { - await testInvalidRecords([sqsRecord, sqsRecord]); + it('Do not handle empty record sets.', async () => { + const sqsEventMultipleRecords: SQSEvent = { + Records: [], + }; + + await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow(); }); - it('Do not handle empty record sets.', async () => { - await testInvalidRecords([]); + it('Ignores non-sqs event sources.', async () => { + const record = { + ...sqsRecord, + eventSource: 'aws:non-sqs', + }; + + const sqsEventMultipleRecordsNonSQS: SQSEvent = { + Records: [record], + }; + + await expect(scaleUpHandler(sqsEventMultipleRecordsNonSQS, context)).resolves.not.toThrow(); + expect(scaleUp).toHaveBeenCalledWith([]); }); it('Scale without error should resolve.', async () => { const mock = vi.fn(scaleUp); mock.mockImplementation(() => { return new Promise((resolve) => { - resolve(); + resolve([]); }); }); await expect(scaleUpHandler(sqsEvent, context)).resolves.not.toThrow(); @@ -95,37 +109,150 @@ describe('Test scale up lambda wrapper.', () => { await expect(scaleUpHandler(sqsEvent, context)).resolves.not.toThrow(); }); - it('Scale should be rejected', async () => { - const error = new ScaleError('Scale should be rejected'); + it('Scale should create a batch failure message', async () => { + const error = new ScaleError(); const mock = vi.fn() as MockedFunction; mock.mockImplementation(() => { return Promise.reject(error); }); vi.mocked(scaleUp).mockImplementation(mock); - await expect(scaleUpHandler(sqsEvent, context)).rejects.toThrow(error); + await expect(scaleUpHandler(sqsEvent, context)).resolves.toEqual({ + batchItemFailures: [{ itemIdentifier: sqsRecord.messageId }], + }); }); -}); -async function testInvalidRecords(sqsRecords: SQSRecord[]) { - const mock = vi.fn(scaleUp); - const logWarnSpy = vi.spyOn(logger, 'warn'); - mock.mockImplementation(() => { - return new Promise((resolve) => { - resolve(); + describe('Batch processing', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + const createMultipleRecords = (count: number, eventSource = 'aws:sqs'): SQSRecord[] => { + return Array.from({ length: count }, (_, i) => ({ + ...sqsRecord, + eventSource, + messageId: `message-${i}`, + body: JSON.stringify({ + ...body, + id: i + 1, + }), + })); + }; + + it('Should handle multiple SQS records in a single invocation', async () => { + const records = createMultipleRecords(3); + const multiRecordEvent: SQSEvent = { Records: records }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.resolve([])); + vi.mocked(scaleUp).mockImplementation(mock); + + await expect(scaleUpHandler(multiRecordEvent, context)).resolves.not.toThrow(); + expect(scaleUp).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ messageId: 'message-0' }), + expect.objectContaining({ messageId: 'message-1' }), + expect.objectContaining({ messageId: 'message-2' }), + ]), + ); + }); + + it('Should return batch item failures for rejected messages', async () => { + const records = createMultipleRecords(3); + const multiRecordEvent: SQSEvent = { Records: records }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.resolve(['message-1', 'message-2'])); + vi.mocked(scaleUp).mockImplementation(mock); + + const result = await scaleUpHandler(multiRecordEvent, context); + expect(result).toEqual({ + batchItemFailures: [{ itemIdentifier: 'message-1' }, { itemIdentifier: 'message-2' }], + }); + }); + + it('Should filter out non-SQS event sources', async () => { + const sqsRecords = createMultipleRecords(2, 'aws:sqs'); + const nonSqsRecords = createMultipleRecords(1, 'aws:sns'); + const mixedEvent: SQSEvent = { + Records: [...sqsRecords, ...nonSqsRecords], + }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.resolve([])); + vi.mocked(scaleUp).mockImplementation(mock); + + await scaleUpHandler(mixedEvent, context); + expect(scaleUp).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ messageId: 'message-0' }), + expect.objectContaining({ messageId: 'message-1' }), + ]), + ); + expect(scaleUp).not.toHaveBeenCalledWith( + expect.arrayContaining([expect.objectContaining({ messageId: 'message-2' })]), + ); + }); + + it('Should sort messages by retry count', async () => { + const records = [ + { + ...sqsRecord, + messageId: 'high-retry', + body: JSON.stringify({ ...body, retryCounter: 5 }), + }, + { + ...sqsRecord, + messageId: 'low-retry', + body: JSON.stringify({ ...body, retryCounter: 1 }), + }, + { + ...sqsRecord, + messageId: 'no-retry', + body: JSON.stringify({ ...body }), + }, + ]; + const multiRecordEvent: SQSEvent = { Records: records }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation((messages) => { + // Verify messages are sorted by retry count (ascending) + expect(messages[0].messageId).toBe('no-retry'); + expect(messages[1].messageId).toBe('low-retry'); + expect(messages[2].messageId).toBe('high-retry'); + return Promise.resolve([]); + }); + vi.mocked(scaleUp).mockImplementation(mock); + + await scaleUpHandler(multiRecordEvent, context); + }); + + it('Should return all failed messages when scaleUp throws non-ScaleError', async () => { + const records = createMultipleRecords(2); + const multiRecordEvent: SQSEvent = { Records: records }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.reject(new Error('Generic error'))); + vi.mocked(scaleUp).mockImplementation(mock); + + const result = await scaleUpHandler(multiRecordEvent, context); + expect(result).toEqual({ batchItemFailures: [] }); + }); + + it('Should throw when scaleUp throws ScaleError', async () => { + const records = createMultipleRecords(2); + const multiRecordEvent: SQSEvent = { Records: records }; + + const error = new ScaleError(2); + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.reject(error)); + vi.mocked(scaleUp).mockImplementation(mock); + + await expect(scaleUpHandler(multiRecordEvent, context)).resolves.toEqual({ + batchItemFailures: [{ itemIdentifier: 'message-0' }, { itemIdentifier: 'message-1' }], + }); }); }); - const sqsEventMultipleRecords: SQSEvent = { - Records: sqsRecords, - }; - - await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow(); - - expect(logWarnSpy).toHaveBeenCalledWith( - expect.stringContaining( - 'Event ignored, only one record at the time can be handled, ensure the lambda batch size is set to 1.', - ), - ); -} +}); describe('Test scale down lambda wrapper.', () => { it('Scaling down no error.', async () => { diff --git a/lambdas/functions/control-plane/src/lambda.ts b/lambdas/functions/control-plane/src/lambda.ts index 3e3ab90557..e2a0451c95 100644 --- a/lambdas/functions/control-plane/src/lambda.ts +++ b/lambdas/functions/control-plane/src/lambda.ts @@ -1,34 +1,66 @@ import middy from '@middy/core'; import { logger, setContext } from '@aws-github-runner/aws-powertools-util'; import { captureLambdaHandler, tracer } from '@aws-github-runner/aws-powertools-util'; -import { Context, SQSEvent } from 'aws-lambda'; +import { Context, type SQSBatchItemFailure, type SQSBatchResponse, SQSEvent } from 'aws-lambda'; import { PoolEvent, adjust } from './pool/pool'; import ScaleError from './scale-runners/ScaleError'; import { scaleDown } from './scale-runners/scale-down'; -import { scaleUp } from './scale-runners/scale-up'; +import { type ActionRequestMessage, type ActionRequestMessageSQS, scaleUp } from './scale-runners/scale-up'; import { SSMCleanupOptions, cleanSSMTokens } from './scale-runners/ssm-housekeeper'; import { checkAndRetryJob } from './scale-runners/job-retry'; -export async function scaleUpHandler(event: SQSEvent, context: Context): Promise { +export async function scaleUpHandler(event: SQSEvent, context: Context): Promise { setContext(context, 'lambda.ts'); logger.logEventIfEnabled(event); - if (event.Records.length !== 1) { - logger.warn('Event ignored, only one record at the time can be handled, ensure the lambda batch size is set to 1.'); - return Promise.resolve(); + const sqsMessages: ActionRequestMessageSQS[] = []; + const warnedEventSources = new Set(); + + for (const { body, eventSource, messageId } of event.Records) { + if (eventSource !== 'aws:sqs') { + if (!warnedEventSources.has(eventSource)) { + logger.warn('Ignoring non-sqs event source', { eventSource }); + warnedEventSources.add(eventSource); + } + + continue; + } + + const payload = JSON.parse(body) as ActionRequestMessage; + sqsMessages.push({ ...payload, messageId }); } + // Sort messages by their retry count, so that we retry the same messages if + // there's a persistent failure. This should cause messages to be dropped + // quicker than if we retried in an arbitrary order. + sqsMessages.sort((l, r) => { + return (l.retryCounter ?? 0) - (r.retryCounter ?? 0); + }); + + const batchItemFailures: SQSBatchItemFailure[] = []; + try { - await scaleUp(event.Records[0].eventSource, JSON.parse(event.Records[0].body)); - return Promise.resolve(); + const rejectedMessageIds = await scaleUp(sqsMessages); + + for (const messageId of rejectedMessageIds) { + batchItemFailures.push({ + itemIdentifier: messageId, + }); + } + + return { batchItemFailures }; } catch (e) { if (e instanceof ScaleError) { - return Promise.reject(e); + batchItemFailures.push(...e.toBatchItemFailures(sqsMessages)); + logger.warn(`${e.detailedMessage} A retry will be attempted via SQS.`, { error: e }); } else { - logger.warn(`Ignoring error: ${e}`); - return Promise.resolve(); + logger.error(`Error processing batch (size: ${sqsMessages.length}): ${(e as Error).message}, ignoring batch`, { + error: e, + }); } + + return { batchItemFailures }; } } diff --git a/lambdas/functions/control-plane/src/local.ts b/lambdas/functions/control-plane/src/local.ts index 2166da58fd..0b06335c8a 100644 --- a/lambdas/functions/control-plane/src/local.ts +++ b/lambdas/functions/control-plane/src/local.ts @@ -1,21 +1,21 @@ import { logger } from '@aws-github-runner/aws-powertools-util'; -import { ActionRequestMessage, scaleUp } from './scale-runners/scale-up'; +import { scaleUpHandler } from './lambda'; +import { Context, SQSEvent } from 'aws-lambda'; -const sqsEvent = { +const sqsEvent: SQSEvent = { Records: [ { messageId: 'e8d74d08-644e-42ca-bf82-a67daa6c4dad', receiptHandle: - // eslint-disable-next-line max-len 'AQEBCpLYzDEKq4aKSJyFQCkJduSKZef8SJVOperbYyNhXqqnpFG5k74WygVAJ4O0+9nybRyeOFThvITOaS21/jeHiI5fgaM9YKuI0oGYeWCIzPQsluW5CMDmtvqv1aA8sXQ5n2x0L9MJkzgdIHTC3YWBFLQ2AxSveOyIHwW+cHLIFCAcZlOaaf0YtaLfGHGkAC4IfycmaijV8NSlzYgDuxrC9sIsWJ0bSvk5iT4ru/R4+0cjm7qZtGlc04k9xk5Fu6A+wRxMaIyiFRY+Ya19ykcevQldidmEjEWvN6CRToLgclk=', - body: { + body: JSON.stringify({ repositoryName: 'self-hosted', repositoryOwner: 'test-runners', eventType: 'workflow_job', id: 987654, installationId: 123456789, - }, + }), attributes: { ApproximateReceiveCount: '1', SentTimestamp: '1626450047230', @@ -34,12 +34,34 @@ const sqsEvent = { ], }; +const context: Context = { + awsRequestId: '1', + callbackWaitsForEmptyEventLoop: false, + functionName: '', + functionVersion: '', + getRemainingTimeInMillis: () => 0, + invokedFunctionArn: '', + logGroupName: '', + logStreamName: '', + memoryLimitInMB: '', + done: () => { + return; + }, + fail: () => { + return; + }, + succeed: () => { + return; + }, +}; + export function run(): void { - scaleUp(sqsEvent.Records[0].eventSource, sqsEvent.Records[0].body as ActionRequestMessage) - .then() - .catch((e) => { - logger.error(e); - }); + try { + scaleUpHandler(sqsEvent, context); + } catch (e: unknown) { + const message = e instanceof Error ? e.message : `${e}`; + logger.error(message, e instanceof Error ? { error: e } : {}); + } } run(); diff --git a/lambdas/functions/control-plane/src/pool/pool.test.ts b/lambdas/functions/control-plane/src/pool/pool.test.ts index 3a7ba3ab1c..76c3d61d4c 100644 --- a/lambdas/functions/control-plane/src/pool/pool.test.ts +++ b/lambdas/functions/control-plane/src/pool/pool.test.ts @@ -188,11 +188,7 @@ describe('Test simple pool.', () => { it('Top up pool with pool size 2 registered.', async () => { await adjust({ poolSize: 3 }); expect(createRunners).toHaveBeenCalledTimes(1); - expect(createRunners).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ numberOfRunners: 1 }), - expect.anything(), - ); + expect(createRunners).toHaveBeenCalledWith(expect.anything(), expect.anything(), 1, expect.anything()); }); it('Should not top up if pool size is reached.', async () => { @@ -268,11 +264,7 @@ describe('Test simple pool.', () => { it('Top up if the pool size is set to 5', async () => { await adjust({ poolSize: 5 }); // 2 idle, top up with 3 to match a pool of 5 - expect(createRunners).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ numberOfRunners: 3 }), - expect.anything(), - ); + expect(createRunners).toHaveBeenCalledWith(expect.anything(), expect.anything(), 3, expect.anything()); }); }); @@ -287,11 +279,7 @@ describe('Test simple pool.', () => { it('Top up if the pool size is set to 5', async () => { await adjust({ poolSize: 5 }); // 2 idle, top up with 3 to match a pool of 5 - expect(createRunners).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ numberOfRunners: 3 }), - expect.anything(), - ); + expect(createRunners).toHaveBeenCalledWith(expect.anything(), expect.anything(), 3, expect.anything()); }); }); @@ -341,11 +329,7 @@ describe('Test simple pool.', () => { await adjust({ poolSize: 5 }); // 2 idle, 2 prefixed idle top up with 1 to match a pool of 5 - expect(createRunners).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ numberOfRunners: 1 }), - expect.anything(), - ); + expect(createRunners).toHaveBeenCalledWith(expect.anything(), expect.anything(), 1, expect.anything()); }); }); }); diff --git a/lambdas/functions/control-plane/src/pool/pool.ts b/lambdas/functions/control-plane/src/pool/pool.ts index 162a7d0f6d..a333c8da61 100644 --- a/lambdas/functions/control-plane/src/pool/pool.ts +++ b/lambdas/functions/control-plane/src/pool/pool.ts @@ -92,11 +92,11 @@ export async function adjust(event: PoolEvent): Promise { environment, launchTemplateName, subnets, - numberOfRunners: topUp, amiIdSsmParameterName, tracingEnabled, onDemandFailoverOnError, }, + topUp, githubInstallationClient, ); } else { diff --git a/lambdas/functions/control-plane/src/scale-runners/ScaleError.test.ts b/lambdas/functions/control-plane/src/scale-runners/ScaleError.test.ts new file mode 100644 index 0000000000..0a7478c12f --- /dev/null +++ b/lambdas/functions/control-plane/src/scale-runners/ScaleError.test.ts @@ -0,0 +1,76 @@ +import { describe, expect, it } from 'vitest'; +import type { ActionRequestMessageSQS } from './scale-up'; +import ScaleError from './ScaleError'; + +describe('ScaleError', () => { + describe('detailedMessage', () => { + it('should format message for single instance failure', () => { + const error = new ScaleError(1); + + expect(error.detailedMessage).toBe( + 'Failed to create instance, create fleet failed. (Failed to create 1 instance)', + ); + }); + + it('should format message for multiple instance failures', () => { + const error = new ScaleError(3); + + expect(error.detailedMessage).toBe( + 'Failed to create instance, create fleet failed. (Failed to create 3 instances)', + ); + }); + }); + + describe('toBatchItemFailures', () => { + const mockMessages: ActionRequestMessageSQS[] = [ + { messageId: 'msg-1', id: 1, eventType: 'workflow_job' }, + { messageId: 'msg-2', id: 2, eventType: 'workflow_job' }, + { messageId: 'msg-3', id: 3, eventType: 'workflow_job' }, + { messageId: 'msg-4', id: 4, eventType: 'workflow_job' }, + ]; + + it.each([ + { failedCount: 1, expected: [{ itemIdentifier: 'msg-1' }], description: 'default instance count' }, + { + failedCount: 2, + expected: [{ itemIdentifier: 'msg-1' }, { itemIdentifier: 'msg-2' }], + description: 'less than message count', + }, + { + failedCount: 4, + expected: [ + { itemIdentifier: 'msg-1' }, + { itemIdentifier: 'msg-2' }, + { itemIdentifier: 'msg-3' }, + { itemIdentifier: 'msg-4' }, + ], + description: 'equal to message count', + }, + { + failedCount: 10, + expected: [ + { itemIdentifier: 'msg-1' }, + { itemIdentifier: 'msg-2' }, + { itemIdentifier: 'msg-3' }, + { itemIdentifier: 'msg-4' }, + ], + description: 'more than message count', + }, + { failedCount: 0, expected: [], description: 'zero failed instances' }, + { failedCount: -1, expected: [], description: 'negative failed instances' }, + { failedCount: -10, expected: [], description: 'large negative failed instances' }, + ])('should handle $description (failedCount=$failedCount)', ({ failedCount, expected }) => { + const error = new ScaleError(failedCount); + const failures = error.toBatchItemFailures(mockMessages); + + expect(failures).toEqual(expected); + }); + + it('should handle empty message array', () => { + const error = new ScaleError(3); + const failures = error.toBatchItemFailures([]); + + expect(failures).toEqual([]); + }); + }); +}); diff --git a/lambdas/functions/control-plane/src/scale-runners/ScaleError.ts b/lambdas/functions/control-plane/src/scale-runners/ScaleError.ts index d7e71f8c33..9c1f474d17 100644 --- a/lambdas/functions/control-plane/src/scale-runners/ScaleError.ts +++ b/lambdas/functions/control-plane/src/scale-runners/ScaleError.ts @@ -1,8 +1,28 @@ +import type { SQSBatchItemFailure } from 'aws-lambda'; +import type { ActionRequestMessageSQS } from './scale-up'; + class ScaleError extends Error { - constructor(public message: string) { - super(message); + constructor(public readonly failedInstanceCount: number = 1) { + super('Failed to create instance, create fleet failed.'); this.name = 'ScaleError'; - this.stack = new Error().stack; + } + + /** + * Gets a formatted error message including the failed instance count + */ + public get detailedMessage(): string { + return `${this.message} (Failed to create ${this.failedInstanceCount} instance${this.failedInstanceCount !== 1 ? 's' : ''})`; + } + + /** + * Generate SQS batch item failures for the failed instances + */ + public toBatchItemFailures(messages: ActionRequestMessageSQS[]): SQSBatchItemFailure[] { + // Ensure we don't retry negative counts or more messages than available + const messagesToRetry = Math.max(0, Math.min(this.failedInstanceCount, messages.length)); + return messages.slice(0, messagesToRetry).map(({ messageId }) => ({ + itemIdentifier: messageId, + })); } } diff --git a/lambdas/functions/control-plane/src/scale-runners/job-retry.test.ts b/lambdas/functions/control-plane/src/scale-runners/job-retry.test.ts index 1edfefb69b..3ee24fb6d4 100644 --- a/lambdas/functions/control-plane/src/scale-runners/job-retry.test.ts +++ b/lambdas/functions/control-plane/src/scale-runners/job-retry.test.ts @@ -2,9 +2,11 @@ import { publishMessage } from '../aws/sqs'; import { publishRetryMessage, checkAndRetryJob } from './job-retry'; import { ActionRequestMessage, ActionRequestMessageRetry } from './scale-up'; import { getOctokit } from '../github/octokit'; +import { jobRetryCheck } from '../lambda'; import { Octokit } from '@octokit/rest'; import { createSingleMetric } from '@aws-github-runner/aws-powertools-util'; import { describe, it, expect, beforeEach, vi } from 'vitest'; +import type { SQSRecord } from 'aws-lambda'; vi.mock('../aws/sqs', async () => ({ publishMessage: vi.fn(), @@ -267,3 +269,93 @@ describe(`Test job retry check`, () => { expect(publishMessage).not.toHaveBeenCalled(); }); }); + +describe('Test job retry handler (batch processing)', () => { + const context = { + requestId: 'request-id', + functionName: 'function-name', + functionVersion: 'function-version', + invokedFunctionArn: 'invoked-function-arn', + memoryLimitInMB: '128', + awsRequestId: 'aws-request-id', + logGroupName: 'log-group-name', + logStreamName: 'log-stream-name', + remainingTimeInMillis: () => 30000, + done: () => {}, + fail: () => {}, + succeed: () => {}, + getRemainingTimeInMillis: () => 30000, + callbackWaitsForEmptyEventLoop: false, + }; + + function createSQSRecord(messageId: string): SQSRecord { + return { + messageId, + receiptHandle: 'receipt-handle', + body: JSON.stringify({ + eventType: 'workflow_job', + id: 123, + installationId: 456, + repositoryName: 'test-repo', + repositoryOwner: 'test-owner', + repoOwnerType: 'Organization', + retryCounter: 0, + }), + attributes: { + ApproximateReceiveCount: '1', + SentTimestamp: '1234567890', + SenderId: 'sender-id', + ApproximateFirstReceiveTimestamp: '1234567891', + }, + messageAttributes: {}, + md5OfBody: 'md5', + eventSource: 'aws:sqs', + eventSourceARN: 'arn:aws:sqs:region:account:queue', + awsRegion: 'us-east-1', + }; + } + + beforeEach(() => { + vi.clearAllMocks(); + process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; + process.env.JOB_QUEUE_SCALE_UP_URL = 'https://sqs.example.com/queue'; + }); + + it('should handle multiple records in a single batch', async () => { + mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ + data: { + status: 'queued', + }, + headers: {}, + })); + + const event = { + Records: [createSQSRecord('msg-1'), createSQSRecord('msg-2'), createSQSRecord('msg-3')], + }; + + await expect(jobRetryCheck(event, context)).resolves.not.toThrow(); + expect(publishMessage).toHaveBeenCalledTimes(3); + }); + + it('should continue processing other records when one fails', async () => { + mockCreateOctokitClient + .mockResolvedValueOnce(new Octokit()) // First record succeeds + .mockRejectedValueOnce(new Error('API error')) // Second record fails + .mockResolvedValueOnce(new Octokit()); // Third record succeeds + + mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ + data: { + status: 'queued', + }, + headers: {}, + })); + + const event = { + Records: [createSQSRecord('msg-1'), createSQSRecord('msg-2'), createSQSRecord('msg-3')], + }; + + await expect(jobRetryCheck(event, context)).resolves.not.toThrow(); + // There were two successful calls to publishMessage + expect(publishMessage).toHaveBeenCalledTimes(2); + }); +}); diff --git a/lambdas/functions/control-plane/src/scale-runners/scale-up.test.ts b/lambdas/functions/control-plane/src/scale-runners/scale-up.test.ts index 14c0a0422e..1cf6ace8b7 100644 --- a/lambdas/functions/control-plane/src/scale-runners/scale-up.test.ts +++ b/lambdas/functions/control-plane/src/scale-runners/scale-up.test.ts @@ -1,5 +1,4 @@ import { PutParameterCommand, SSMClient } from '@aws-sdk/client-ssm'; -import { Octokit } from '@octokit/rest'; import { mockClient } from 'aws-sdk-client-mock'; import 'aws-sdk-client-mock-jest/vitest'; // Using vi.mocked instead of jest-mock @@ -9,10 +8,10 @@ import { performance } from 'perf_hooks'; import * as ghAuth from '../github/auth'; import { createRunner, listEC2Runners } from './../aws/runners'; import { RunnerInputParameters } from './../aws/runners.d'; -import ScaleError from './ScaleError'; import * as scaleUpModule from './scale-up'; import { getParameter } from '@aws-github-runner/aws-ssm-util'; import { describe, it, expect, beforeEach, vi } from 'vitest'; +import type { Octokit } from '@octokit/rest'; const mockOctokit = { paginate: vi.fn(), @@ -29,6 +28,7 @@ const mockOctokit = { getRepoInstallation: vi.fn(), }, }; + const mockCreateRunner = vi.mocked(createRunner); const mockListRunners = vi.mocked(listEC2Runners); const mockSSMClient = mockClient(SSMClient); @@ -66,26 +66,33 @@ export type RunnerType = 'ephemeral' | 'non-ephemeral'; // for ephemeral and non-ephemeral runners const RUNNER_TYPES: RunnerType[] = ['ephemeral', 'non-ephemeral']; -const mocktokit = Octokit as vi.MockedClass; const mockedAppAuth = vi.mocked(ghAuth.createGithubAppAuth); const mockedInstallationAuth = vi.mocked(ghAuth.createGithubInstallationAuth); const mockCreateClient = vi.mocked(ghAuth.createOctokitClient); -const TEST_DATA: scaleUpModule.ActionRequestMessage = { +const TEST_DATA_SINGLE: scaleUpModule.ActionRequestMessageSQS = { id: 1, eventType: 'workflow_job', repositoryName: 'hello-world', repositoryOwner: 'Codertocat', installationId: 2, repoOwnerType: 'Organization', + messageId: 'foobar', }; +const TEST_DATA: scaleUpModule.ActionRequestMessageSQS[] = [ + { + ...TEST_DATA_SINGLE, + messageId: 'foobar', + }, +]; + const cleanEnv = process.env; const EXPECTED_RUNNER_PARAMS: RunnerInputParameters = { environment: 'unit-test-environment', runnerType: 'Org', - runnerOwner: TEST_DATA.repositoryOwner, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, numberOfRunners: 1, launchTemplateName: 'lt-1', ec2instanceCriteria: { @@ -132,14 +139,14 @@ beforeEach(() => { instanceId: 'i-1234', launchTime: new Date(), type: 'Org', - owner: TEST_DATA.repositoryOwner, + owner: TEST_DATA_SINGLE.repositoryOwner, }, ]); mockedAppAuth.mockResolvedValue({ type: 'app', token: 'token', - appId: TEST_DATA.installationId, + appId: TEST_DATA_SINGLE.installationId, expiresAt: 'some-date', }); mockedInstallationAuth.mockResolvedValue({ @@ -153,7 +160,7 @@ beforeEach(() => { installationId: 0, }); - mockCreateClient.mockResolvedValue(new mocktokit()); + mockCreateClient.mockResolvedValue(mockOctokit as unknown as Octokit); }); describe('scaleUp with GHES', () => { @@ -161,17 +168,12 @@ describe('scaleUp with GHES', () => { process.env.GHES_URL = 'https://github.enterprise.something'; }); - it('ignores non-sqs events', async () => { - expect.assertions(1); - await expect(scaleUpModule.scaleUp('aws:s3', TEST_DATA)).rejects.toEqual(Error('Cannot handle non-SQS events!')); - }); - it('checks queued workflows', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).toBeCalledWith({ - job_id: TEST_DATA.id, - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + job_id: TEST_DATA_SINGLE.id, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); @@ -179,7 +181,7 @@ describe('scaleUp with GHES', () => { mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ data: { total_count: 0 }, })); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toBeCalled(); }); @@ -198,18 +200,18 @@ describe('scaleUp with GHES', () => { }); it('gets the current org level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Org', - runnerOwner: TEST_DATA.repositoryOwner, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); @@ -217,35 +219,35 @@ describe('scaleUp with GHES', () => { it('does create a runner if maximum is set to -1', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '-1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toHaveBeenCalled(); expect(createRunner).toHaveBeenCalled(); }); it('creates a token when maximum runners has not been reached', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, }); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a runner with correct config', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with labels in a specific group', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with ami id override from ssm parameter', async () => { process.env.AMI_ID_SSM_PARAMETER_NAME = 'my-ami-id-param'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith({ ...expectedRunnerParams, amiIdSsmParameterName: 'my-ami-id-param' }); }); @@ -254,15 +256,15 @@ describe('scaleUp with GHES', () => { mockSSMgetParameter.mockImplementation(async () => { throw new Error('ParameterNotFound'); }); - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toBeInstanceOf(Error); + await expect(scaleUpModule.scaleUp(TEST_DATA)).rejects.toBeInstanceOf(Error); expect(mockOctokit.paginate).toHaveBeenCalledTimes(1); }); it('Discards event if it is a User repo and org level runners is enabled', async () => { process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; - const USER_REPO_TEST_DATA = { ...TEST_DATA }; - USER_REPO_TEST_DATA.repoOwnerType = 'User'; - await scaleUpModule.scaleUp('aws:sqs', USER_REPO_TEST_DATA); + const USER_REPO_TEST_DATA = structuredClone(TEST_DATA); + USER_REPO_TEST_DATA[0].repoOwnerType = 'User'; + await scaleUpModule.scaleUp(USER_REPO_TEST_DATA); expect(createRunner).not.toHaveBeenCalled(); }); @@ -270,7 +272,7 @@ describe('scaleUp with GHES', () => { mockSSMgetParameter.mockImplementation(async () => { throw new Error('ParameterNotFound'); }); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.paginate).toHaveBeenCalledTimes(1); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 2); expect(mockSSMClient).toHaveReceivedNthSpecificCommandWith(1, PutParameterCommand, { @@ -281,7 +283,7 @@ describe('scaleUp with GHES', () => { }); it('Does not create SSM parameter for runner group id if it exists', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.paginate).toHaveBeenCalledTimes(0); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 1); }); @@ -289,9 +291,9 @@ describe('scaleUp with GHES', () => { it('create start runner config for ephemeral runners ', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.generateRunnerJitconfigForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, name: 'unit-test-i-12345', runner_group_id: 1, labels: ['label1', 'label2'], @@ -312,7 +314,7 @@ describe('scaleUp with GHES', () => { it('create start runner config for non-ephemeral runners ', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; process.env.RUNNERS_MAXIMUM_COUNT = '2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.generateRunnerJitconfigForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalled(); expect(mockSSMClient).toHaveReceivedNthSpecificCommandWith(1, PutParameterCommand, { @@ -383,7 +385,7 @@ describe('scaleUp with GHES', () => { 'i-150', 'i-151', ]; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); const endTime = performance.now(); expect(endTime - startTime).toBeGreaterThan(1000); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 40); @@ -397,87 +399,307 @@ describe('scaleUp with GHES', () => { process.env.RUNNER_NAME_PREFIX = 'unit-test'; expectedRunnerParams = { ...EXPECTED_RUNNER_PARAMS }; expectedRunnerParams.runnerType = 'Repo'; - expectedRunnerParams.runnerOwner = `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`; - // `--url https://github.enterprise.something/${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + expectedRunnerParams.runnerOwner = `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`; + // `--url https://github.enterprise.something/${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, // `--token 1234abcd`, // ]; }); it('gets the current repo level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Repo', - runnerOwner: `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + runnerOwner: `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a token when maximum runners has not been reached', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('uses the default runner max count', async () => { process.env.RUNNERS_MAXIMUM_COUNT = undefined; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('creates a runner with correct config and labels', async () => { process.env.RUNNER_LABELS = 'label1,label2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner and ensure the group argument is ignored', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP_IGNORED'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('Check error is thrown', async () => { const mockCreateRunners = vi.mocked(createRunner); mockCreateRunners.mockRejectedValue(new Error('no retry')); - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toThrow('no retry'); + await expect(scaleUpModule.scaleUp(TEST_DATA)).rejects.toThrow('no retry'); mockCreateRunners.mockReset(); }); }); -}); -describe('scaleUp with public GH', () => { - it('ignores non-sqs events', async () => { - expect.assertions(1); - await expect(scaleUpModule.scaleUp('aws:s3', TEST_DATA)).rejects.toEqual(Error('Cannot handle non-SQS events!')); + describe('Batch processing', () => { + beforeEach(() => { + process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + process.env.RUNNERS_MAXIMUM_COUNT = '10'; + }); + + const createTestMessages = ( + count: number, + overrides: Partial[] = [], + ): scaleUpModule.ActionRequestMessageSQS[] => { + return Array.from({ length: count }, (_, i) => ({ + ...TEST_DATA_SINGLE, + id: i + 1, + messageId: `message-${i}`, + ...overrides[i], + })); + }; + + it('Should handle multiple messages for the same organization', async () => { + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(1); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 3, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, + }), + ); + }); + + it('Should handle multiple messages for different organizations', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'org1' }, + { repositoryOwner: 'org2' }, + { repositoryOwner: 'org1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'org1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'org2', + }), + ); + }); + + it('Should handle multiple messages for different repositories when org-level is disabled', async () => { + process.env.ENABLE_ORGANIZATION_RUNNERS = 'false'; + const messages = createTestMessages(3, [ + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + { repositoryOwner: 'owner1', repositoryName: 'repo2' }, + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'owner1/repo1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'owner1/repo2', + }), + ); + }); + + it('Should reject messages when maximum runners limit is reached', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '1'; // Set to 1 so with 1 existing, no new ones can be created + mockListRunners.mockImplementation(async () => [ + { + instanceId: 'i-existing', + launchTime: new Date(), + type: 'Org', + owner: TEST_DATA_SINGLE.repositoryOwner, + }, + ]); + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); // No runners should be created + expect(rejectedMessages).toHaveLength(3); // All 3 messages should be rejected + }); + + it('Should handle partial EC2 instance creation failures', async () => { + mockCreateRunner.mockImplementation(async () => ['i-12345']); // Only creates 1 instead of requested 3 + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(rejectedMessages).toHaveLength(2); // 3 requested - 1 created = 2 failed + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should filter out invalid event types for ephemeral runners', async () => { + const messages = createTestMessages(3, [ + { eventType: 'workflow_job' }, + { eventType: 'check_run' }, + { eventType: 'workflow_job' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only workflow_job events processed + }), + ); + expect(rejectedMessages).toContain('message-1'); // check_run event rejected + }); + + it('Should skip invalid repo owner types but not reject them', async () => { + const messages = createTestMessages(3, [ + { repoOwnerType: 'Organization' }, + { repoOwnerType: 'User' }, // Invalid for org-level runners + { repoOwnerType: 'Organization' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only Organization events processed + }), + ); + expect(rejectedMessages).not.toContain('message-1'); // User repo not rejected, just skipped + }); + + it('Should skip messages when jobs are not queued', async () => { + mockOctokit.actions.getJobForWorkflowRun.mockImplementation((params) => { + const isQueued = params.job_id === 1 || params.job_id === 3; // Only jobs 1 and 3 are queued + return { + data: { + status: isQueued ? 'queued' : 'completed', + }, + }; + }); + + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only queued jobs processed + }), + ); + }); + + it('Should create separate GitHub clients for different installations', async () => { + // Override the default mock to return different installation IDs + mockOctokit.apps.getOrgInstallation.mockReset(); + mockOctokit.apps.getOrgInstallation.mockImplementation((params) => ({ + data: { + id: params.org === 'org1' ? 100 : 200, + }, + })); + + const messages = createTestMessages(2, [ + { repositoryOwner: 'org1', installationId: 0 }, + { repositoryOwner: 'org2', installationId: 0 }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(3); // 1 app client, 2 repo installation clients + expect(mockedInstallationAuth).toHaveBeenCalledWith(100, 'https://github.enterprise.something/api/v3'); + expect(mockedInstallationAuth).toHaveBeenCalledWith(200, 'https://github.enterprise.something/api/v3'); + }); + + it('Should reuse GitHub clients for same installation', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(2); // 1 app client, 1 installation client + expect(mockedInstallationAuth).toHaveBeenCalledTimes(1); + }); + + it('Should return empty array when no valid messages to process', async () => { + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + const messages = createTestMessages(2, [ + { eventType: 'check_run' }, // Invalid for ephemeral + { eventType: 'check_run' }, // Invalid for ephemeral + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should handle unlimited runners configuration', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '-1'; + const messages = createTestMessages(10); + + await scaleUpModule.scaleUp(messages); + + expect(listEC2Runners).not.toHaveBeenCalled(); // No need to check current runners + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 10, // All messages processed + }), + ); + }); }); +}); +describe('scaleUp with public GH', () => { it('checks queued workflows', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).toBeCalledWith({ - job_id: TEST_DATA.id, - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + job_id: TEST_DATA_SINGLE.id, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('not checking queued workflows', async () => { process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).not.toBeCalled(); }); @@ -485,7 +707,7 @@ describe('scaleUp with public GH', () => { mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ data: { status: 'completed' }, })); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toBeCalled(); }); @@ -497,38 +719,38 @@ describe('scaleUp with public GH', () => { }); it('gets the current org level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Org', - runnerOwner: TEST_DATA.repositoryOwner, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a token when maximum runners has not been reached', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, }); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a runner with correct config', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with labels in s specific group', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); }); @@ -541,44 +763,44 @@ describe('scaleUp with public GH', () => { process.env.RUNNER_NAME_PREFIX = 'unit-test'; expectedRunnerParams = { ...EXPECTED_RUNNER_PARAMS }; expectedRunnerParams.runnerType = 'Repo'; - expectedRunnerParams.runnerOwner = `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`; + expectedRunnerParams.runnerOwner = `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`; }); it('gets the current repo level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Repo', - runnerOwner: `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + runnerOwner: `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a token when maximum runners has not been reached', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('creates a runner with correct config and labels', async () => { process.env.RUNNER_LABELS = 'label1,label2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with correct config and labels and on demand failover enabled.', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.ENABLE_ON_DEMAND_FAILOVER_FOR_ERRORS = JSON.stringify(['InsufficientInstanceCapacity']); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith({ ...expectedRunnerParams, onDemandFailoverOnError: ['InsufficientInstanceCapacity'], @@ -588,26 +810,25 @@ describe('scaleUp with public GH', () => { it('creates a runner and ensure the group argument is ignored', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP_IGNORED'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('ephemeral runners only run with workflow_job event, others should fail.', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; - await expect( - scaleUpModule.scaleUp('aws:sqs', { - ...TEST_DATA, - eventType: 'check_run', - }), - ).rejects.toBeInstanceOf(Error); + + const USER_REPO_TEST_DATA = structuredClone(TEST_DATA); + USER_REPO_TEST_DATA[0].eventType = 'check_run'; + + await expect(scaleUpModule.scaleUp(USER_REPO_TEST_DATA)).resolves.toEqual(['foobar']); }); it('creates a ephemeral runner with JIT config.', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; process.env.SSM_TOKEN_PATH = '/github-action-runners/default/runners/config'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).not.toBeCalled(); expect(createRunner).toBeCalledWith(expectedRunnerParams); @@ -629,7 +850,7 @@ describe('scaleUp with public GH', () => { process.env.ENABLE_JIT_CONFIG = 'false'; process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; process.env.SSM_TOKEN_PATH = '/github-action-runners/default/runners/config'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).not.toBeCalled(); expect(createRunner).toBeCalledWith(expectedRunnerParams); @@ -652,7 +873,7 @@ describe('scaleUp with public GH', () => { process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; process.env.RUNNER_LABELS = 'jit'; process.env.SSM_TOKEN_PATH = '/github-action-runners/default/runners/config'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).not.toBeCalled(); expect(createRunner).toBeCalledWith(expectedRunnerParams); @@ -672,21 +893,247 @@ describe('scaleUp with public GH', () => { it('creates a ephemeral runner after checking job is queued.', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; process.env.ENABLE_JOB_QUEUED_CHECK = 'true'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).toBeCalled(); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('disable auto update on the runner.', async () => { process.env.DISABLE_RUNNER_AUTOUPDATE = 'true'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); - it('Scaling error should cause reject so retry can be triggered.', async () => { + it('Scaling error should return failed message IDs so retry can be triggered.', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toBeInstanceOf(ScaleError); + await expect(scaleUpModule.scaleUp(TEST_DATA)).resolves.toEqual(['foobar']); + }); + }); + + describe('Batch processing', () => { + const createTestMessages = ( + count: number, + overrides: Partial[] = [], + ): scaleUpModule.ActionRequestMessageSQS[] => { + return Array.from({ length: count }, (_, i) => ({ + ...TEST_DATA_SINGLE, + id: i + 1, + messageId: `message-${i}`, + ...overrides[i], + })); + }; + + beforeEach(() => { + setDefaults(); + process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + process.env.RUNNERS_MAXIMUM_COUNT = '10'; + }); + + it('Should handle multiple messages for the same organization', async () => { + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(1); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 3, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, + }), + ); + }); + + it('Should handle multiple messages for different organizations', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'org1' }, + { repositoryOwner: 'org2' }, + { repositoryOwner: 'org1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'org1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'org2', + }), + ); + }); + + it('Should handle multiple messages for different repositories when org-level is disabled', async () => { + process.env.ENABLE_ORGANIZATION_RUNNERS = 'false'; + const messages = createTestMessages(3, [ + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + { repositoryOwner: 'owner1', repositoryName: 'repo2' }, + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'owner1/repo1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'owner1/repo2', + }), + ); + }); + + it('Should reject messages when maximum runners limit is reached', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '1'; // Set to 1 so with 1 existing, no new ones can be created + mockListRunners.mockImplementation(async () => [ + { + instanceId: 'i-existing', + launchTime: new Date(), + type: 'Org', + owner: TEST_DATA_SINGLE.repositoryOwner, + }, + ]); + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); // No runners should be created + expect(rejectedMessages).toHaveLength(3); // All 3 messages should be rejected + }); + + it('Should handle partial EC2 instance creation failures', async () => { + mockCreateRunner.mockImplementation(async () => ['i-12345']); // Only creates 1 instead of requested 3 + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(rejectedMessages).toHaveLength(2); // 3 requested - 1 created = 2 failed + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should filter out invalid event types for ephemeral runners', async () => { + const messages = createTestMessages(3, [ + { eventType: 'workflow_job' }, + { eventType: 'check_run' }, + { eventType: 'workflow_job' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only workflow_job events processed + }), + ); + expect(rejectedMessages).toContain('message-1'); // check_run event rejected + }); + + it('Should skip invalid repo owner types but not reject them', async () => { + const messages = createTestMessages(3, [ + { repoOwnerType: 'Organization' }, + { repoOwnerType: 'User' }, // Invalid for org-level runners + { repoOwnerType: 'Organization' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only Organization events processed + }), + ); + expect(rejectedMessages).not.toContain('message-1'); // User repo not rejected, just skipped + }); + + it('Should skip messages when jobs are not queued', async () => { + mockOctokit.actions.getJobForWorkflowRun.mockImplementation((params) => { + const isQueued = params.job_id === 1 || params.job_id === 3; // Only jobs 1 and 3 are queued + return { + data: { + status: isQueued ? 'queued' : 'completed', + }, + }; + }); + + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only queued jobs processed + }), + ); + }); + + it('Should create separate GitHub clients for different installations', async () => { + // Override the default mock to return different installation IDs + mockOctokit.apps.getOrgInstallation.mockReset(); + mockOctokit.apps.getOrgInstallation.mockImplementation((params) => ({ + data: { + id: params.org === 'org1' ? 100 : 200, + }, + })); + + const messages = createTestMessages(2, [ + { repositoryOwner: 'org1', installationId: 0 }, + { repositoryOwner: 'org2', installationId: 0 }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(3); // 1 app client, 2 repo installation clients + expect(mockedInstallationAuth).toHaveBeenCalledWith(100, ''); + expect(mockedInstallationAuth).toHaveBeenCalledWith(200, ''); + }); + + it('Should reuse GitHub clients for same installation', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(2); // 1 app client, 1 installation client + expect(mockedInstallationAuth).toHaveBeenCalledTimes(1); + }); + + it('Should return empty array when no valid messages to process', async () => { + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + const messages = createTestMessages(2, [ + { eventType: 'check_run' }, // Invalid for ephemeral + { eventType: 'check_run' }, // Invalid for ephemeral + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should handle unlimited runners configuration', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '-1'; + const messages = createTestMessages(10); + + await scaleUpModule.scaleUp(messages); + + expect(listEC2Runners).not.toHaveBeenCalled(); // No need to check current runners + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 10, // All messages processed + }), + ); }); }); }); @@ -696,17 +1143,12 @@ describe('scaleUp with Github Data Residency', () => { process.env.GHES_URL = 'https://companyname.ghe.com'; }); - it('ignores non-sqs events', async () => { - expect.assertions(1); - await expect(scaleUpModule.scaleUp('aws:s3', TEST_DATA)).rejects.toEqual(Error('Cannot handle non-SQS events!')); - }); - it('checks queued workflows', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).toBeCalledWith({ - job_id: TEST_DATA.id, - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + job_id: TEST_DATA_SINGLE.id, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); @@ -714,7 +1156,7 @@ describe('scaleUp with Github Data Residency', () => { mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ data: { total_count: 0 }, })); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toBeCalled(); }); @@ -733,18 +1175,18 @@ describe('scaleUp with Github Data Residency', () => { }); it('gets the current org level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Org', - runnerOwner: TEST_DATA.repositoryOwner, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); @@ -752,35 +1194,35 @@ describe('scaleUp with Github Data Residency', () => { it('does create a runner if maximum is set to -1', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '-1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toHaveBeenCalled(); expect(createRunner).toHaveBeenCalled(); }); it('creates a token when maximum runners has not been reached', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, }); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a runner with correct config', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with labels in a specific group', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with ami id override from ssm parameter', async () => { process.env.AMI_ID_SSM_PARAMETER_NAME = 'my-ami-id-param'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith({ ...expectedRunnerParams, amiIdSsmParameterName: 'my-ami-id-param' }); }); @@ -789,15 +1231,15 @@ describe('scaleUp with Github Data Residency', () => { mockSSMgetParameter.mockImplementation(async () => { throw new Error('ParameterNotFound'); }); - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toBeInstanceOf(Error); + await expect(scaleUpModule.scaleUp(TEST_DATA)).rejects.toBeInstanceOf(Error); expect(mockOctokit.paginate).toHaveBeenCalledTimes(1); }); it('Discards event if it is a User repo and org level runners is enabled', async () => { process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; - const USER_REPO_TEST_DATA = { ...TEST_DATA }; - USER_REPO_TEST_DATA.repoOwnerType = 'User'; - await scaleUpModule.scaleUp('aws:sqs', USER_REPO_TEST_DATA); + const USER_REPO_TEST_DATA = structuredClone(TEST_DATA); + USER_REPO_TEST_DATA[0].repoOwnerType = 'User'; + await scaleUpModule.scaleUp(USER_REPO_TEST_DATA); expect(createRunner).not.toHaveBeenCalled(); }); @@ -805,7 +1247,7 @@ describe('scaleUp with Github Data Residency', () => { mockSSMgetParameter.mockImplementation(async () => { throw new Error('ParameterNotFound'); }); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.paginate).toHaveBeenCalledTimes(1); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 2); expect(mockSSMClient).toHaveReceivedNthSpecificCommandWith(1, PutParameterCommand, { @@ -816,7 +1258,7 @@ describe('scaleUp with Github Data Residency', () => { }); it('Does not create SSM parameter for runner group id if it exists', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.paginate).toHaveBeenCalledTimes(0); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 1); }); @@ -824,9 +1266,9 @@ describe('scaleUp with Github Data Residency', () => { it('create start runner config for ephemeral runners ', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.generateRunnerJitconfigForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, name: 'unit-test-i-12345', runner_group_id: 1, labels: ['label1', 'label2'], @@ -847,7 +1289,7 @@ describe('scaleUp with Github Data Residency', () => { it('create start runner config for non-ephemeral runners ', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; process.env.RUNNERS_MAXIMUM_COUNT = '2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.generateRunnerJitconfigForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalled(); expect(mockSSMClient).toHaveReceivedNthSpecificCommandWith(1, PutParameterCommand, { @@ -918,7 +1360,7 @@ describe('scaleUp with Github Data Residency', () => { 'i-150', 'i-151', ]; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); const endTime = performance.now(); expect(endTime - startTime).toBeGreaterThan(1000); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 40); @@ -932,67 +1374,295 @@ describe('scaleUp with Github Data Residency', () => { process.env.RUNNER_NAME_PREFIX = 'unit-test'; expectedRunnerParams = { ...EXPECTED_RUNNER_PARAMS }; expectedRunnerParams.runnerType = 'Repo'; - expectedRunnerParams.runnerOwner = `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`; - // `--url https://companyname.ghe.com${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + expectedRunnerParams.runnerOwner = `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`; + // `--url https://companyname.ghe.com${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, // `--token 1234abcd`, // ]; }); it('gets the current repo level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Repo', - runnerOwner: `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + runnerOwner: `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a token when maximum runners has not been reached', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('uses the default runner max count', async () => { process.env.RUNNERS_MAXIMUM_COUNT = undefined; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('creates a runner with correct config and labels', async () => { process.env.RUNNER_LABELS = 'label1,label2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner and ensure the group argument is ignored', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP_IGNORED'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('Check error is thrown', async () => { const mockCreateRunners = vi.mocked(createRunner); mockCreateRunners.mockRejectedValue(new Error('no retry')); - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toThrow('no retry'); + await expect(scaleUpModule.scaleUp(TEST_DATA)).rejects.toThrow('no retry'); mockCreateRunners.mockReset(); }); }); + + describe('Batch processing', () => { + const createTestMessages = ( + count: number, + overrides: Partial[] = [], + ): scaleUpModule.ActionRequestMessageSQS[] => { + return Array.from({ length: count }, (_, i) => ({ + ...TEST_DATA_SINGLE, + id: i + 1, + messageId: `message-${i}`, + ...overrides[i], + })); + }; + + beforeEach(() => { + setDefaults(); + process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + process.env.RUNNERS_MAXIMUM_COUNT = '10'; + }); + + it('Should handle multiple messages for the same organization', async () => { + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(1); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 3, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, + }), + ); + }); + + it('Should handle multiple messages for different organizations', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'org1' }, + { repositoryOwner: 'org2' }, + { repositoryOwner: 'org1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'org1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'org2', + }), + ); + }); + + it('Should handle multiple messages for different repositories when org-level is disabled', async () => { + process.env.ENABLE_ORGANIZATION_RUNNERS = 'false'; + const messages = createTestMessages(3, [ + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + { repositoryOwner: 'owner1', repositoryName: 'repo2' }, + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'owner1/repo1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'owner1/repo2', + }), + ); + }); + + it('Should reject messages when maximum runners limit is reached', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '2'; + mockListRunners.mockImplementation(async () => [ + { + instanceId: 'i-existing', + launchTime: new Date(), + type: 'Org', + owner: TEST_DATA_SINGLE.repositoryOwner, + }, + ]); + + const messages = createTestMessages(5); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, // 2 max - 1 existing = 1 new + }), + ); + expect(rejectedMessages).toHaveLength(4); // 5 requested - 1 created = 4 rejected + }); + + it('Should handle partial EC2 instance creation failures', async () => { + mockCreateRunner.mockImplementation(async () => ['i-12345']); // Only creates 1 instead of requested 3 + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(rejectedMessages).toHaveLength(2); // 3 requested - 1 created = 2 failed + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should filter out invalid event types for ephemeral runners', async () => { + const messages = createTestMessages(3, [ + { eventType: 'workflow_job' }, + { eventType: 'check_run' }, + { eventType: 'workflow_job' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only workflow_job events processed + }), + ); + expect(rejectedMessages).toContain('message-1'); // check_run event rejected + }); + + it('Should skip invalid repo owner types but not reject them', async () => { + const messages = createTestMessages(3, [ + { repoOwnerType: 'Organization' }, + { repoOwnerType: 'User' }, // Invalid for org-level runners + { repoOwnerType: 'Organization' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only Organization events processed + }), + ); + expect(rejectedMessages).not.toContain('message-1'); // User repo not rejected, just skipped + }); + + it('Should skip messages when jobs are not queued', async () => { + mockOctokit.actions.getJobForWorkflowRun.mockImplementation((params) => { + const isQueued = params.job_id === 1 || params.job_id === 3; // Only jobs 1 and 3 are queued + return { + data: { + status: isQueued ? 'queued' : 'completed', + }, + }; + }); + + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only queued jobs processed + }), + ); + }); + + it('Should create separate GitHub clients for different installations', async () => { + mockOctokit.apps.getOrgInstallation.mockImplementation((params) => ({ + data: { + id: params.org === 'org1' ? 100 : 200, + }, + })); + + const messages = createTestMessages(2, [ + { repositoryOwner: 'org1', installationId: 0 }, + { repositoryOwner: 'org2', installationId: 0 }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(3); // 1 app client, 2 repo installation clients + expect(mockedInstallationAuth).toHaveBeenCalledWith(100, ''); + expect(mockedInstallationAuth).toHaveBeenCalledWith(200, ''); + }); + + it('Should reuse GitHub clients for same installation', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(2); // 1 app client, 1 installation client + expect(mockedInstallationAuth).toHaveBeenCalledTimes(1); + }); + + it('Should return empty array when no valid messages to process', async () => { + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + const messages = createTestMessages(2, [ + { eventType: 'check_run' }, // Invalid for ephemeral + { eventType: 'check_run' }, // Invalid for ephemeral + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should handle unlimited runners configuration', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '-1'; + const messages = createTestMessages(10); + + await scaleUpModule.scaleUp(messages); + + expect(listEC2Runners).not.toHaveBeenCalled(); // No need to check current runners + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 10, // All messages processed + }), + ); + }); + }); }); function defaultOctokitMockImpl() { @@ -1032,12 +1702,12 @@ function defaultOctokitMockImpl() { }; const mockInstallationIdReturnValueOrgs = { data: { - id: TEST_DATA.installationId, + id: TEST_DATA_SINGLE.installationId, }, }; const mockInstallationIdReturnValueRepos = { data: { - id: TEST_DATA.installationId, + id: TEST_DATA_SINGLE.installationId, }, }; diff --git a/lambdas/functions/control-plane/src/scale-runners/scale-up.ts b/lambdas/functions/control-plane/src/scale-runners/scale-up.ts index 638edd3232..35df7ea5d7 100644 --- a/lambdas/functions/control-plane/src/scale-runners/scale-up.ts +++ b/lambdas/functions/control-plane/src/scale-runners/scale-up.ts @@ -6,8 +6,6 @@ import yn from 'yn'; import { createGithubAppAuth, createGithubInstallationAuth, createOctokitClient } from '../github/auth'; import { createRunner, listEC2Runners, tag } from './../aws/runners'; import { RunnerInputParameters } from './../aws/runners.d'; -import ScaleError from './ScaleError'; -import { publishRetryMessage } from './job-retry'; import { metricGitHubAppRateLimit } from '../github/rate-limit'; const logger = createChildLogger('scale-up'); @@ -33,6 +31,10 @@ export interface ActionRequestMessage { retryCounter?: number; } +export interface ActionRequestMessageSQS extends ActionRequestMessage { + messageId: string; +} + export interface ActionRequestMessageRetry extends ActionRequestMessage { retryCounter: number; } @@ -114,7 +116,7 @@ function removeTokenFromLogging(config: string[]): string[] { } export async function getInstallationId( - ghesApiUrl: string, + githubAppClient: Octokit, enableOrgLevel: boolean, payload: ActionRequestMessage, ): Promise { @@ -122,16 +124,14 @@ export async function getInstallationId( return payload.installationId; } - const ghAuth = await createGithubAppAuth(undefined, ghesApiUrl); - const githubClient = await createOctokitClient(ghAuth.token, ghesApiUrl); return enableOrgLevel ? ( - await githubClient.apps.getOrgInstallation({ + await githubAppClient.apps.getOrgInstallation({ org: payload.repositoryOwner, }) ).data.id : ( - await githubClient.apps.getRepoInstallation({ + await githubAppClient.apps.getRepoInstallation({ owner: payload.repositoryOwner, repo: payload.repositoryName, }) @@ -211,23 +211,27 @@ async function getRunnerGroupByName(ghClient: Octokit, githubRunnerConfig: Creat export async function createRunners( githubRunnerConfig: CreateGitHubRunnerConfig, ec2RunnerConfig: CreateEC2RunnerConfig, + numberOfRunners: number, ghClient: Octokit, -): Promise { +): Promise { const instances = await createRunner({ runnerType: githubRunnerConfig.runnerType, runnerOwner: githubRunnerConfig.runnerOwner, - numberOfRunners: 1, + numberOfRunners, ...ec2RunnerConfig, }); if (instances.length !== 0) { await createStartRunnerConfig(githubRunnerConfig, instances, ghClient); } + + return instances; } -export async function scaleUp(eventSource: string, payload: ActionRequestMessage): Promise { - logger.info(`Received ${payload.eventType} from ${payload.repositoryOwner}/${payload.repositoryName}`); +export async function scaleUp(payloads: ActionRequestMessageSQS[]): Promise { + logger.info('Received scale up requests', { + n_requests: payloads.length, + }); - if (eventSource !== 'aws:sqs') throw Error('Cannot handle non-SQS events!'); const enableOrgLevel = yn(process.env.ENABLE_ORGANIZATION_RUNNERS, { default: true }); const maximumRunners = parseInt(process.env.RUNNERS_MAXIMUM_COUNT || '3'); const runnerLabels = process.env.RUNNER_LABELS || ''; @@ -252,103 +256,202 @@ export async function scaleUp(eventSource: string, payload: ActionRequestMessage ? (JSON.parse(process.env.ENABLE_ON_DEMAND_FAILOVER_FOR_ERRORS) as [string]) : []; - if (ephemeralEnabled && payload.eventType !== 'workflow_job') { - logger.warn(`${payload.eventType} event is not supported in combination with ephemeral runners.`); - throw Error( - `The event type ${payload.eventType} is not supported in combination with ephemeral runners.` + - `Please ensure you have enabled workflow_job events.`, - ); - } + const { ghesApiUrl, ghesBaseUrl } = getGitHubEnterpriseApiUrl(); - if (!isValidRepoOwnerTypeIfOrgLevelEnabled(payload, enableOrgLevel)) { - logger.warn( - `Repository ${payload.repositoryOwner}/${payload.repositoryName} does not belong to a GitHub` + - `organization and organization runners are enabled. This is not supported. Not scaling up for this event.` + - `Not throwing error to prevent re-queueing and just ignoring the event.`, - ); - return; + const ghAuth = await createGithubAppAuth(undefined, ghesApiUrl); + const githubAppClient = await createOctokitClient(ghAuth.token, ghesApiUrl); + + // A map of either owner or owner/repo name to Octokit client, so we use a + // single client per installation (set of messages), depending on how the app + // is installed. This is for a couple of reasons: + // - Sharing clients opens up the possibility of caching API calls. + // - Fetching a client for an installation actually requires a couple of API + // calls itself, which would get expensive if done for every message in a + // batch. + type MessagesWithClient = { + messages: ActionRequestMessageSQS[]; + githubInstallationClient: Octokit; + }; + + const validMessages = new Map(); + const invalidMessages: string[] = []; + for (const payload of payloads) { + const { eventType, messageId, repositoryName, repositoryOwner } = payload; + if (ephemeralEnabled && eventType !== 'workflow_job') { + logger.warn( + 'Event is not supported in combination with ephemeral runners. Please ensure you have enabled workflow_job events.', + { eventType, messageId }, + ); + + invalidMessages.push(messageId); + + continue; + } + + if (!isValidRepoOwnerTypeIfOrgLevelEnabled(payload, enableOrgLevel)) { + logger.warn( + `Repository does not belong to a GitHub organization and organization runners are enabled. This is not supported. Not scaling up for this event. Not throwing error to prevent re-queueing and just ignoring the event.`, + { + repository: `${repositoryOwner}/${repositoryName}`, + messageId, + }, + ); + + continue; + } + + const key = enableOrgLevel ? payload.repositoryOwner : `${payload.repositoryOwner}/${payload.repositoryName}`; + + let entry = validMessages.get(key); + + // If we've not seen this owner/repo before, we'll need to create a GitHub + // client for it. + if (entry === undefined) { + const installationId = await getInstallationId(githubAppClient, enableOrgLevel, payload); + const ghAuth = await createGithubInstallationAuth(installationId, ghesApiUrl); + const githubInstallationClient = await createOctokitClient(ghAuth.token, ghesApiUrl); + + entry = { + messages: [], + githubInstallationClient, + }; + + validMessages.set(key, entry); + } + + entry.messages.push(payload); } - const ephemeral = ephemeralEnabled && payload.eventType === 'workflow_job'; const runnerType = enableOrgLevel ? 'Org' : 'Repo'; - const runnerOwner = enableOrgLevel ? payload.repositoryOwner : `${payload.repositoryOwner}/${payload.repositoryName}`; addPersistentContextToChildLogger({ runner: { + ephemeral: ephemeralEnabled, type: runnerType, - owner: runnerOwner, namePrefix: runnerNamePrefix, - }, - github: { - event: payload.eventType, - workflow_job_id: payload.id.toString(), + n_events: Array.from(validMessages.values()).reduce((acc, group) => acc + group.messages.length, 0), }, }); - logger.info(`Received event`); + logger.info(`Received events`); - const { ghesApiUrl, ghesBaseUrl } = getGitHubEnterpriseApiUrl(); + for (const [group, { githubInstallationClient, messages }] of validMessages.entries()) { + // Work out how much we want to scale up by. + let scaleUp = 0; - const installationId = await getInstallationId(ghesApiUrl, enableOrgLevel, payload); - const ghAuth = await createGithubInstallationAuth(installationId, ghesApiUrl); - const githubInstallationClient = await createOctokitClient(ghAuth.token, ghesApiUrl); + for (const message of messages) { + const messageLogger = logger.createChild({ + persistentKeys: { + eventType: message.eventType, + group, + messageId: message.messageId, + repository: `${message.repositoryOwner}/${message.repositoryName}`, + }, + }); - if (!enableJobQueuedCheck || (await isJobQueued(githubInstallationClient, payload))) { - let scaleUp = true; - if (maximumRunners !== -1) { - const currentRunners = await listEC2Runners({ - environment, - runnerType, - runnerOwner, + if (enableJobQueuedCheck && !(await isJobQueued(githubInstallationClient, message))) { + messageLogger.info('No runner will be created, job is not queued.'); + + continue; + } + + scaleUp++; + } + + if (scaleUp === 0) { + logger.info('No runners will be created for this group, no valid messages found.'); + + continue; + } + + // Don't call the EC2 API if we can create an unlimited number of runners. + const currentRunners = + maximumRunners === -1 ? 0 : (await listEC2Runners({ environment, runnerType, runnerOwner: group })).length; + + logger.info('Current runners', { + currentRunners, + maximumRunners, + }); + + // Calculate how many runners we want to create. + const newRunners = + maximumRunners === -1 + ? // If we don't have an upper limit, scale up by the number of new jobs. + scaleUp + : // Otherwise, we do have a limit, so work out if `scaleUp` would exceed it. + Math.min(scaleUp, maximumRunners - currentRunners); + + const missingInstanceCount = Math.max(0, scaleUp - newRunners); + + if (missingInstanceCount > 0) { + logger.info('Not all runners will be created for this group, maximum number of runners reached.', { + desiredNewRunners: scaleUp, }); - logger.info(`Current runners: ${currentRunners.length} of ${maximumRunners}`); - scaleUp = currentRunners.length < maximumRunners; + + if (ephemeralEnabled) { + // This removes `missingInstanceCount` items from the start of the array + // so that, if we retry more messages later, we pick fresh ones. + invalidMessages.push(...messages.splice(0, missingInstanceCount).map(({ messageId }) => messageId)); + } + + // No runners will be created, so skip calling the EC2 API. + if (missingInstanceCount === scaleUp) { + continue; + } } - if (scaleUp) { - logger.info(`Attempting to launch a new runner`); + logger.info(`Attempting to launch new runners`, { + newRunners, + }); - await createRunners( - { - ephemeral, - enableJitConfig, - ghesBaseUrl, - runnerLabels, - runnerGroup, - runnerNamePrefix, - runnerOwner, - runnerType, - disableAutoUpdate, - ssmTokenPath, - ssmConfigPath, - }, - { - ec2instanceCriteria: { - instanceTypes, - targetCapacityType: instanceTargetCapacityType, - maxSpotPrice: instanceMaxSpotPrice, - instanceAllocationStrategy: instanceAllocationStrategy, - }, - environment, - launchTemplateName, - subnets, - amiIdSsmParameterName, - tracingEnabled, - onDemandFailoverOnError, + const instances = await createRunners( + { + ephemeral: ephemeralEnabled, + enableJitConfig, + ghesBaseUrl, + runnerLabels, + runnerGroup, + runnerNamePrefix, + runnerOwner: group, + runnerType, + disableAutoUpdate, + ssmTokenPath, + ssmConfigPath, + }, + { + ec2instanceCriteria: { + instanceTypes, + targetCapacityType: instanceTargetCapacityType, + maxSpotPrice: instanceMaxSpotPrice, + instanceAllocationStrategy: instanceAllocationStrategy, }, - githubInstallationClient, - ); + environment, + launchTemplateName, + subnets, + amiIdSsmParameterName, + tracingEnabled, + onDemandFailoverOnError, + }, + newRunners, + githubInstallationClient, + ); - await publishRetryMessage(payload); - } else { - logger.info('No runner will be created, maximum number of runners reached.'); - if (ephemeral) { - throw new ScaleError('No runners create: maximum of runners reached.'); - } + // Not all runners we wanted were created, let's reject enough items so that + // number of entries will be retried. + if (instances.length !== newRunners) { + const failedInstanceCount = newRunners - instances.length; + + logger.warn('Some runners failed to be created, rejecting some messages so the requests are retried', { + wanted: newRunners, + got: instances.length, + failedInstanceCount, + }); + + invalidMessages.push(...messages.slice(0, failedInstanceCount).map(({ messageId }) => messageId)); } - } else { - logger.info('No runner will be created, job is not queued.'); } + + return invalidMessages; } export function getGitHubEnterpriseApiUrl() { diff --git a/lambdas/libs/aws-powertools-util/src/logger/index.ts b/lambdas/libs/aws-powertools-util/src/logger/index.ts index 195b552a74..2bad191a83 100644 --- a/lambdas/libs/aws-powertools-util/src/logger/index.ts +++ b/lambdas/libs/aws-powertools-util/src/logger/index.ts @@ -9,7 +9,7 @@ const defaultValues = { }; function setContext(context: Context, module?: string) { - logger.addPersistentLogAttributes({ + logger.appendPersistentKeys({ 'aws-request-id': context.awsRequestId, 'function-name': context.functionName, module: module, @@ -17,7 +17,7 @@ function setContext(context: Context, module?: string) { // Add the context to all child loggers childLoggers.forEach((childLogger) => { - childLogger.addPersistentLogAttributes({ + childLogger.appendPersistentKeys({ 'aws-request-id': context.awsRequestId, 'function-name': context.functionName, }); @@ -25,14 +25,14 @@ function setContext(context: Context, module?: string) { } const logger = new Logger({ - persistentLogAttributes: { + persistentKeys: { ...defaultValues, }, }); function createChildLogger(module: string): Logger { const childLogger = logger.createChild({ - persistentLogAttributes: { + persistentKeys: { module: module, }, }); @@ -47,7 +47,7 @@ type LogAttributes = { function addPersistentContextToChildLogger(attributes: LogAttributes) { childLoggers.forEach((childLogger) => { - childLogger.addPersistentLogAttributes(attributes); + childLogger.appendPersistentKeys(attributes); }); } diff --git a/main.tf b/main.tf index 9c72614808..27615a55ef 100644 --- a/main.tf +++ b/main.tf @@ -210,28 +210,30 @@ module "runners" { credit_specification = var.runner_credit_specification cpu_options = var.runner_cpu_options - enable_runner_binaries_syncer = var.enable_runner_binaries_syncer - lambda_s3_bucket = var.lambda_s3_bucket - runners_lambda_s3_key = var.runners_lambda_s3_key - runners_lambda_s3_object_version = var.runners_lambda_s3_object_version - lambda_runtime = var.lambda_runtime - lambda_architecture = var.lambda_architecture - lambda_zip = var.runners_lambda_zip - lambda_scale_up_memory_size = var.runners_scale_up_lambda_memory_size - lambda_scale_down_memory_size = var.runners_scale_down_lambda_memory_size - lambda_timeout_scale_up = var.runners_scale_up_lambda_timeout - lambda_timeout_scale_down = var.runners_scale_down_lambda_timeout - lambda_subnet_ids = var.lambda_subnet_ids - lambda_security_group_ids = var.lambda_security_group_ids - lambda_tags = var.lambda_tags - tracing_config = var.tracing_config - logging_retention_in_days = var.logging_retention_in_days - logging_kms_key_id = var.logging_kms_key_id - enable_cloudwatch_agent = var.enable_cloudwatch_agent - cloudwatch_config = var.cloudwatch_config - runner_log_files = var.runner_log_files - runner_group_name = var.runner_group_name - runner_name_prefix = var.runner_name_prefix + enable_runner_binaries_syncer = var.enable_runner_binaries_syncer + lambda_s3_bucket = var.lambda_s3_bucket + runners_lambda_s3_key = var.runners_lambda_s3_key + runners_lambda_s3_object_version = var.runners_lambda_s3_object_version + lambda_runtime = var.lambda_runtime + lambda_architecture = var.lambda_architecture + lambda_event_source_mapping_batch_size = var.lambda_event_source_mapping_batch_size + lambda_event_source_mapping_maximum_batching_window_in_seconds = var.lambda_event_source_mapping_maximum_batching_window_in_seconds + lambda_zip = var.runners_lambda_zip + lambda_scale_up_memory_size = var.runners_scale_up_lambda_memory_size + lambda_scale_down_memory_size = var.runners_scale_down_lambda_memory_size + lambda_timeout_scale_up = var.runners_scale_up_lambda_timeout + lambda_timeout_scale_down = var.runners_scale_down_lambda_timeout + lambda_subnet_ids = var.lambda_subnet_ids + lambda_security_group_ids = var.lambda_security_group_ids + lambda_tags = var.lambda_tags + tracing_config = var.tracing_config + logging_retention_in_days = var.logging_retention_in_days + logging_kms_key_id = var.logging_kms_key_id + enable_cloudwatch_agent = var.enable_cloudwatch_agent + cloudwatch_config = var.cloudwatch_config + runner_log_files = var.runner_log_files + runner_group_name = var.runner_group_name + runner_name_prefix = var.runner_name_prefix scale_up_reserved_concurrent_executions = var.scale_up_reserved_concurrent_executions diff --git a/modules/multi-runner/README.md b/modules/multi-runner/README.md index 759cb61832..d05560e0f1 100644 --- a/modules/multi-runner/README.md +++ b/modules/multi-runner/README.md @@ -137,6 +137,8 @@ module "multi-runner" { | [key\_name](#input\_key\_name) | Key pair name | `string` | `null` | no | | [kms\_key\_arn](#input\_kms\_key\_arn) | Optional CMK Key ARN to be used for Parameter Store. | `string` | `null` | no | | [lambda\_architecture](#input\_lambda\_architecture) | AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions. | `string` | `"arm64"` | no | +| [lambda\_event\_source\_mapping\_batch\_size](#input\_lambda\_event\_source\_mapping\_batch\_size) | Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used. | `number` | `10` | no | +| [lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds](#input\_lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds) | Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10. Defaults to 0. | `number` | `0` | no | | [lambda\_principals](#input\_lambda\_principals) | (Optional) add extra principals to the role created for execution of the lambda, e.g. for local testing. |
list(object({
type = string
identifiers = list(string)
}))
| `[]` | no | | [lambda\_runtime](#input\_lambda\_runtime) | AWS Lambda runtime. | `string` | `"nodejs22.x"` | no | | [lambda\_s3\_bucket](#input\_lambda\_s3\_bucket) | S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly. | `string` | `null` | no | diff --git a/modules/multi-runner/runners.tf b/modules/multi-runner/runners.tf index 811ab36260..d58e61f6ac 100644 --- a/modules/multi-runner/runners.tf +++ b/modules/multi-runner/runners.tf @@ -58,28 +58,30 @@ module "runners" { credit_specification = each.value.runner_config.credit_specification cpu_options = each.value.runner_config.cpu_options - enable_runner_binaries_syncer = each.value.runner_config.enable_runner_binaries_syncer - lambda_s3_bucket = var.lambda_s3_bucket - runners_lambda_s3_key = var.runners_lambda_s3_key - runners_lambda_s3_object_version = var.runners_lambda_s3_object_version - lambda_runtime = var.lambda_runtime - lambda_architecture = var.lambda_architecture - lambda_zip = var.runners_lambda_zip - lambda_scale_up_memory_size = var.scale_up_lambda_memory_size - lambda_timeout_scale_up = var.runners_scale_up_lambda_timeout - lambda_scale_down_memory_size = var.scale_down_lambda_memory_size - lambda_timeout_scale_down = var.runners_scale_down_lambda_timeout - lambda_subnet_ids = var.lambda_subnet_ids - lambda_security_group_ids = var.lambda_security_group_ids - lambda_tags = var.lambda_tags - tracing_config = var.tracing_config - logging_retention_in_days = var.logging_retention_in_days - logging_kms_key_id = var.logging_kms_key_id - enable_cloudwatch_agent = each.value.runner_config.enable_cloudwatch_agent - cloudwatch_config = try(coalesce(each.value.runner_config.cloudwatch_config, var.cloudwatch_config), null) - runner_log_files = each.value.runner_config.runner_log_files - runner_group_name = each.value.runner_config.runner_group_name - runner_name_prefix = each.value.runner_config.runner_name_prefix + enable_runner_binaries_syncer = each.value.runner_config.enable_runner_binaries_syncer + lambda_s3_bucket = var.lambda_s3_bucket + runners_lambda_s3_key = var.runners_lambda_s3_key + runners_lambda_s3_object_version = var.runners_lambda_s3_object_version + lambda_runtime = var.lambda_runtime + lambda_architecture = var.lambda_architecture + lambda_zip = var.runners_lambda_zip + lambda_scale_up_memory_size = var.scale_up_lambda_memory_size + lambda_event_source_mapping_batch_size = var.lambda_event_source_mapping_batch_size + lambda_event_source_mapping_maximum_batching_window_in_seconds = var.lambda_event_source_mapping_maximum_batching_window_in_seconds + lambda_timeout_scale_up = var.runners_scale_up_lambda_timeout + lambda_scale_down_memory_size = var.scale_down_lambda_memory_size + lambda_timeout_scale_down = var.runners_scale_down_lambda_timeout + lambda_subnet_ids = var.lambda_subnet_ids + lambda_security_group_ids = var.lambda_security_group_ids + lambda_tags = var.lambda_tags + tracing_config = var.tracing_config + logging_retention_in_days = var.logging_retention_in_days + logging_kms_key_id = var.logging_kms_key_id + enable_cloudwatch_agent = each.value.runner_config.enable_cloudwatch_agent + cloudwatch_config = try(coalesce(each.value.runner_config.cloudwatch_config, var.cloudwatch_config), null) + runner_log_files = each.value.runner_config.runner_log_files + runner_group_name = each.value.runner_config.runner_group_name + runner_name_prefix = each.value.runner_config.runner_name_prefix scale_up_reserved_concurrent_executions = each.value.runner_config.scale_up_reserved_concurrent_executions diff --git a/modules/multi-runner/variables.tf b/modules/multi-runner/variables.tf index edbdb33059..2c7c7c69e5 100644 --- a/modules/multi-runner/variables.tf +++ b/modules/multi-runner/variables.tf @@ -718,3 +718,15 @@ variable "user_agent" { type = string default = "github-aws-runners" } + +variable "lambda_event_source_mapping_batch_size" { + description = "Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used." + type = number + default = 10 +} + +variable "lambda_event_source_mapping_maximum_batching_window_in_seconds" { + description = "Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch_size is greater than 10. Defaults to 0." + type = number + default = 0 +} diff --git a/modules/runners/README.md b/modules/runners/README.md index cf62c2c96a..45807f67bf 100644 --- a/modules/runners/README.md +++ b/modules/runners/README.md @@ -177,6 +177,8 @@ yarn run dist | [key\_name](#input\_key\_name) | Key pair name | `string` | `null` | no | | [kms\_key\_arn](#input\_kms\_key\_arn) | Optional CMK Key ARN to be used for Parameter Store. | `string` | `null` | no | | [lambda\_architecture](#input\_lambda\_architecture) | AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions. | `string` | `"arm64"` | no | +| [lambda\_event\_source\_mapping\_batch\_size](#input\_lambda\_event\_source\_mapping\_batch\_size) | Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used. | `number` | `10` | no | +| [lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds](#input\_lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds) | Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10. Defaults to 0. | `number` | `0` | no | | [lambda\_runtime](#input\_lambda\_runtime) | AWS Lambda runtime. | `string` | `"nodejs22.x"` | no | | [lambda\_s3\_bucket](#input\_lambda\_s3\_bucket) | S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly. | `string` | `null` | no | | [lambda\_scale\_down\_memory\_size](#input\_lambda\_scale\_down\_memory\_size) | Memory size limit in MB for scale down lambda. | `number` | `512` | no | diff --git a/modules/runners/job-retry.tf b/modules/runners/job-retry.tf index e51c3903d4..130992667f 100644 --- a/modules/runners/job-retry.tf +++ b/modules/runners/job-retry.tf @@ -3,30 +3,32 @@ locals { job_retry_enabled = var.job_retry != null && var.job_retry.enable ? true : false job_retry = { - prefix = var.prefix - tags = local.tags - aws_partition = var.aws_partition - architecture = var.lambda_architecture - runtime = var.lambda_runtime - security_group_ids = var.lambda_security_group_ids - subnet_ids = var.lambda_subnet_ids - kms_key_arn = var.kms_key_arn - lambda_tags = var.lambda_tags - log_level = var.log_level - logging_kms_key_id = var.logging_kms_key_id - logging_retention_in_days = var.logging_retention_in_days - metrics = var.metrics - role_path = var.role_path - role_permissions_boundary = var.role_permissions_boundary - s3_bucket = var.lambda_s3_bucket - s3_key = var.runners_lambda_s3_key - s3_object_version = var.runners_lambda_s3_object_version - zip = var.lambda_zip - tracing_config = var.tracing_config - github_app_parameters = var.github_app_parameters - enable_organization_runners = var.enable_organization_runners - sqs_build_queue = var.sqs_build_queue - ghes_url = var.ghes_url + prefix = var.prefix + tags = local.tags + aws_partition = var.aws_partition + architecture = var.lambda_architecture + runtime = var.lambda_runtime + security_group_ids = var.lambda_security_group_ids + subnet_ids = var.lambda_subnet_ids + kms_key_arn = var.kms_key_arn + lambda_tags = var.lambda_tags + log_level = var.log_level + logging_kms_key_id = var.logging_kms_key_id + logging_retention_in_days = var.logging_retention_in_days + metrics = var.metrics + role_path = var.role_path + role_permissions_boundary = var.role_permissions_boundary + s3_bucket = var.lambda_s3_bucket + s3_key = var.runners_lambda_s3_key + s3_object_version = var.runners_lambda_s3_object_version + zip = var.lambda_zip + tracing_config = var.tracing_config + github_app_parameters = var.github_app_parameters + enable_organization_runners = var.enable_organization_runners + sqs_build_queue = var.sqs_build_queue + ghes_url = var.ghes_url + lambda_event_source_mapping_batch_size = var.lambda_event_source_mapping_batch_size + lambda_event_source_mapping_maximum_batching_window_in_seconds = var.lambda_event_source_mapping_maximum_batching_window_in_seconds } } diff --git a/modules/runners/job-retry/README.md b/modules/runners/job-retry/README.md index 91089a213b..4f4c80921c 100644 --- a/modules/runners/job-retry/README.md +++ b/modules/runners/job-retry/README.md @@ -42,7 +42,7 @@ The module is an inner module and used by the runner module when the opt-in feat | Name | Description | Type | Default | Required | |------|-------------|------|---------|:--------:| -| [config](#input\_config) | Configuration for the spot termination watcher lambda function.

`aws_partition`: Partition for the base arn if not 'aws'
`architecture`: AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions.
`environment_variables`: Environment variables for the lambda.
`enable_organization_runners`: Enable organization runners.
`enable_metric`: Enable metric for the lambda. If `spot_warning` is set to true, the lambda will emit a metric when it detects a spot termination warning.
'ghes\_url': Optional GitHub Enterprise Server URL.
'user\_agent': Optional User-Agent header for GitHub API requests.
'github\_app\_parameters': Parameter Store for GitHub App Parameters.
'kms\_key\_arn': Optional CMK Key ARN instead of using the default AWS managed key.
`lambda_principals`: Add extra principals to the role created for execution of the lambda, e.g. for local testing.
`lambda_tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment.
`log_level`: Logging level for lambda logging. Valid values are 'silly', 'trace', 'debug', 'info', 'warn', 'error', 'fatal'.
`logging_kms_key_id`: Specifies the kms key id to encrypt the logs with
`logging_retention_in_days`: Specifies the number of days you want to retain log events for the lambda log group. Possible values are: 0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1827, and 3653.
`memory_size`: Memory size linit in MB of the lambda.
`metrics`: Configuration to enable metrics creation by the lambda.
`prefix`: The prefix used for naming resources.
`role_path`: The path that will be added to the role, if not set the environment name will be used.
`role_permissions_boundary`: Permissions boundary that will be added to the created role for the lambda.
`runtime`: AWS Lambda runtime.
`s3_bucket`: S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly.
`s3_key`: S3 key for syncer lambda function. Required if using S3 bucket to specify lambdas.
`s3_object_version`: S3 object version for syncer lambda function. Useful if S3 versioning is enabled on source bucket.
`security_group_ids`: List of security group IDs associated with the Lambda function.
'sqs\_build\_queue': SQS queue for build events to re-publish job request.
`subnet_ids`: List of subnets in which the action runners will be launched, the subnets needs to be subnets in the `vpc_id`.
`tag_filters`: Map of tags that will be used to filter the resources to be tracked. Only for which all tags are present and starting with the same value as the value in the map will be tracked.
`tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment.
`timeout`: Time out of the lambda in seconds.
`tracing_config`: Configuration for lambda tracing.
`zip`: File location of the lambda zip file. |
object({
aws_partition = optional(string, null)
architecture = optional(string, null)
enable_organization_runners = bool
environment_variables = optional(map(string), {})
ghes_url = optional(string, null)
user_agent = optional(string, null)
github_app_parameters = object({
key_base64 = map(string)
id = map(string)
})
kms_key_arn = optional(string, null)
lambda_tags = optional(map(string), {})
log_level = optional(string, null)
logging_kms_key_id = optional(string, null)
logging_retention_in_days = optional(number, null)
memory_size = optional(number, null)
metrics = optional(object({
enable = optional(bool, false)
namespace = optional(string, null)
metric = optional(object({
enable_github_app_rate_limit = optional(bool, true)
enable_job_retry = optional(bool, true)
}), {})
}), {})
prefix = optional(string, null)
principals = optional(list(object({
type = string
identifiers = list(string)
})), [])
queue_encryption = optional(object({
kms_data_key_reuse_period_seconds = optional(number, null)
kms_master_key_id = optional(string, null)
sqs_managed_sse_enabled = optional(bool, true)
}), {})
role_path = optional(string, null)
role_permissions_boundary = optional(string, null)
runtime = optional(string, null)
security_group_ids = optional(list(string), [])
subnet_ids = optional(list(string), [])
s3_bucket = optional(string, null)
s3_key = optional(string, null)
s3_object_version = optional(string, null)
sqs_build_queue = object({
url = string
arn = string
})
tags = optional(map(string), {})
timeout = optional(number, 30)
tracing_config = optional(object({
mode = optional(string, null)
capture_http_requests = optional(bool, false)
capture_error = optional(bool, false)
}), {})
zip = optional(string, null)
})
| n/a | yes | +| [config](#input\_config) | Configuration for the spot termination watcher lambda function.

`aws_partition`: Partition for the base arn if not 'aws'
`architecture`: AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions.
`environment_variables`: Environment variables for the lambda.
`enable_organization_runners`: Enable organization runners.
`enable_metric`: Enable metric for the lambda. If `spot_warning` is set to true, the lambda will emit a metric when it detects a spot termination warning.
'ghes\_url': Optional GitHub Enterprise Server URL.
'user\_agent': Optional User-Agent header for GitHub API requests.
'github\_app\_parameters': Parameter Store for GitHub App Parameters.
'kms\_key\_arn': Optional CMK Key ARN instead of using the default AWS managed key.
`lambda_event_source_mapping_batch_size`: Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default will be used.
`lambda_event_source_mapping_maximum_batching_window_in_seconds`: Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10.
`lambda_principals`: Add extra principals to the role created for execution of the lambda, e.g. for local testing.
`lambda_tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment.
`log_level`: Logging level for lambda logging. Valid values are 'silly', 'trace', 'debug', 'info', 'warn', 'error', 'fatal'.
`logging_kms_key_id`: Specifies the kms key id to encrypt the logs with
`logging_retention_in_days`: Specifies the number of days you want to retain log events for the lambda log group. Possible values are: 0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1827, and 3653.
`memory_size`: Memory size linit in MB of the lambda.
`metrics`: Configuration to enable metrics creation by the lambda.
`prefix`: The prefix used for naming resources.
`role_path`: The path that will be added to the role, if not set the environment name will be used.
`role_permissions_boundary`: Permissions boundary that will be added to the created role for the lambda.
`runtime`: AWS Lambda runtime.
`s3_bucket`: S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly.
`s3_key`: S3 key for syncer lambda function. Required if using S3 bucket to specify lambdas.
`s3_object_version`: S3 object version for syncer lambda function. Useful if S3 versioning is enabled on source bucket.
`security_group_ids`: List of security group IDs associated with the Lambda function.
'sqs\_build\_queue': SQS queue for build events to re-publish job request.
`subnet_ids`: List of subnets in which the action runners will be launched, the subnets needs to be subnets in the `vpc_id`.
`tag_filters`: Map of tags that will be used to filter the resources to be tracked. Only for which all tags are present and starting with the same value as the value in the map will be tracked.
`tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment.
`timeout`: Time out of the lambda in seconds.
`tracing_config`: Configuration for lambda tracing.
`zip`: File location of the lambda zip file. |
object({
aws_partition = optional(string, null)
architecture = optional(string, null)
enable_organization_runners = bool
environment_variables = optional(map(string), {})
ghes_url = optional(string, null)
user_agent = optional(string, null)
github_app_parameters = object({
key_base64 = map(string)
id = map(string)
})
kms_key_arn = optional(string, null)
lambda_event_source_mapping_batch_size = optional(number, 10)
lambda_event_source_mapping_maximum_batching_window_in_seconds = optional(number, 0)
lambda_tags = optional(map(string), {})
log_level = optional(string, null)
logging_kms_key_id = optional(string, null)
logging_retention_in_days = optional(number, null)
memory_size = optional(number, null)
metrics = optional(object({
enable = optional(bool, false)
namespace = optional(string, null)
metric = optional(object({
enable_github_app_rate_limit = optional(bool, true)
enable_job_retry = optional(bool, true)
}), {})
}), {})
prefix = optional(string, null)
principals = optional(list(object({
type = string
identifiers = list(string)
})), [])
queue_encryption = optional(object({
kms_data_key_reuse_period_seconds = optional(number, null)
kms_master_key_id = optional(string, null)
sqs_managed_sse_enabled = optional(bool, true)
}), {})
role_path = optional(string, null)
role_permissions_boundary = optional(string, null)
runtime = optional(string, null)
security_group_ids = optional(list(string), [])
subnet_ids = optional(list(string), [])
s3_bucket = optional(string, null)
s3_key = optional(string, null)
s3_object_version = optional(string, null)
sqs_build_queue = object({
url = string
arn = string
})
tags = optional(map(string), {})
timeout = optional(number, 30)
tracing_config = optional(object({
mode = optional(string, null)
capture_http_requests = optional(bool, false)
capture_error = optional(bool, false)
}), {})
zip = optional(string, null)
})
| n/a | yes | ## Outputs diff --git a/modules/runners/job-retry/main.tf b/modules/runners/job-retry/main.tf index 9561c7db71..612c515f8c 100644 --- a/modules/runners/job-retry/main.tf +++ b/modules/runners/job-retry/main.tf @@ -44,9 +44,10 @@ module "job_retry" { } resource "aws_lambda_event_source_mapping" "job_retry" { - event_source_arn = aws_sqs_queue.job_retry_check_queue.arn - function_name = module.job_retry.lambda.function.arn - batch_size = 1 + event_source_arn = aws_sqs_queue.job_retry_check_queue.arn + function_name = module.job_retry.lambda.function.arn + batch_size = var.config.lambda_event_source_mapping_batch_size + maximum_batching_window_in_seconds = var.config.lambda_event_source_mapping_maximum_batching_window_in_seconds } resource "aws_lambda_permission" "job_retry" { diff --git a/modules/runners/job-retry/variables.tf b/modules/runners/job-retry/variables.tf index 4a8fe19fbf..f40bec1ba7 100644 --- a/modules/runners/job-retry/variables.tf +++ b/modules/runners/job-retry/variables.tf @@ -11,6 +11,8 @@ variable "config" { 'user_agent': Optional User-Agent header for GitHub API requests. 'github_app_parameters': Parameter Store for GitHub App Parameters. 'kms_key_arn': Optional CMK Key ARN instead of using the default AWS managed key. + `lambda_event_source_mapping_batch_size`: Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default will be used. + `lambda_event_source_mapping_maximum_batching_window_in_seconds`: Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch_size is greater than 10. `lambda_principals`: Add extra principals to the role created for execution of the lambda, e.g. for local testing. `lambda_tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment. `log_level`: Logging level for lambda logging. Valid values are 'silly', 'trace', 'debug', 'info', 'warn', 'error', 'fatal'. @@ -45,12 +47,14 @@ variable "config" { key_base64 = map(string) id = map(string) }) - kms_key_arn = optional(string, null) - lambda_tags = optional(map(string), {}) - log_level = optional(string, null) - logging_kms_key_id = optional(string, null) - logging_retention_in_days = optional(number, null) - memory_size = optional(number, null) + kms_key_arn = optional(string, null) + lambda_event_source_mapping_batch_size = optional(number, 10) + lambda_event_source_mapping_maximum_batching_window_in_seconds = optional(number, 0) + lambda_tags = optional(map(string), {}) + log_level = optional(string, null) + logging_kms_key_id = optional(string, null) + logging_retention_in_days = optional(number, null) + memory_size = optional(number, null) metrics = optional(object({ enable = optional(bool, false) namespace = optional(string, null) diff --git a/modules/runners/scale-up.tf b/modules/runners/scale-up.tf index 9230267c07..a9985aec40 100644 --- a/modules/runners/scale-up.tf +++ b/modules/runners/scale-up.tf @@ -87,9 +87,11 @@ resource "aws_cloudwatch_log_group" "scale_up" { } resource "aws_lambda_event_source_mapping" "scale_up" { - event_source_arn = var.sqs_build_queue.arn - function_name = aws_lambda_function.scale_up.arn - batch_size = 1 + event_source_arn = var.sqs_build_queue.arn + function_name = aws_lambda_function.scale_up.arn + function_response_types = ["ReportBatchItemFailures"] + batch_size = var.lambda_event_source_mapping_batch_size + maximum_batching_window_in_seconds = var.lambda_event_source_mapping_maximum_batching_window_in_seconds } resource "aws_lambda_permission" "scale_runners_lambda" { diff --git a/modules/runners/variables.tf b/modules/runners/variables.tf index a78231e7da..c71fa77898 100644 --- a/modules/runners/variables.tf +++ b/modules/runners/variables.tf @@ -770,3 +770,23 @@ variable "user_agent" { type = string default = null } + +variable "lambda_event_source_mapping_batch_size" { + description = "Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used." + type = number + default = 10 + validation { + condition = var.lambda_event_source_mapping_batch_size >= 1 && var.lambda_event_source_mapping_batch_size <= 1000 + error_message = "The batch size for the lambda event source mapping must be between 1 and 1000." + } +} + +variable "lambda_event_source_mapping_maximum_batching_window_in_seconds" { + description = "Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch_size is greater than 10. Defaults to 0." + type = number + default = 0 + validation { + condition = var.lambda_event_source_mapping_maximum_batching_window_in_seconds >= 0 && var.lambda_event_source_mapping_maximum_batching_window_in_seconds <= 300 + error_message = "Maximum batching window must be between 0 and 300 seconds." + } +} diff --git a/variables.tf b/variables.tf index f412d2a486..07deb83cf9 100644 --- a/variables.tf +++ b/variables.tf @@ -1016,3 +1016,19 @@ variable "user_agent" { type = string default = "github-aws-runners" } + +variable "lambda_event_source_mapping_batch_size" { + description = "Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used." + type = number + default = 10 +} + +variable "lambda_event_source_mapping_maximum_batching_window_in_seconds" { + description = "Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch_size is greater than 10. Defaults to 0." + type = number + default = 0 + validation { + condition = var.lambda_event_source_mapping_maximum_batching_window_in_seconds >= 0 && var.lambda_event_source_mapping_maximum_batching_window_in_seconds <= 300 + error_message = "Maximum batching window must be between 0 and 300 seconds." + } +}