Skip to content

Commit aad0c97

Browse files
committed
feat(sagemaker): add support for serverless inference endpoints
Implements SageMaker Serverless Inference endpoints as requested in issue #23148. - Add ServerlessProductionVariantProps interface with maxConcurrency, memorySizeInMB, and provisionedConcurrency - Extend EndpointConfig to support serverless variants alongside existing instance variants - Add comprehensive validation for serverless configuration parameters - Enforce mutual exclusivity between instance and serverless variants - Add CloudFormation template generation for ServerlessConfig properties - Include extensive test coverage for validation scenarios and error cases Closes #23148
1 parent 8a26869 commit aad0c97

File tree

3 files changed

+431
-5
lines changed

3 files changed

+431
-5
lines changed

packages/@aws-cdk/aws-sagemaker-alpha/lib/endpoint-config.ts

Lines changed: 185 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,31 @@ export interface InstanceProductionVariantProps extends ProductionVariantProps {
7575
readonly instanceType?: InstanceType;
7676
}
7777

78+
/**
79+
* Construction properties for a serverless production variant.
80+
*/
81+
export interface ServerlessProductionVariantProps extends ProductionVariantProps {
82+
/**
83+
* The maximum number of concurrent invocations your serverless endpoint can process.
84+
*
85+
* Valid range: 1-200
86+
*/
87+
readonly maxConcurrency: number;
88+
/**
89+
* The memory size of your serverless endpoint. Valid values are in 1 GB increments:
90+
* 1024 MB, 2048 MB, 3072 MB, 4096 MB, 5120 MB, or 6144 MB.
91+
*/
92+
readonly memorySizeInMB: number;
93+
/**
94+
* The number of concurrent invocations that are provisioned and ready to respond to your endpoint.
95+
*
96+
* Valid range: 1-200, must be less than or equal to maxConcurrency.
97+
*
98+
* @default - none
99+
*/
100+
readonly provisionedConcurrency?: number;
101+
}
102+
78103
/**
79104
* Represents common attributes of all production variant types (e.g., instance, serverless) once
80105
* associated to an EndpointConfig.
@@ -119,6 +144,26 @@ export interface InstanceProductionVariant extends ProductionVariant {
119144
readonly instanceType: InstanceType;
120145
}
121146

147+
/**
148+
* Represents a serverless production variant that has been associated with an EndpointConfig.
149+
*
150+
* @internal
151+
*/
152+
interface ServerlessProductionVariant extends ProductionVariant {
153+
/**
154+
* The maximum number of concurrent invocations your serverless endpoint can process.
155+
*/
156+
readonly maxConcurrency: number;
157+
/**
158+
* The memory size of your serverless endpoint.
159+
*/
160+
readonly memorySizeInMB: number;
161+
/**
162+
* The number of concurrent invocations that are provisioned and ready to respond to your endpoint.
163+
*/
164+
readonly provisionedConcurrency?: number;
165+
}
166+
122167
/**
123168
* Construction properties for a SageMaker EndpointConfig.
124169
*/
@@ -142,9 +187,21 @@ export interface EndpointConfigProps {
142187
* A list of instance production variants. You can always add more variants later by calling
143188
* `EndpointConfig#addInstanceProductionVariant`.
144189
*
190+
* Cannot be specified if `serverlessProductionVariant` is specified.
191+
*
145192
* @default - none
146193
*/
147194
readonly instanceProductionVariants?: InstanceProductionVariantProps[];
195+
196+
/**
197+
* A serverless production variant. Serverless endpoints automatically launch compute resources
198+
* and scale them in and out depending on traffic.
199+
*
200+
* Cannot be specified if `instanceProductionVariants` is specified.
201+
*
202+
* @default - none
203+
*/
204+
readonly serverlessProductionVariant?: ServerlessProductionVariantProps;
148205
}
149206

150207
/**
@@ -207,6 +264,7 @@ export class EndpointConfig extends cdk.Resource implements IEndpointConfig {
207264
public readonly endpointConfigName: string;
208265

209266
private readonly instanceProductionVariantsByName: { [key: string]: InstanceProductionVariant } = {};
267+
private serverlessProductionVariant?: ServerlessProductionVariant;
210268

211269
constructor(scope: Construct, id: string, props: EndpointConfigProps = {}) {
212270
super(scope, id, {
@@ -215,13 +273,22 @@ export class EndpointConfig extends cdk.Resource implements IEndpointConfig {
215273
// Enhanced CDK Analytics Telemetry
216274
addConstructMetadata(this, props);
217275

276+
// Validate mutual exclusivity
277+
if (props.instanceProductionVariants && props.serverlessProductionVariant) {
278+
throw new Error('Cannot specify both instanceProductionVariants and serverlessProductionVariant. Choose one variant type.');
279+
}
280+
218281
(props.instanceProductionVariants || []).map(p => this.addInstanceProductionVariant(p));
219282

283+
if (props.serverlessProductionVariant) {
284+
this.addServerlessProductionVariant(props.serverlessProductionVariant);
285+
}
286+
220287
// create the endpoint configuration resource
221288
const endpointConfig = new CfnEndpointConfig(this, 'EndpointConfig', {
222289
kmsKeyId: (props.encryptionKey) ? props.encryptionKey.keyRef.keyArn : undefined,
223290
endpointConfigName: this.physicalName,
224-
productionVariants: cdk.Lazy.any({ produce: () => this.renderInstanceProductionVariants() }),
291+
productionVariants: cdk.Lazy.any({ produce: () => this.renderProductionVariants() }),
225292
});
226293
this.endpointConfigName = this.getResourceNameAttribute(endpointConfig.attrEndpointConfigName);
227294
this.endpointConfigArn = this.getResourceArnAttribute(endpointConfig.ref, {
@@ -238,6 +305,9 @@ export class EndpointConfig extends cdk.Resource implements IEndpointConfig {
238305
*/
239306
@MethodMetadata()
240307
public addInstanceProductionVariant(props: InstanceProductionVariantProps): void {
308+
if (this.serverlessProductionVariant) {
309+
throw new Error('Cannot add instance production variant when serverless production variant is already configured');
310+
}
241311
if (props.variantName in this.instanceProductionVariantsByName) {
242312
throw new Error(`There is already a Production Variant with name '${props.variantName}'`);
243313
}
@@ -252,6 +322,30 @@ export class EndpointConfig extends cdk.Resource implements IEndpointConfig {
252322
};
253323
}
254324

325+
/**
326+
* Add serverless production variant to the endpoint configuration.
327+
*
328+
* @param props The properties of a serverless production variant to add.
329+
*/
330+
@MethodMetadata()
331+
public addServerlessProductionVariant(props: ServerlessProductionVariantProps): void {
332+
if (Object.keys(this.instanceProductionVariantsByName).length > 0) {
333+
throw new Error('Cannot add serverless production variant when instance production variants are already configured');
334+
}
335+
if (this.serverlessProductionVariant) {
336+
throw new Error('Cannot add more than one serverless production variant per endpoint configuration');
337+
}
338+
this.validateServerlessProductionVariantProps(props);
339+
this.serverlessProductionVariant = {
340+
initialVariantWeight: props.initialVariantWeight || 1.0,
341+
maxConcurrency: props.maxConcurrency,
342+
memorySizeInMB: props.memorySizeInMB,
343+
modelName: props.model.modelName,
344+
provisionedConcurrency: props.provisionedConcurrency,
345+
variantName: props.variantName,
346+
};
347+
}
348+
255349
/**
256350
* Get instance production variants associated with endpoint configuration.
257351
*
@@ -276,10 +370,21 @@ export class EndpointConfig extends cdk.Resource implements IEndpointConfig {
276370
}
277371

278372
private validateProductionVariants(): void {
279-
// validate number of production variants
280-
if (this._instanceProductionVariants.length < 1) {
373+
const hasInstanceVariants = this._instanceProductionVariants.length > 0;
374+
const hasServerlessVariant = this.serverlessProductionVariant !== undefined;
375+
376+
// validate at least one production variant
377+
if (!hasInstanceVariants && !hasServerlessVariant) {
281378
throw new Error('Must configure at least 1 production variant');
282-
} else if (this._instanceProductionVariants.length > 10) {
379+
}
380+
381+
// validate mutual exclusivity
382+
if (hasInstanceVariants && hasServerlessVariant) {
383+
throw new Error('Cannot configure both instance and serverless production variants');
384+
}
385+
386+
// validate instance variant limits
387+
if (hasInstanceVariants && this._instanceProductionVariants.length > 10) {
283388
throw new Error('Can\'t have more than 10 production variants');
284389
}
285390
}
@@ -310,11 +415,65 @@ export class EndpointConfig extends cdk.Resource implements IEndpointConfig {
310415
}
311416
}
312417

418+
private validateServerlessProductionVariantProps(props: ServerlessProductionVariantProps): void {
419+
const errors: string[] = [];
420+
421+
// check variant weight is not negative
422+
if (props.initialVariantWeight && props.initialVariantWeight < 0) {
423+
errors.push('Cannot have negative variant weight');
424+
}
425+
426+
// check maxConcurrency range
427+
if (props.maxConcurrency < 1 || props.maxConcurrency > 200) {
428+
errors.push('maxConcurrency must be between 1 and 200');
429+
}
430+
431+
// check memorySizeInMB valid values (1GB increments from 1024 to 6144)
432+
const validMemorySizes = [1024, 2048, 3072, 4096, 5120, 6144];
433+
if (!validMemorySizes.includes(props.memorySizeInMB)) {
434+
errors.push(`memorySizeInMB must be one of: ${validMemorySizes.join(', ')} MB`);
435+
}
436+
437+
// check provisionedConcurrency range and relationship to maxConcurrency
438+
if (props.provisionedConcurrency !== undefined) {
439+
if (props.provisionedConcurrency < 1 || props.provisionedConcurrency > 200) {
440+
errors.push('provisionedConcurrency must be between 1 and 200');
441+
}
442+
if (props.provisionedConcurrency > props.maxConcurrency) {
443+
errors.push('provisionedConcurrency cannot be greater than maxConcurrency');
444+
}
445+
}
446+
447+
// check environment compatibility with model
448+
const model = props.model;
449+
if (!sameEnv(model.env.account, this.env.account)) {
450+
errors.push(`Cannot use model in account ${model.env.account} for endpoint configuration in account ${this.env.account}`);
451+
} else if (!sameEnv(model.env.region, this.env.region)) {
452+
errors.push(`Cannot use model in region ${model.env.region} for endpoint configuration in region ${this.env.region}`);
453+
}
454+
455+
if (errors.length > 0) {
456+
throw new Error(`Invalid Serverless Production Variant Props: ${errors.join(EOL)}`);
457+
}
458+
}
459+
460+
/**
461+
* Render the list of production variants (instance or serverless).
462+
*/
463+
private renderProductionVariants(): CfnEndpointConfig.ProductionVariantProperty[] {
464+
this.validateProductionVariants();
465+
466+
if (this.serverlessProductionVariant) {
467+
return this.renderServerlessProductionVariant();
468+
} else {
469+
return this.renderInstanceProductionVariants();
470+
}
471+
}
472+
313473
/**
314474
* Render the list of instance production variants.
315475
*/
316476
private renderInstanceProductionVariants(): CfnEndpointConfig.ProductionVariantProperty[] {
317-
this.validateProductionVariants();
318477
return this._instanceProductionVariants.map( v => ({
319478
acceleratorType: v.acceleratorType?.toString(),
320479
initialInstanceCount: v.initialInstanceCount,
@@ -324,4 +483,25 @@ export class EndpointConfig extends cdk.Resource implements IEndpointConfig {
324483
variantName: v.variantName,
325484
}) );
326485
}
486+
487+
/**
488+
* Render the serverless production variant.
489+
*/
490+
private renderServerlessProductionVariant(): CfnEndpointConfig.ProductionVariantProperty[] {
491+
if (!this.serverlessProductionVariant) {
492+
return [];
493+
}
494+
495+
const variant = this.serverlessProductionVariant;
496+
return [{
497+
initialVariantWeight: variant.initialVariantWeight,
498+
modelName: variant.modelName,
499+
variantName: variant.variantName,
500+
serverlessConfig: {
501+
maxConcurrency: variant.maxConcurrency,
502+
memorySizeInMb: variant.memorySizeInMB,
503+
provisionedConcurrency: variant.provisionedConcurrency,
504+
},
505+
}];
506+
}
327507
}

0 commit comments

Comments
 (0)