Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cmd/controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type flags struct {
metricsRequestHeaderAttributes string
metricsRequestHeaderLabels string // DEPRECATED: use metricsRequestHeaderAttributes instead.
spanRequestHeaderAttributes string
endpointPrefixes string
rootPrefix string
extProcExtraEnvVars string
extProcImagePullSecrets string
Expand Down Expand Up @@ -154,6 +155,11 @@ func parseAndValidateFlags(args []string) (flags, error) {
"",
"Comma-separated key-value pairs for mapping HTTP request headers to otel span attributes. Format: x-session-id:session.id,x-user-id:user.id.",
)
endpointPrefixes := fs.String(
"endpointPrefixes",
"",
"Comma-separated key-value pairs for endpoint prefixes. Format: openai:/,cohere:/cohere,anthropic:/anthropic.",
)
rootPrefix := fs.String(
"rootPrefix",
"/",
Expand Down Expand Up @@ -240,6 +246,13 @@ func parseAndValidateFlags(args []string) (flags, error) {
}
}

// Validate endpoint prefixes if provided.
if *endpointPrefixes != "" {
if _, err := internalapi.ParseEndpointPrefixes(*endpointPrefixes); err != nil {
return flags{}, fmt.Errorf("invalid endpoint prefixes: %w", err)
}
}

// Validate extProc extra env vars if provided.
if *extProcExtraEnvVars != "" {
_, err := controller.ParseExtraEnvVars(*extProcExtraEnvVars)
Expand Down Expand Up @@ -270,6 +283,7 @@ func parseAndValidateFlags(args []string) (flags, error) {
metricsRequestHeaderAttributes: *metricsRequestHeaderAttributes,
metricsRequestHeaderLabels: *metricsRequestHeaderLabels,
spanRequestHeaderAttributes: *spanRequestHeaderAttributes,
endpointPrefixes: *endpointPrefixes,
rootPrefix: *rootPrefix,
extProcExtraEnvVars: *extProcExtraEnvVars,
extProcImagePullSecrets: *extProcImagePullSecrets,
Expand Down Expand Up @@ -362,6 +376,7 @@ func main() {
UDSPath: extProcUDSPath,
MetricsRequestHeaderAttributes: parsedFlags.metricsRequestHeaderAttributes,
TracingRequestHeaderAttributes: parsedFlags.spanRequestHeaderAttributes,
EndpointPrefixes: parsedFlags.endpointPrefixes,
RootPrefix: parsedFlags.rootPrefix,
ExtProcExtraEnvVars: parsedFlags.extProcExtraEnvVars,
ExtProcImagePullSecrets: parsedFlags.extProcImagePullSecrets,
Expand Down
12 changes: 12 additions & 0 deletions cmd/controller/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func Test_parseAndValidateFlags(t *testing.T) {
tc.dash + "port=:8080",
tc.dash + "extProcExtraEnvVars=OTEL_SERVICE_NAME=test;OTEL_TRACES_EXPORTER=console",
tc.dash + "spanRequestHeaderAttributes=x-session-id:session.id",
tc.dash + "endpointPrefixes=openai:/v1,cohere:/cohere/v2,anthropic:/anthropic/v1",
tc.dash + "maxRecvMsgSize=33554432",
tc.dash + "watchNamespaces=default,envoy-ai-gateway-system",
tc.dash + "cacheSyncTimeout=5m",
Expand All @@ -65,6 +66,7 @@ func Test_parseAndValidateFlags(t *testing.T) {
require.Equal(t, ":8080", f.extensionServerPort)
require.Equal(t, "OTEL_SERVICE_NAME=test;OTEL_TRACES_EXPORTER=console", f.extProcExtraEnvVars)
require.Equal(t, "x-session-id:session.id", f.spanRequestHeaderAttributes)
require.Equal(t, "openai:/v1,cohere:/cohere/v2,anthropic:/anthropic/v1", f.endpointPrefixes)
require.Equal(t, 32*1024*1024, f.maxRecvMsgSize)
require.Equal(t, []string{"default", "envoy-ai-gateway-system"}, f.watchNamespaces)
require.Equal(t, 5*time.Minute, f.cacheSyncTimeout)
Expand Down Expand Up @@ -136,6 +138,16 @@ func Test_parseAndValidateFlags(t *testing.T) {
flags: []string{"--spanRequestHeaderAttributes=:session.id"},
expErr: "invalid tracing header attributes",
},
{
name: "invalid endpointPrefixes - unknown key",
flags: []string{"--endpointPrefixes=foo:/x"},
expErr: "invalid endpoint prefixes",
},
{
name: "invalid endpointPrefixes - missing colon",
flags: []string{"--endpointPrefixes=openai"},
expErr: "invalid endpoint prefixes",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := parseAndValidateFlags(tc.flags)
Expand Down
34 changes: 27 additions & 7 deletions cmd/extproc/mainlib/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ type extProcFlags struct {
rootPrefix string
// maxRecvMsgSize is the maximum message size in bytes that the gRPC server can receive.
maxRecvMsgSize int
// endpointPrefixes is the comma-separated key-value pairs for endpoint prefixes.
endpointPrefixes string
}

// parseAndValidateFlags parses and validates the flags passed to the external processor.
Expand Down Expand Up @@ -96,6 +98,11 @@ func parseAndValidateFlags(args []string) (extProcFlags, error) {
"/",
"The root path prefix for all the processors.",
)
fs.StringVar(&flags.endpointPrefixes,
"endpointPrefixes",
"",
"Comma-separated key-value pairs for endpoint prefixes. Format: openai:/,cohere:/cohere,anthropic:/anthropic.",
)
fs.IntVar(&flags.maxRecvMsgSize,
"maxRecvMsgSize",
4*1024*1024,
Expand Down Expand Up @@ -132,6 +139,11 @@ func parseAndValidateFlags(args []string) (extProcFlags, error) {
errs = append(errs, fmt.Errorf("failed to parse tracing header mapping: %w", err))
}
}
if flags.endpointPrefixes != "" {
if _, err := internalapi.ParseEndpointPrefixes(flags.endpointPrefixes); err != nil {
errs = append(errs, fmt.Errorf("failed to parse endpoint prefixes: %w", err))
}
}

return flags, errors.Join(errs...)
}
Expand Down Expand Up @@ -217,6 +229,14 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
return fmt.Errorf("failed to parse tracing header mapping: %w", err)
}

// Parse endpoint prefixes and apply defaults for any missing values.
endpointPrefixes, err := internalapi.ParseEndpointPrefixes(flags.endpointPrefixes)
if err != nil {
return fmt.Errorf("failed to parse endpoint prefixes: %w", err)
}
// Set defaults for any missing endpoint prefixes.
endpointPrefixes.SetDefaults()

// Create Prometheus registry and reader which automatically converts
// attribute to Prometheus-compatible format (e.g. dots to underscores).
promRegistry := prometheus.NewRegistry()
Expand Down Expand Up @@ -247,13 +267,13 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
if err != nil {
return fmt.Errorf("failed to create external processor server: %w", err)
}
server.Register(path.Join(flags.rootPrefix, "/v1/chat/completions"), extproc.ChatCompletionProcessorFactory(chatCompletionMetrics))
server.Register(path.Join(flags.rootPrefix, "/v1/completions"), extproc.CompletionsProcessorFactory(completionMetrics))
server.Register(path.Join(flags.rootPrefix, "/v1/embeddings"), extproc.EmbeddingsProcessorFactory(embeddingsMetrics))
server.Register(path.Join(flags.rootPrefix, "/v1/images/generations"), extproc.ImageGenerationProcessorFactory(imageGenerationMetrics))
server.Register(path.Join(flags.rootPrefix, "/cohere/v2/rerank"), extproc.RerankProcessorFactory(rerankMetrics))
server.Register(path.Join(flags.rootPrefix, "/v1/models"), extproc.NewModelsProcessor)
server.Register(path.Join(flags.rootPrefix, "/anthropic/v1/messages"), extproc.MessagesProcessorFactory(messagesMetrics))
server.Register(path.Join(flags.rootPrefix, *endpointPrefixes.OpenAI, "/v1/chat/completions"), extproc.ChatCompletionProcessorFactory(chatCompletionMetrics))
server.Register(path.Join(flags.rootPrefix, *endpointPrefixes.OpenAI, "/v1/completions"), extproc.CompletionsProcessorFactory(completionMetrics))
server.Register(path.Join(flags.rootPrefix, *endpointPrefixes.OpenAI, "/v1/embeddings"), extproc.EmbeddingsProcessorFactory(embeddingsMetrics))
server.Register(path.Join(flags.rootPrefix, *endpointPrefixes.OpenAI, "/v1/images/generations"), extproc.ImageGenerationProcessorFactory(imageGenerationMetrics))
server.Register(path.Join(flags.rootPrefix, *endpointPrefixes.Cohere, "/v2/rerank"), extproc.RerankProcessorFactory(rerankMetrics))
server.Register(path.Join(flags.rootPrefix, *endpointPrefixes.OpenAI, "/v1/models"), extproc.NewModelsProcessor)
server.Register(path.Join(flags.rootPrefix, *endpointPrefixes.Anthropic, "/v1/messages"), extproc.MessagesProcessorFactory(messagesMetrics))

if watchErr := extproc.StartConfigWatcher(ctx, flags.configPath, server, l, time.Second*5); watchErr != nil {
return fmt.Errorf("failed to start config watcher: %w", watchErr)
Expand Down
18 changes: 18 additions & 0 deletions cmd/extproc/mainlib/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ func Test_parseAndValidateFlags(t *testing.T) {
rootPrefix: "/foo/bar/",
logLevel: slog.LevelDebug,
},
{
name: "with endpoint prefixes",
args: []string{"-configPath", "/path/to/config.yaml", "-endpointPrefixes", "openai:/,cohere:/cohere,anthropic:/anthropic"},
configPath: "/path/to/config.yaml",
addr: ":1063",
rootPrefix: "/",
logLevel: slog.LevelInfo,
},
{
name: "with header mapping",
args: []string{
Expand Down Expand Up @@ -150,6 +158,16 @@ func Test_parseAndValidateFlags(t *testing.T) {
args: []string{"-logLevel", "invalid"},
expectedError: "configPath must be provided\nfailed to unmarshal log level: slog: level string \"invalid\": unknown name",
},
{
name: "invalid endpoint prefixes - unknown key",
args: []string{"-configPath", "/path/to/config.yaml", "-endpointPrefixes", "foo:/x"},
expectedError: "failed to parse endpoint prefixes: unknown endpointPrefixes key \"foo\" at position 1 (allowed: openai, cohere, anthropic)",
},
{
name: "invalid endpoint prefixes - missing colon",
args: []string{"-configPath", "/path/to/config.yaml", "-endpointPrefixes", "openai"},
expectedError: "failed to parse endpoint prefixes: invalid endpointPrefixes pair at position 1: \"openai\" (expected format: key:value)",
},
{
name: "invalid tracing header attributes - missing colon",
args: []string{"-configPath", "/path/to/config.yaml", "-spanRequestHeaderAttributes", "x-session-id"},
Expand Down
3 changes: 3 additions & 0 deletions internal/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ type Options struct {
ExtProcMaxRecvMsgSize int
// MCPSessionEncryptionSeed is the seed used to derive the encryption key for MCP session encryption.
MCPSessionEncryptionSeed string
// EndpointPrefixes is the comma-separated key-value pairs for endpoint prefixes.
EndpointPrefixes string
}

// StartControllers starts the controllers for the AI Gateway.
Expand Down Expand Up @@ -220,6 +222,7 @@ func StartControllers(ctx context.Context, mgr manager.Manager, config *rest.Con
options.MetricsRequestHeaderAttributes,
options.TracingRequestHeaderAttributes,
options.RootPrefix,
options.EndpointPrefixes,
options.ExtProcExtraEnvVars,
options.ExtProcImagePullSecrets,
options.ExtProcMaxRecvMsgSize,
Expand Down
9 changes: 8 additions & 1 deletion internal/controller/gateway_mutator.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type gatewayMutator struct {
metricsRequestHeaderAttributes string
spanRequestHeaderAttributes string
rootPrefix string
endpointPrefixes string
extProcExtraEnvVars []corev1.EnvVar
extProcImagePullSecrets []corev1.LocalObjectReference
extProcMaxRecvMsgSize int
Expand All @@ -55,7 +56,7 @@ type gatewayMutator struct {

func newGatewayMutator(c client.Client, kube kubernetes.Interface, logger logr.Logger,
extProcImage string, extProcImagePullPolicy corev1.PullPolicy, extProcLogLevel,
udsPath, metricsRequestHeaderAttributes, spanRequestHeaderAttributes, rootPrefix, extProcExtraEnvVars, extProcImagePullSecrets string, extProcMaxRecvMsgSize int,
udsPath, metricsRequestHeaderAttributes, spanRequestHeaderAttributes, rootPrefix, endpointPrefixes, extProcExtraEnvVars, extProcImagePullSecrets string, extProcMaxRecvMsgSize int,
extProcAsSideCar bool,
mcpSessionEncryptionSeed string,
) *gatewayMutator {
Expand Down Expand Up @@ -90,6 +91,7 @@ func newGatewayMutator(c client.Client, kube kubernetes.Interface, logger logr.L
metricsRequestHeaderAttributes: metricsRequestHeaderAttributes,
spanRequestHeaderAttributes: spanRequestHeaderAttributes,
rootPrefix: rootPrefix,
endpointPrefixes: endpointPrefixes,
extProcExtraEnvVars: parsedEnvVars,
extProcImagePullSecrets: parsedImagePullSecrets,
extProcMaxRecvMsgSize: extProcMaxRecvMsgSize,
Expand Down Expand Up @@ -142,6 +144,11 @@ func (g *gatewayMutator) buildExtProcArgs(filterConfigFullPath string, extProcAd
args = append(args, "-spanRequestHeaderAttributes", g.spanRequestHeaderAttributes)
}

// Add endpoint prefixes mapping if configured.
if g.endpointPrefixes != "" {
args = append(args, "-endpointPrefixes", g.endpointPrefixes)
}

return args
}

Expand Down
20 changes: 16 additions & 4 deletions internal/controller/gateway_mutator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
func TestGatewayMutator_Default(t *testing.T) {
fakeClient := requireNewFakeClientWithIndexes(t)
fakeKube := fake2.NewClientset()
g := newTestGatewayMutator(fakeClient, fakeKube, "", "", "", "", false)
g := newTestGatewayMutator(fakeClient, fakeKube, "", "", "", "", "", false)
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "test-pod", Namespace: "test-namespace"},
Spec: corev1.PodSpec{
Expand All @@ -49,6 +49,7 @@ func TestGatewayMutator_mutatePod(t *testing.T) {
name string
metricsRequestHeaderAttributes string
spanRequestHeaderAttributes string
endpointPrefixes string
extProcExtraEnvVars string
extProcImagePullSecrets string
extprocTest func(t *testing.T, container corev1.Container)
Expand Down Expand Up @@ -83,6 +84,17 @@ func TestGatewayMutator_mutatePod(t *testing.T) {
require.True(t, foundMCPSeed)
},
},
{
name: "with endpoint prefixes",
endpointPrefixes: "openai:/v1,cohere:/cohere/v2,anthropic:/anthropic/v1",
extprocTest: func(t *testing.T, container corev1.Container) {
require.Contains(t, container.Args, "-endpointPrefixes")
require.Contains(t, container.Args, "openai:/v1,cohere:/cohere/v2,anthropic:/anthropic/v1")
},
podTest: func(t *testing.T, pod corev1.Pod) {
require.Empty(t, pod.Spec.ImagePullSecrets)
},
},
{
name: "with extra env vars",
extProcExtraEnvVars: "OTEL_SERVICE_NAME=ai-gateway-extproc;OTEL_TRACES_EXPORTER=otlp",
Expand Down Expand Up @@ -191,7 +203,7 @@ func TestGatewayMutator_mutatePod(t *testing.T) {
t.Run(fmt.Sprintf("sidecar=%v", sidecar), func(t *testing.T) {
fakeClient := requireNewFakeClientWithIndexes(t)
fakeKube := fake2.NewClientset()
g := newTestGatewayMutator(fakeClient, fakeKube, tt.metricsRequestHeaderAttributes, tt.spanRequestHeaderAttributes, tt.extProcExtraEnvVars, tt.extProcImagePullSecrets, sidecar)
g := newTestGatewayMutator(fakeClient, fakeKube, tt.metricsRequestHeaderAttributes, tt.spanRequestHeaderAttributes, tt.endpointPrefixes, tt.extProcExtraEnvVars, tt.extProcImagePullSecrets, sidecar)

const gwName, gwNamespace = "test-gateway", "test-namespace"
err := fakeClient.Create(t.Context(), &aigv1a1.AIGatewayRoute{
Expand Down Expand Up @@ -274,11 +286,11 @@ func TestGatewayMutator_mutatePod(t *testing.T) {
}
}

func newTestGatewayMutator(fakeClient client.Client, fakeKube *fake2.Clientset, metricsRequestHeaderAttributes, spanRequestHeaderAttributes, extProcExtraEnvVars, extProcImagePullSecrets string, sidecar bool) *gatewayMutator {
func newTestGatewayMutator(fakeClient client.Client, fakeKube *fake2.Clientset, metricsRequestHeaderAttributes, spanRequestHeaderAttributes, endpointPrefixes, extProcExtraEnvVars, extProcImagePullSecrets string, sidecar bool) *gatewayMutator {
ctrl.SetLogger(zap.New(zap.UseFlagOptions(&zap.Options{Development: true, Level: zapcore.DebugLevel})))
return newGatewayMutator(
fakeClient, fakeKube, ctrl.Log, "docker.io/envoyproxy/ai-gateway-extproc:latest", corev1.PullIfNotPresent,
"info", "/tmp/extproc.sock", metricsRequestHeaderAttributes, spanRequestHeaderAttributes, "/v1", extProcExtraEnvVars, extProcImagePullSecrets, 512*1024*1024,
"info", "/tmp/extproc.sock", metricsRequestHeaderAttributes, spanRequestHeaderAttributes, "/v1", endpointPrefixes, extProcExtraEnvVars, extProcImagePullSecrets, 512*1024*1024,
sidecar, "seed",
)
}
Expand Down
76 changes: 76 additions & 0 deletions internal/internalapi/internalapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,82 @@ func ParseRequestHeaderAttributeMapping(s string) (map[string]string, error) {
return result, nil
}

// EndpointPrefixes represents well-known endpoint prefixes that AI Gateway supports.
// Only these keys are recognized when parsing the endpointPrefixes flag/value.
type EndpointPrefixes struct {
OpenAI *string
Cohere *string
Anthropic *string
}

// SetDefaults populates empty fields with default provider prefixes (without version components).
// Defaults:
//
// openai -> /
// cohere -> /cohere
// anthropic -> /anthropic
func (e *EndpointPrefixes) SetDefaults() {
if e.OpenAI == nil {
prefix := "/"
e.OpenAI = &prefix
}
if e.Cohere == nil {
prefix := "/cohere"
e.Cohere = &prefix
}
if e.Anthropic == nil {
prefix := "/anthropic"
e.Anthropic = &prefix
}
}

// ParseEndpointPrefixes parses a comma-separated list of key:value pairs to populate EndpointPrefixes.
//
// Recognized keys (case-sensitive):
// - openai
// - cohere
// - anthropic
//
// Format example:
//
// "openai:/,cohere:/cohere,anthropic:/anthropic"
//
// Unknown keys cause an error; values must be non-empty.
func ParseEndpointPrefixes(s string) (EndpointPrefixes, error) {
var out EndpointPrefixes
if s == "" {
return out, nil
}

pairs := strings.Split(s, ",")
for i, pair := range pairs {
pair = strings.TrimSpace(pair)
if pair == "" {
return EndpointPrefixes{}, fmt.Errorf("empty endpointPrefixes pair at position %d", i+1)
}

parts := strings.SplitN(pair, ":", 2)
if len(parts) != 2 {
return EndpointPrefixes{}, fmt.Errorf("invalid endpointPrefixes pair at position %d: %q (expected format: key:value)", i+1, pair)
}

key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])

switch key {
case "openai":
out.OpenAI = &value
case "cohere":
out.Cohere = &value
case "anthropic":
out.Anthropic = &value
default:
return EndpointPrefixes{}, fmt.Errorf("unknown endpointPrefixes key %q at position %d (allowed: openai, cohere, anthropic)", key, i+1)
}
}
return out, nil
}

// ModelNameHeaderKeyDefault is the default header key for the model name.
const ModelNameHeaderKeyDefault = aigv1a1.AIModelHeaderKey

Expand Down
Loading
Loading