From 85d9cc3a14192f86ec4c6a1f51161f1ede165717 Mon Sep 17 00:00:00 2001 From: yxia216 Date: Thu, 6 Nov 2025 01:38:42 -0500 Subject: [PATCH 1/6] fix-usage-chunk Signed-off-by: yxia216 --- .../extproc/translator/openai_gcpvertexai.go | 53 +++++++++++++------ .../translator/openai_gcpvertexai_test.go | 4 +- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/internal/extproc/translator/openai_gcpvertexai.go b/internal/extproc/translator/openai_gcpvertexai.go index 94f1d857d..1511b6409 100644 --- a/internal/extproc/translator/openai_gcpvertexai.go +++ b/internal/extproc/translator/openai_gcpvertexai.go @@ -175,16 +175,6 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) handleStreamingResponse( // Convert GCP chunk to OpenAI chunk. openAIChunk := o.convertGCPChunkToOpenAI(chunk) - // Extract token usage if present in this chunk (typically in the last chunk). - if chunk.UsageMetadata != nil { - tokenUsage = LLMTokenUsage{ - InputTokens: uint32(chunk.UsageMetadata.PromptTokenCount), //nolint:gosec - OutputTokens: uint32(chunk.UsageMetadata.CandidatesTokenCount), //nolint:gosec - TotalTokens: uint32(chunk.UsageMetadata.TotalTokenCount), //nolint:gosec - CachedInputTokens: uint32(chunk.UsageMetadata.CachedContentTokenCount), //nolint:gosec - } - } - // Serialize to SSE format as expected by OpenAI API. var chunkBytes []byte chunkBytes, err = json.Marshal(openAIChunk) @@ -198,6 +188,40 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) handleStreamingResponse( if span != nil { span.RecordResponseChunk(openAIChunk) } + + // Extract token usage only in the last chunk. + if chunk.UsageMetadata != nil && chunk.UsageMetadata.PromptTokenCount > 0 { + // Convert usage to pointer if available. + usage := ptr.To(geminiUsageToOpenAIUsage(chunk.UsageMetadata)) + + usageChunk := openai.ChatCompletionResponseChunk{ + Object: "chat.completion.chunk", + Choices: []openai.ChatCompletionResponseChunkChoice{}, + // usage is nil for all chunks other than the last chunk + Usage: usage, + } + + // Serialize to SSE format as expected by OpenAI API. + var chunkBytes []byte + chunkBytes, err = json.Marshal(usageChunk) + if err != nil { + return nil, nil, LLMTokenUsage{}, "", fmt.Errorf("error marshaling OpenAI chunk: %w", err) + } + sseChunkBuf.WriteString("data: ") + sseChunkBuf.Write(chunkBytes) + sseChunkBuf.WriteString("\n\n") + + if span != nil { + span.RecordResponseChunk(openAIChunk) + } + + tokenUsage = LLMTokenUsage{ + InputTokens: uint32(chunk.UsageMetadata.PromptTokenCount), //nolint:gosec + OutputTokens: uint32(chunk.UsageMetadata.CandidatesTokenCount), //nolint:gosec + TotalTokens: uint32(chunk.UsageMetadata.TotalTokenCount), //nolint:gosec + CachedInputTokens: uint32(chunk.UsageMetadata.CachedContentTokenCount), //nolint:gosec + } + } } mut := &extprocv3.BodyMutation_Body{ Body: sseChunkBuf.Bytes(), @@ -251,16 +275,11 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) convertGCPChunkToOpenAI( choices = []openai.ChatCompletionResponseChunkChoice{} } - // Convert usage to pointer if available. - var usage *openai.Usage - if chunk.UsageMetadata != nil { - usage = ptr.To(geminiUsageToOpenAIUsage(chunk.UsageMetadata)) - } - return &openai.ChatCompletionResponseChunk{ Object: "chat.completion.chunk", Choices: choices, - Usage: usage, + // usage is nil for all chunks other than the last chunk + Usage: nil, } } diff --git a/internal/extproc/translator/openai_gcpvertexai_test.go b/internal/extproc/translator/openai_gcpvertexai_test.go index e374e05c6..75dbe110d 100644 --- a/internal/extproc/translator/openai_gcpvertexai_test.go +++ b/internal/extproc/translator/openai_gcpvertexai_test.go @@ -1054,7 +1054,9 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_ResponseBody(t *testing.T wantHeaderMut: nil, wantBodyMut: &extprocv3.BodyMutation{ Mutation: &extprocv3.BodyMutation_Body{ - Body: []byte(`data: {"choices":[{"index":0,"delta":{"content":"Hello","role":"assistant"}}],"object":"chat.completion.chunk","usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8,"completion_tokens_details":{},"prompt_tokens_details":{}}} + Body: []byte(`data: {"choices":[{"index":0,"delta":{"content":"Hello","role":"assistant"}}],"object":"chat.completion.chunk"} + +data: {"object":"chat.completion.chunk","usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8,"completion_tokens_details":{},"prompt_tokens_details":{}}} data: [DONE] `), From 8e23bb243e21580c991dee5c5cf0eca8a5448928 Mon Sep 17 00:00:00 2001 From: CYJiang <86391540+googs1025@users.noreply.github.com> Date: Wed, 5 Nov 2025 18:17:58 +0800 Subject: [PATCH 2/6] fix: ai gateway mutating webhook should default failurePolicy to Fail (#1494) **Description** - ai gateway mutating webhook should default failurePolicy to Fail **Related Issues/PRs (if applicable)** fixes: https://github.com/envoyproxy/ai-gateway/issues/1493 **Special notes for reviewers (if applicable)** Signed-off-by: googs1025 Signed-off-by: yxia216 --- .../charts/ai-gateway-helm/templates/admission_webhook.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manifests/charts/ai-gateway-helm/templates/admission_webhook.yaml b/manifests/charts/ai-gateway-helm/templates/admission_webhook.yaml index 5e9a6bd8c..24a9fff61 100644 --- a/manifests/charts/ai-gateway-helm/templates/admission_webhook.yaml +++ b/manifests/charts/ai-gateway-helm/templates/admission_webhook.yaml @@ -30,7 +30,7 @@ webhooks: sideEffects: None admissionReviewVersions: ["v1"] timeoutSeconds: 10 - failurePolicy: Ignore + failurePolicy: Fail --- {{- if .Values.controller.mutatingWebhook.certManager.enable }} apiVersion: cert-manager.io/v1 From a2405440ceb58e01babaebd3aea1c99ad54256f2 Mon Sep 17 00:00:00 2001 From: CYJiang <86391540+googs1025@users.noreply.github.com> Date: Wed, 5 Nov 2025 19:46:59 +0800 Subject: [PATCH 3/6] docs: fix example in examples/inference-pool (#1488) **Description** fix: https://github.com/envoyproxy/ai-gateway/issues/1485 **Related Issues/PRs (if applicable)** **Special notes for reviewers (if applicable)** Signed-off-by: googs1025 Signed-off-by: yxia216 --- examples/inference-pool/README.md | 70 +++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 7 deletions(-) diff --git a/examples/inference-pool/README.md b/examples/inference-pool/README.md index 5920ca096..fbbcfa5ea 100644 --- a/examples/inference-pool/README.md +++ b/examples/inference-pool/README.md @@ -2,10 +2,19 @@ This example demonstrates how to use AI Gateway with the InferencePool feature, which enables intelligent request routing across multiple inference endpoints with load balancing and health checking capabilities. +The setup includes **three distinct backends**: + +- Two `InferencePool` resources for LLMs (`Llama-3.1-8B-Instruct` and `Mistral`) +- One standard `Backend` for non-InferencePool traffic + +Routing is controlled by the `x-ai-eg-model` HTTP header. + ## Files in This Directory - **`envoy-gateway-values-addon.yaml`**: Envoy Gateway values addon for InferencePool support. Combine with `../../manifests/envoy-gateway-values.yaml`. -- **`base.yaml`**: Complete example that includes Gateway, AIServiceBackend, InferencePool CRDs, and a sample application deployment. +- **`base.yaml`**: Deploys all inference backends and supporting resources using the **standard approach documented in the official guide**. This includes: + - A `mistral` backend with custom Endpoint Picker configuration + - A standard fallback backend (`envoy-ai-gateway-basic-testupstream`) for non-InferencePool routing - **`aigwroute.yaml`**: Example AIGatewayRoute that uses InferencePool as a backend. - **`httproute.yaml`**: Example HTTPRoute for traditional HTTP routing to InferencePool endpoints. - **`with-annotations.yaml`**: Advanced example showing InferencePool with Kubernetes annotations for fine-grained control. @@ -27,16 +36,63 @@ This example demonstrates how to use AI Gateway with the InferencePool feature, ```bash kubectl apply -f base.yaml + kubectl apply -f aigwroute.yaml ``` + > Note: The `aigwroute.yaml` file defines the InferencePool and routing logic, but does not deploy the actual inference backend (e.g., the vLLM server for Llama-3.1-8B-Instruct). + > You must deploy the backend separately by following [Step 3: Deploy Inference Backends](https://aigateway.envoyproxy.io/docs/capabilities/inference/aigatewayroute-inferencepool#step-3-deploy-inference-backends) + 3. Test the setup: - ```bash - GATEWAY_HOST=$(kubectl get gateway/ai-gateway -o jsonpath='{.status.addresses[0].value}') - curl -X POST "http://${GATEWAY_HOST}/v1/chat/completions" \ - -H "Content-Type: application/json" \ - -d '{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello!"}]}' - ``` +You can access the gateway in two ways, depending on your environment. + +✅ Option A: Using External IP (e.g., cloud LoadBalancer, MetalLB) +If your cluster assigns an external address to the Gateway: + +```bash +GATEWAY_HOST=$(kubectl get gateway/inference-pool-with-aigwroute -n default -o jsonpath='{.status.addresses[0].value}') +echo "Gateway available at: http://${GATEWAY_HOST}" +``` + +Then send a request: + +```bash +curl -X POST "http://${GATEWAY_HOST}/v1/chat/completions" \ + -H "x-ai-eg-model: meta-llama/Llama-3.1-8B-Instruct" \ + -H "Authorization: sk-abcdefghijklmnopqrstuvwxyz" \ + -H "Content-Type: application/json" \ + -d '{"model": "meta-llama/Llama-3.1-8B-Instruct", "messages": [{"role": "user", "content": "Hello!"}]}' +``` + +✅ Option B: Using kubectl port-forward (ideal for local clusters like Minikube/Kind) +In one terminal, forward the gateway service: + +```bash +kubectl port-forward svc/envoy-default-inference-pool-with-aigwroute-d416582c 8080:80 -n envoy-gateway-system +``` + +In another terminal, send requests to localhost:8080: + +```bash +# Route to Llama (InferencePool) +curl -X POST "http://localhost:8080/v1/chat/completions" \ + -H "x-ai-eg-model: meta-llama/Llama-3.1-8B-Instruct" \ + -H "Authorization: sk-abcdefghijklmnopqrstuvwxyz" \ + -H "Content-Type: application/json" \ + -d '{"model": "meta-llama/Llama-3.1-8B-Instruct", "messages": [{"role": "user", "content": "Hello!"}]}' + +# Route to Mistral (InferencePool) +curl -X POST "http://localhost:8080/v1/chat/completions" \ + -H "x-ai-eg-model: mistral:latest" \ + -H "Content-Type: application/json" \ + -d '{"model": "mistral:latest", "messages": [{"role": "user", "content": "Hello!"}]}' + +# Route to fallback backend (Standard Backend) +curl -X POST "http://localhost:8080/v1/chat/completions" \ + -H "x-ai-eg-model: some-cool-self-hosted-model" \ + -H "Content-Type: application/json" \ + -d '{"model": "some-cool-self-hosted-model", "messages": [{"role": "user", "content": "Hello!"}]}' +``` ### Combining with Other Features From 820525dce8afd61d228abb19040055787b8eafa6 Mon Sep 17 00:00:00 2001 From: hustxiayang Date: Wed, 5 Nov 2025 07:32:42 -0500 Subject: [PATCH 4/6] fix: finish reason should be tool calls when the model responded with a tool call (#1486) **Description** Finish reason should be tool calls if the model returns a tool call response. In vertex api, there is no tool call finish reason, thus need a work around to make it compatible. --------- Signed-off-by: yxia216 Co-authored-by: Dan Sun Signed-off-by: yxia216 --- internal/extproc/translator/gemini_helper.go | 36 +++++--- .../extproc/translator/gemini_helper_test.go | 88 ++++++++++++------- tests/extproc/testupstream_test.go | 2 +- 3 files changed, 82 insertions(+), 44 deletions(-) diff --git a/internal/extproc/translator/gemini_helper.go b/internal/extproc/translator/gemini_helper.go index b752bf953..0b5ff68f1 100644 --- a/internal/extproc/translator/gemini_helper.go +++ b/internal/extproc/translator/gemini_helper.go @@ -529,10 +529,12 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode // Create the choice. choice := openai.ChatCompletionResponseChoice{ - Index: int64(idx), - FinishReason: geminiFinishReasonToOpenAI(candidate.FinishReason), + Index: int64(idx), } + toolCalls := []openai.ChatCompletionMessageToolCallParam{} + var err error + if candidate.Content != nil { message := openai.ChatCompletionResponseChoiceMessage{ Role: openai.ChatMessageRoleAssistant, @@ -542,7 +544,7 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode message.Content = &content // Extract tool calls if any. - toolCalls, err := extractToolCallsFromGeminiParts(candidate.Content.Parts) + toolCalls, err = extractToolCallsFromGeminiParts(toolCalls, candidate.Content.Parts) if err != nil { return nil, fmt.Errorf("error extracting tool calls: %w", err) } @@ -569,16 +571,26 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode choice.Logprobs = geminiLogprobsToOpenAILogprobs(*candidate.LogprobsResult) } + choice.FinishReason = geminiFinishReasonToOpenAI(candidate.FinishReason, toolCalls) + choices = append(choices, choice) } return choices, nil } +// Define a type constraint that includes both stream and non-stream tool call slice types. +type toolCallSlice interface { + []openai.ChatCompletionMessageToolCallParam | []openai.ChatCompletionChunkChoiceDeltaToolCall +} + // geminiFinishReasonToOpenAI converts Gemini finish reason to OpenAI finish reason. -func geminiFinishReasonToOpenAI(reason genai.FinishReason) openai.ChatCompletionChoicesFinishReason { +func geminiFinishReasonToOpenAI[T toolCallSlice](reason genai.FinishReason, toolCalls T) openai.ChatCompletionChoicesFinishReason { switch reason { case genai.FinishReasonStop: + if len(toolCalls) > 0 { + return openai.ChatCompletionChoicesFinishReasonToolCalls + } return openai.ChatCompletionChoicesFinishReasonStop case genai.FinishReasonMaxTokens: return openai.ChatCompletionChoicesFinishReasonLength @@ -611,9 +623,7 @@ func extractTextFromGeminiParts(parts []*genai.Part, responseMode geminiResponse } // extractToolCallsFromGeminiParts extracts tool calls from Gemini parts. -func extractToolCallsFromGeminiParts(parts []*genai.Part) ([]openai.ChatCompletionMessageToolCallParam, error) { - var toolCalls []openai.ChatCompletionMessageToolCallParam - +func extractToolCallsFromGeminiParts(toolCalls []openai.ChatCompletionMessageToolCallParam, parts []*genai.Part) ([]openai.ChatCompletionMessageToolCallParam, error) { for _, part := range parts { if part == nil || part.FunctionCall == nil { continue @@ -650,8 +660,7 @@ func extractToolCallsFromGeminiParts(parts []*genai.Part) ([]openai.ChatCompleti // extractToolCallsFromGeminiPartsStream extracts tool calls from Gemini parts for streaming responses. // Each tool call is assigned an incremental index starting from 0, matching OpenAI's streaming protocol. // Returns ChatCompletionChunkChoiceDeltaToolCall types suitable for streaming responses, or nil if no tool calls are found. -func extractToolCallsFromGeminiPartsStream(parts []*genai.Part) ([]openai.ChatCompletionChunkChoiceDeltaToolCall, error) { - var toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall +func extractToolCallsFromGeminiPartsStream(toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall, parts []*genai.Part) ([]openai.ChatCompletionChunkChoiceDeltaToolCall, error) { toolCallIndex := int64(0) for _, part := range parts { @@ -772,10 +781,11 @@ func geminiCandidatesToOpenAIStreamingChoices(candidates []*genai.Candidate, res // Create the streaming choice. choice := openai.ChatCompletionResponseChunkChoice{ - Index: 0, - FinishReason: geminiFinishReasonToOpenAI(candidate.FinishReason), + Index: 0, } + toolCalls := []openai.ChatCompletionChunkChoiceDeltaToolCall{} + var err error if candidate.Content != nil { delta := &openai.ChatCompletionResponseChunkChoiceDelta{ Role: openai.ChatMessageRoleAssistant, @@ -788,7 +798,7 @@ func geminiCandidatesToOpenAIStreamingChoices(candidates []*genai.Candidate, res } // Extract tool calls if any. - toolCalls, err := extractToolCallsFromGeminiPartsStream(candidate.Content.Parts) + toolCalls, err = extractToolCallsFromGeminiPartsStream(toolCalls, candidate.Content.Parts) if err != nil { return nil, fmt.Errorf("error extracting tool calls: %w", err) } @@ -796,7 +806,7 @@ func geminiCandidatesToOpenAIStreamingChoices(candidates []*genai.Candidate, res choice.Delta = delta } - + choice.FinishReason = geminiFinishReasonToOpenAI(candidate.FinishReason, toolCalls) choices = append(choices, choice) } diff --git a/internal/extproc/translator/gemini_helper_test.go b/internal/extproc/translator/gemini_helper_test.go index 681d132dc..bd3225526 100644 --- a/internal/extproc/translator/gemini_helper_test.go +++ b/internal/extproc/translator/gemini_helper_test.go @@ -1271,6 +1271,7 @@ func TestGeminiLogprobsToOpenAILogprobs(t *testing.T) { } func TestExtractToolCallsFromGeminiParts(t *testing.T) { + toolCalls := []openai.ChatCompletionMessageToolCallParam{} tests := []struct { name string input []*genai.Part @@ -1360,7 +1361,7 @@ func TestExtractToolCallsFromGeminiParts(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - calls, err := extractToolCallsFromGeminiParts(tt.input) + calls, err := extractToolCallsFromGeminiParts(toolCalls, tt.input) if tt.wantErr { require.Error(t, err) @@ -1381,56 +1382,80 @@ func TestExtractToolCallsFromGeminiParts(t *testing.T) { func TestGeminiFinishReasonToOpenAI(t *testing.T) { tests := []struct { - name string - input genai.FinishReason - expected openai.ChatCompletionChoicesFinishReason + name string + input genai.FinishReason + toolCalls []openai.ChatCompletionMessageToolCallParam + expected openai.ChatCompletionChoicesFinishReason }{ { - name: "stop reason", - input: genai.FinishReasonStop, - expected: openai.ChatCompletionChoicesFinishReasonStop, + name: "stop reason", + input: genai.FinishReasonStop, + toolCalls: []openai.ChatCompletionMessageToolCallParam{}, + expected: openai.ChatCompletionChoicesFinishReasonStop, + }, + { + name: "tool calls reason", + input: genai.FinishReasonStop, + toolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: ptr.To("tool_call_1"), + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "example_tool", + Arguments: "{\"param1\":\"value1\"}", + }, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + }, + }, + expected: openai.ChatCompletionChoicesFinishReasonToolCalls, }, { - name: "max tokens reason", - input: genai.FinishReasonMaxTokens, - expected: openai.ChatCompletionChoicesFinishReasonLength, + name: "max tokens reason", + input: genai.FinishReasonMaxTokens, + toolCalls: []openai.ChatCompletionMessageToolCallParam{}, + expected: openai.ChatCompletionChoicesFinishReasonLength, }, { - name: "empty reason for streaming", - input: "", - expected: "", + name: "empty reason for streaming", + input: "", + toolCalls: []openai.ChatCompletionMessageToolCallParam{}, + expected: "", }, { - name: "safety reason", - input: genai.FinishReasonSafety, - expected: openai.ChatCompletionChoicesFinishReasonContentFilter, + name: "safety reason", + input: genai.FinishReasonSafety, + toolCalls: []openai.ChatCompletionMessageToolCallParam{}, + expected: openai.ChatCompletionChoicesFinishReasonContentFilter, }, { - name: "recitation reason", - input: genai.FinishReasonRecitation, - expected: openai.ChatCompletionChoicesFinishReasonContentFilter, + name: "recitation reason", + input: genai.FinishReasonRecitation, + toolCalls: []openai.ChatCompletionMessageToolCallParam{}, + expected: openai.ChatCompletionChoicesFinishReasonContentFilter, }, { - name: "other reason", - input: genai.FinishReasonOther, - expected: openai.ChatCompletionChoicesFinishReasonContentFilter, + name: "other reason", + input: genai.FinishReasonOther, + toolCalls: []openai.ChatCompletionMessageToolCallParam{}, + expected: openai.ChatCompletionChoicesFinishReasonContentFilter, }, { - name: "unknown reason", - input: genai.FinishReason("unknown_reason"), - expected: openai.ChatCompletionChoicesFinishReasonContentFilter, + name: "unknown reason", + input: genai.FinishReason("unknown_reason"), + toolCalls: []openai.ChatCompletionMessageToolCallParam{}, + expected: openai.ChatCompletionChoicesFinishReasonContentFilter, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := geminiFinishReasonToOpenAI(tt.input) + result := geminiFinishReasonToOpenAI(tt.input, tt.toolCalls) require.Equal(t, tt.expected, result) }) } } func TestExtractToolCallsFromGeminiPartsStream(t *testing.T) { + toolCalls := []openai.ChatCompletionChunkChoiceDeltaToolCall{} tests := []struct { name string input []*genai.Part @@ -1675,7 +1700,7 @@ func TestExtractToolCallsFromGeminiPartsStream(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - calls, err := extractToolCallsFromGeminiPartsStream(tt.input) + calls, err := extractToolCallsFromGeminiPartsStream(toolCalls, tt.input) if tt.wantErr { require.Error(t, err) @@ -1696,6 +1721,8 @@ func TestExtractToolCallsFromGeminiPartsStream(t *testing.T) { // TestExtractToolCallsStreamVsNonStream tests the differences between streaming and non-streaming extraction func TestExtractToolCallsStreamVsNonStream(t *testing.T) { + toolCalls := []openai.ChatCompletionMessageToolCallParam{} + toolCallsStream := []openai.ChatCompletionChunkChoiceDeltaToolCall{} parts := []*genai.Part{ { FunctionCall: &genai.FunctionCall{ @@ -1709,11 +1736,11 @@ func TestExtractToolCallsStreamVsNonStream(t *testing.T) { } // Get results from both functions - streamCalls, err := extractToolCallsFromGeminiPartsStream(parts) + streamCalls, err := extractToolCallsFromGeminiPartsStream(toolCallsStream, parts) require.NoError(t, err) require.Len(t, streamCalls, 1) - nonStreamCalls, err := extractToolCallsFromGeminiParts(parts) + nonStreamCalls, err := extractToolCallsFromGeminiParts(toolCalls, parts) require.NoError(t, err) require.Len(t, nonStreamCalls, 1) @@ -1749,6 +1776,7 @@ func TestExtractToolCallsStreamVsNonStream(t *testing.T) { // TestExtractToolCallsStreamIndexing specifically tests that multiple tool calls get correct indices func TestExtractToolCallsStreamIndexing(t *testing.T) { + toolCalls := []openai.ChatCompletionChunkChoiceDeltaToolCall{} parts := []*genai.Part{ { FunctionCall: &genai.FunctionCall{ @@ -1771,7 +1799,7 @@ func TestExtractToolCallsStreamIndexing(t *testing.T) { }, } - calls, err := extractToolCallsFromGeminiPartsStream(parts) + calls, err := extractToolCallsFromGeminiPartsStream(toolCalls, parts) require.NoError(t, err) require.Len(t, calls, 3) diff --git a/tests/extproc/testupstream_test.go b/tests/extproc/testupstream_test.go index a8c415d77..75d8abc8a 100644 --- a/tests/extproc/testupstream_test.go +++ b/tests/extproc/testupstream_test.go @@ -297,7 +297,7 @@ func TestWithTestUpstream(t *testing.T) { responseStatus: strconv.Itoa(http.StatusOK), responseBody: `{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_delivery_date","args":{"order_id":"123"}}}]},"finishReason":"STOP","avgLogprobs":0.000001220789272338152}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":11,"totalTokenCount":61,"trafficType":"ON_DEMAND","promptTokensDetails":[{"modality":"TEXT","tokenCount":50}],"candidatesTokensDetails":[{"modality":"TEXT","tokenCount":11}]},"modelVersion":"gemini-2.0-flash-001","createTime":"2025-07-11T22:15:44.956335Z","responseId":"EI5xaK-vOtqJm22IPmuCR14AI"}`, expStatus: http.StatusOK, - expResponseBody: `{"choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","tool_calls":[{"id":"703482f8-2e5b-4dcc-a872-d74bd66c3866","function":{"arguments":"{\"order_id\":\"123\"}","name":"get_delivery_date"},"type":"function"}]}}],"model":"gemini-2.0-flash-001","object":"chat.completion","usage":{"completion_tokens":11,"completion_tokens_details":{},"prompt_tokens":50,"total_tokens":61,"prompt_tokens_details":{}}}`, + expResponseBody: `{"choices":[{"finish_reason":"tool_calls","index":0,"message":{"role":"assistant","tool_calls":[{"id":"703482f8-2e5b-4dcc-a872-d74bd66c3866","function":{"arguments":"{\"order_id\":\"123\"}","name":"get_delivery_date"},"type":"function"}]}}],"model":"gemini-2.0-flash-001","object":"chat.completion","usage":{"completion_tokens":11,"completion_tokens_details":{},"prompt_tokens":50,"total_tokens":61,"prompt_tokens_details":{}}}`, }, { name: "gcp-anthropicai - /v1/chat/completions", From 35bff4cce3d40a17ff385c6d8bd41bd7f03c4f5c Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Wed, 5 Nov 2025 12:29:09 -0800 Subject: [PATCH 5/6] refactor: decouples backendauth & headermutator from extproc (#1491) **Description** This decouples backendauth & headermutator packages from extproc specifics. As we are looking to migrate to dynamic modules, this is a necessary refactoring work to make the code as reusable as possible. **Related Issues/PRs (if applicable)** Preliminary for #90 --------- Signed-off-by: Takeshi Yoneda Signed-off-by: yxia216 --- .../backendauth/anthropicapikey.go | 14 +- .../backendauth/anthropicapikey_test.go | 16 +- internal/{extproc => }/backendauth/api_key.go | 11 +- .../{extproc => }/backendauth/api_key_test.go | 25 +-- internal/{extproc => }/backendauth/auth.go | 8 +- .../{extproc => }/backendauth/auth_test.go | 0 internal/{extproc => }/backendauth/aws.go | 39 ++-- .../{extproc => }/backendauth/aws_test.go | 100 +++------- internal/{extproc => }/backendauth/azure.go | 11 +- .../{extproc => }/backendauth/azure_test.go | 28 +-- .../{extproc => }/backendauth/azureapikey.go | 14 +- .../backendauth/azureapikey_test.go | 18 +- internal/{extproc => }/backendauth/gcp.go | 46 +---- internal/backendauth/gcp_test.go | 116 +++++++++++ .../extproc/backendauth/aws_bench_test.go | 107 ---------- internal/extproc/backendauth/gcp_test.go | 187 ------------------ internal/extproc/chatcompletion_processor.go | 27 ++- .../extproc/chatcompletion_processor_test.go | 2 +- internal/extproc/completions_processor.go | 27 ++- .../extproc/completions_processor_test.go | 2 +- internal/extproc/embeddings_processor.go | 27 ++- internal/extproc/embeddings_processor_test.go | 50 +---- internal/extproc/imagegeneration_processor.go | 27 ++- internal/extproc/messages_processor.go | 27 ++- internal/extproc/messages_processor_test.go | 47 +---- internal/extproc/mocks_test.go | 10 +- internal/extproc/processor.go | 2 +- internal/extproc/rerank_processor.go | 27 ++- internal/extproc/rerank_processor_test.go | 2 +- internal/extproc/server.go | 2 +- .../headermutator/header_mutator.go | 24 +-- .../headermutator/header_mutator_test.go | 23 ++- internal/internalapi/headers.go | 19 ++ 33 files changed, 397 insertions(+), 688 deletions(-) rename internal/{extproc => }/backendauth/anthropicapikey.go (64%) rename internal/{extproc => }/backendauth/anthropicapikey_test.go (68%) rename internal/{extproc => }/backendauth/api_key.go (64%) rename internal/{extproc => }/backendauth/api_key_test.go (51%) rename internal/{extproc => }/backendauth/auth.go (80%) rename internal/{extproc => }/backendauth/auth_test.go (100%) rename internal/{extproc => }/backendauth/aws.go (78%) rename internal/{extproc => }/backendauth/aws_test.go (71%) rename internal/{extproc => }/backendauth/azure.go (64%) rename internal/{extproc => }/backendauth/azure_test.go (52%) rename internal/{extproc => }/backendauth/azureapikey.go (65%) rename internal/{extproc => }/backendauth/azureapikey_test.go (71%) rename internal/{extproc => }/backendauth/gcp.go (58%) create mode 100644 internal/backendauth/gcp_test.go delete mode 100644 internal/extproc/backendauth/aws_bench_test.go delete mode 100644 internal/extproc/backendauth/gcp_test.go rename internal/{extproc => }/headermutator/header_mutator.go (80%) rename internal/{extproc => }/headermutator/header_mutator_test.go (84%) create mode 100644 internal/internalapi/headers.go diff --git a/internal/extproc/backendauth/anthropicapikey.go b/internal/backendauth/anthropicapikey.go similarity index 64% rename from internal/extproc/backendauth/anthropicapikey.go rename to internal/backendauth/anthropicapikey.go index 69219f0ed..182b6f2dd 100644 --- a/internal/extproc/backendauth/anthropicapikey.go +++ b/internal/backendauth/anthropicapikey.go @@ -9,10 +9,8 @@ import ( "context" "strings" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" ) type anthropicAPIKeyHandler struct { @@ -27,13 +25,7 @@ func newAnthropicAPIKeyHandler(auth *filterapi.AnthropicAPIKeyAuth) (Handler, er // Anthropic uses "x-api-key" header instead of "Authorization: Bearer". // // https://docs.claude.com/en/api/overview#authentication -func (a *anthropicAPIKeyHandler) Do(_ context.Context, requestHeaders map[string]string, headerMut *extprocv3.HeaderMutation, _ *extprocv3.BodyMutation) error { +func (a *anthropicAPIKeyHandler) Do(_ context.Context, requestHeaders map[string]string, _ []byte) ([]internalapi.Header, error) { requestHeaders["x-api-key"] = a.apiKey - headerMut.SetHeaders = append(headerMut.SetHeaders, &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{ - Key: "x-api-key", - RawValue: []byte(a.apiKey), - }, - }) - return nil + return []internalapi.Header{{"x-api-key", a.apiKey}}, nil } diff --git a/internal/extproc/backendauth/anthropicapikey_test.go b/internal/backendauth/anthropicapikey_test.go similarity index 68% rename from internal/extproc/backendauth/anthropicapikey_test.go rename to internal/backendauth/anthropicapikey_test.go index 26e0e2d32..fdbf48f6b 100644 --- a/internal/extproc/backendauth/anthropicapikey_test.go +++ b/internal/backendauth/anthropicapikey_test.go @@ -9,7 +9,6 @@ import ( "context" "testing" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/stretchr/testify/require" "github.com/envoyproxy/ai-gateway/internal/filterapi" @@ -21,18 +20,17 @@ func TestAnthropicAPIKeyHandler(t *testing.T) { require.NoError(t, err) headers := make(map[string]string) - headerMut := &extprocv3.HeaderMutation{} - err = handler.Do(context.Background(), headers, headerMut, nil) + hders, err := handler.Do(context.Background(), headers, nil) require.NoError(t, err) // Verify header in map require.Equal(t, "test-azure-key", headers["x-api-key"]) // Verify header in mutation - require.Len(t, headerMut.SetHeaders, 1) - require.Equal(t, "x-api-key", headerMut.SetHeaders[0].Header.Key) - require.Equal(t, "test-azure-key", string(headerMut.SetHeaders[0].Header.RawValue)) + require.Len(t, hders, 1) + require.Equal(t, "x-api-key", hders[0][0]) + require.Equal(t, "test-azure-key", hders[0][1]) }) t.Run("trims whitespace", func(t *testing.T) { @@ -40,11 +38,13 @@ func TestAnthropicAPIKeyHandler(t *testing.T) { require.NoError(t, err) headers := make(map[string]string) - headerMut := &extprocv3.HeaderMutation{} - err = handler.Do(context.Background(), headers, headerMut, nil) + hdrs, err := handler.Do(context.Background(), headers, nil) require.NoError(t, err) require.Equal(t, "key-with-spaces", headers["x-api-key"]) + require.Len(t, hdrs, 1) + require.Equal(t, "x-api-key", hdrs[0][0]) + require.Equal(t, "key-with-spaces", hdrs[0][1]) }) } diff --git a/internal/extproc/backendauth/api_key.go b/internal/backendauth/api_key.go similarity index 64% rename from internal/extproc/backendauth/api_key.go rename to internal/backendauth/api_key.go index 67cc9aaf4..64b9f4ff5 100644 --- a/internal/extproc/backendauth/api_key.go +++ b/internal/backendauth/api_key.go @@ -10,10 +10,8 @@ import ( "fmt" "strings" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" ) // apiKeyHandler implements [Handler] for api key authz. @@ -28,10 +26,7 @@ func newAPIKeyHandler(auth *filterapi.APIKeyAuth) (Handler, error) { // Do implements [Handler.Do]. // // Extracts the api key from the local file and set it as an authorization header. -func (a *apiKeyHandler) Do(_ context.Context, requestHeaders map[string]string, headerMut *extprocv3.HeaderMutation, _ *extprocv3.BodyMutation) error { +func (a *apiKeyHandler) Do(_ context.Context, requestHeaders map[string]string, _ []byte) ([]internalapi.Header, error) { requestHeaders["Authorization"] = fmt.Sprintf("Bearer %s", a.apiKey) - headerMut.SetHeaders = append(headerMut.SetHeaders, &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{Key: "Authorization", RawValue: []byte(requestHeaders["Authorization"])}, - }) - return nil + return []internalapi.Header{{"Authorization", fmt.Sprintf("Bearer %s", a.apiKey)}}, nil } diff --git a/internal/extproc/backendauth/api_key_test.go b/internal/backendauth/api_key_test.go similarity index 51% rename from internal/extproc/backendauth/api_key_test.go rename to internal/backendauth/api_key_test.go index cf1bed2a1..023d47d26 100644 --- a/internal/extproc/backendauth/api_key_test.go +++ b/internal/backendauth/api_key_test.go @@ -8,8 +8,6 @@ package backendauth import ( "testing" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/stretchr/testify/require" "github.com/envoyproxy/ai-gateway/internal/filterapi" @@ -30,28 +28,15 @@ func TestApiKeyHandler_Do(t *testing.T) { require.NoError(t, err) require.NotNil(t, handler) - requestHeaders := map[string]string{":method": "POST"} - headerMut := &extprocv3.HeaderMutation{ - SetHeaders: []*corev3.HeaderValueOption{ - {Header: &corev3.HeaderValue{ - Key: ":path", - Value: "/model/some-random-model/converse", - }}, - }, - } - bodyMut := &extprocv3.BodyMutation{ - Mutation: &extprocv3.BodyMutation_Body{ - Body: []byte(`{"messages": [{"role": "user", "content": [{"text": "Say this is a test!"}]}]}`), - }, - } - err = handler.Do(t.Context(), requestHeaders, headerMut, bodyMut) + requestHeaders := map[string]string{":method": "POST", ":path": "/model/some-random-model/converse"} + hdrs, err := handler.Do(t.Context(), requestHeaders, nil) require.NoError(t, err) bearerToken, ok := requestHeaders["Authorization"] require.True(t, ok) require.Equal(t, "Bearer test", bearerToken) - require.Len(t, headerMut.SetHeaders, 2) - require.Equal(t, "Authorization", headerMut.SetHeaders[1].Header.Key) - require.Equal(t, []byte("Bearer test"), headerMut.SetHeaders[1].Header.GetRawValue()) + require.Len(t, hdrs, 1) + require.Equal(t, "Authorization", hdrs[0][0]) + require.Equal(t, "Bearer test", hdrs[0][1]) } diff --git a/internal/extproc/backendauth/auth.go b/internal/backendauth/auth.go similarity index 80% rename from internal/extproc/backendauth/auth.go rename to internal/backendauth/auth.go index 90e0192c8..8c34ee32b 100644 --- a/internal/extproc/backendauth/auth.go +++ b/internal/backendauth/auth.go @@ -9,17 +9,17 @@ import ( "context" "errors" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" ) // Handler is the interface that deals with the backend auth for a specific backend. // // TODO: maybe this can be just "post-transformation" handler, as it is not really only about auth. type Handler interface { - // Do performs the backend auth, and make changes to the request headers and body mutations. - Do(ctx context.Context, requestHeaders map[string]string, headerMut *extprocv3.HeaderMutation, bodyMut *extprocv3.BodyMutation) error + // Do performs the backend auth, and make changes to the request headers passed in as `requestHeaders`. + // It also returns a list of headers that were added or modified as a slice of key-value pairs. + Do(ctx context.Context, requestHeaders map[string]string, mutatedBody []byte) ([]internalapi.Header, error) } // NewHandler returns a new implementation of [Handler] based on the configuration. diff --git a/internal/extproc/backendauth/auth_test.go b/internal/backendauth/auth_test.go similarity index 100% rename from internal/extproc/backendauth/auth_test.go rename to internal/backendauth/auth_test.go diff --git a/internal/extproc/backendauth/aws.go b/internal/backendauth/aws.go similarity index 78% rename from internal/extproc/backendauth/aws.go rename to internal/backendauth/aws.go index 02c119e60..06ace0c59 100644 --- a/internal/extproc/backendauth/aws.go +++ b/internal/backendauth/aws.go @@ -15,15 +15,13 @@ import ( "os" "strings" "time" - "unsafe" "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/config" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" ) // awsHandler implements [Handler] for AWS Bedrock authz. @@ -84,26 +82,13 @@ func newAWSHandler(ctx context.Context, awsAuth *filterapi.AWSAuth) (Handler, er // // This assumes that during the transformation, the path is set in the header mutation as well as // the body in the body mutation. -func (a *awsHandler) Do(ctx context.Context, requestHeaders map[string]string, headerMut *extprocv3.HeaderMutation, bodyMut *extprocv3.BodyMutation) error { +func (a *awsHandler) Do(ctx context.Context, requestHeaders map[string]string, mutatedBody []byte) ([]internalapi.Header, error) { method := requestHeaders[":method"] - path := "" - if headerMut.SetHeaders != nil { - for _, h := range headerMut.SetHeaders { - if h.Header.Key == ":path" { - if len(h.Header.Value) > 0 { - path = h.Header.Value - } else { - rv := h.Header.RawValue - path = unsafe.String(&rv[0], len(rv)) - } - break - } - } - } + path := requestHeaders[":path"] var body []byte - if _body := bodyMut.GetBody(); len(_body) > 0 { - body = _body + if len(mutatedBody) > 0 { + body = mutatedBody } payloadHash := sha256.Sum256(body) @@ -111,7 +96,7 @@ func (a *awsHandler) Do(ctx context.Context, requestHeaders map[string]string, h fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com%s", a.region, path), bytes.NewReader(body)) if err != nil { - return fmt.Errorf("cannot create request: %w", err) + return nil, fmt.Errorf("cannot create request: %w", err) } // By setting the content length to -1, we can avoid the inclusion of the `Content-Length` header in the signature. // https://github.com/aws/aws-sdk-go-v2/blob/755839b2eebb246c7eec79b65404aee105196d5b/aws/signer/v4/v4.go#L427-L431 @@ -124,21 +109,21 @@ func (a *awsHandler) Do(ctx context.Context, requestHeaders map[string]string, h credentials, err := a.credentialsProvider.Retrieve(ctx) if err != nil { - return fmt.Errorf("cannot retrieve AWS credentials: %w", err) + return nil, fmt.Errorf("cannot retrieve AWS credentials: %w", err) } err = a.signer.SignHTTP(ctx, credentials, req, hex.EncodeToString(payloadHash[:]), "bedrock", a.region, time.Now()) if err != nil { - return fmt.Errorf("cannot sign request: %w", err) + return nil, fmt.Errorf("cannot sign request: %w", err) } + var headers []internalapi.Header for key, hdr := range req.Header { if key == "Authorization" || strings.HasPrefix(key, "X-Amz-") { - headerMut.SetHeaders = append(headerMut.SetHeaders, &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{Key: key, RawValue: []byte(hdr[0])}, // Assume aws-go-sdk always returns a single value. - }) + headers = append(headers, internalapi.Header{key, hdr[0]}) + requestHeaders[key] = hdr[0] } } - return nil + return headers, nil } diff --git a/internal/extproc/backendauth/aws_test.go b/internal/backendauth/aws_test.go similarity index 71% rename from internal/extproc/backendauth/aws_test.go rename to internal/backendauth/aws_test.go index fa90db595..a08855235 100644 --- a/internal/extproc/backendauth/aws_test.go +++ b/internal/backendauth/aws_test.go @@ -9,39 +9,18 @@ import ( "sync" "testing" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/stretchr/testify/require" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" ) -// Test helper to extract headers from HeaderMutation -func extractHeaders(headerMut *extprocv3.HeaderMutation) map[string]string { - headers := map[string]string{} - for _, h := range headerMut.SetHeaders { - value := h.Header.Value - if value == "" && len(h.Header.RawValue) > 0 { - value = string(h.Header.RawValue) - } - headers[h.Header.Key] = value - } - return headers -} - -// Test helper to create a test request -func createTestRequest(method, path string, body []byte) (map[string]string, *extprocv3.HeaderMutation, *extprocv3.BodyMutation) { - requestHeaders := map[string]string{":method": method} - headerMut := &extprocv3.HeaderMutation{ - SetHeaders: []*corev3.HeaderValueOption{ - {Header: &corev3.HeaderValue{Key: ":path", Value: path}}, - }, - } - bodyMut := &extprocv3.BodyMutation{} - if len(body) > 0 { - bodyMut.Mutation = &extprocv3.BodyMutation_Body{Body: body} +func stringPairsToMap(pairs []internalapi.Header) map[string]string { + result := make(map[string]string) + for _, h := range pairs { + result[h.Key()] = h.Value() } - return requestHeaders, headerMut, bodyMut + return result } func TestNewAWSHandler(t *testing.T) { @@ -111,8 +90,7 @@ func TestNewAWSHandler(t *testing.T) { require.NotNil(t, handler) // But calling Do() should fail when no credentials are available - reqHeaders, headerMut, bodyMut := createTestRequest("POST", "/model/test/converse", []byte(`{"test": "data"}`)) - err = handler.Do(t.Context(), reqHeaders, headerMut, bodyMut) + _, err = handler.Do(t.Context(), map[string]string{":method": "POST", ":path": "/model/test/converse"}, []byte(`{"test": "data"}`)) require.Error(t, err) require.Contains(t, err.Error(), "cannot retrieve AWS credentials") }) @@ -140,15 +118,10 @@ func TestAWSHandler_Do(t *testing.T) { for range 100 { go func() { defer wg.Done() - reqHeaders, headerMut, bodyMut := createTestRequest( - "POST", - "/model/some-random-model/converse", - []byte(`{"messages": [{"role": "user", "content": [{"text": "Say this is a test!"}]}]}`), - ) - err := handler.Do(t.Context(), reqHeaders, headerMut, bodyMut) + hdrs, err := handler.Do(t.Context(), map[string]string{":method": "POST", ":path": "/model/some-random-model/converse"}, []byte(`{"messages": [{"role": "user", "content": [{"text": "Say this is a test!"}]}]}`)) require.NoError(t, err) - headers := extractHeaders(headerMut) + headers := stringPairsToMap(hdrs) require.Contains(t, headers, "X-Amz-Date") require.Contains(t, headers, "Authorization") }() @@ -167,16 +140,12 @@ func TestAWSHandler_Do(t *testing.T) { }) require.NoError(t, err) - reqHeaders, headerMut, bodyMut := createTestRequest( - "POST", - "/model/amazon.titan-text-express-v1/invoke", - []byte(`{"inputText": "Hello from default chain"}`), - ) - - err = handler.Do(t.Context(), reqHeaders, headerMut, bodyMut) + hdrs, err := handler.Do(t.Context(), map[string]string{ + ":method": "POST", ":path": "/model/amazon.titan-text-express-v1/invoke", + }, []byte(`{"inputText": "Hello from default chain"}`)) require.NoError(t, err) - headers := extractHeaders(headerMut) + headers := stringPairsToMap(hdrs) require.Contains(t, headers, "X-Amz-Date") require.Contains(t, headers, "Authorization") // Verify the authorization header contains the access key ID @@ -194,16 +163,12 @@ func TestAWSHandler_Do(t *testing.T) { }) require.NoError(t, err) - reqHeaders, headerMut, bodyMut := createTestRequest( - "POST", - "/model/anthropic.claude-v2/converse", - []byte(`{"messages": []}`), - ) - - err = handler.Do(t.Context(), reqHeaders, headerMut, bodyMut) + hdrs, err := handler.Do(t.Context(), map[string]string{ + ":method": "POST", ":path": "/model/anthropic.claude-v2/converse", + }, []byte(`{"inputText": "Hello from default chain"}`)) require.NoError(t, err) - headers := extractHeaders(headerMut) + headers := stringPairsToMap(hdrs) require.Contains(t, headers, "X-Amz-Date") require.Contains(t, headers, "Authorization") require.Contains(t, headers, "X-Amz-Security-Token") @@ -220,15 +185,12 @@ func TestAWSHandler_Do(t *testing.T) { methods := []string{"POST", "GET", "PUT"} for _, method := range methods { - reqHeaders, headerMut, bodyMut := createTestRequest( - method, - "/model/test-model/invoke", - []byte(`{"test": "data"}`), - ) - err := handler.Do(t.Context(), reqHeaders, headerMut, bodyMut) - require.NoError(t, err, "Failed for method: %s", method) - - headers := extractHeaders(headerMut) + hdrs, err := handler.Do(t.Context(), map[string]string{ + ":method": method, ":path": "/model/test-model/invoke", + }, []byte(`{"test": "data"}`)) + require.NoError(t, err) + + headers := stringPairsToMap(hdrs) require.Contains(t, headers, "Authorization", "Missing Authorization for method: %s", method) } }) @@ -241,11 +203,12 @@ func TestAWSHandler_Do(t *testing.T) { }) require.NoError(t, err) - reqHeaders, headerMut, bodyMut := createTestRequest("GET", "/models", nil) - err = handler.Do(t.Context(), reqHeaders, headerMut, bodyMut) + hdrs, err := handler.Do(t.Context(), map[string]string{ + ":method": "GET", ":path": "/model/test-model/invoke", + }, nil) require.NoError(t, err) - headers := extractHeaders(headerMut) + headers := stringPairsToMap(hdrs) require.Contains(t, headers, "Authorization") require.Contains(t, headers, "X-Amz-Date") }) @@ -261,12 +224,9 @@ func TestAWSHandler_Do(t *testing.T) { }) require.NoError(t, err) - reqHeaders, headerMut, bodyMut := createTestRequest( - "POST", - "/model/test/converse", - []byte(`{"test": "data"}`), - ) - err = handler.Do(t.Context(), reqHeaders, headerMut, bodyMut) + _, err = handler.Do(t.Context(), map[string]string{ + ":method": "POST", ":path": "/model/test/converse", + }, nil) require.NoError(t, err, "Failed for region: %s", region) } }) diff --git a/internal/extproc/backendauth/azure.go b/internal/backendauth/azure.go similarity index 64% rename from internal/extproc/backendauth/azure.go rename to internal/backendauth/azure.go index 0e5e3e480..25c2c5a73 100644 --- a/internal/extproc/backendauth/azure.go +++ b/internal/backendauth/azure.go @@ -10,10 +10,8 @@ import ( "fmt" "strings" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" ) type azureHandler struct { @@ -27,10 +25,7 @@ func newAzureHandler(auth *filterapi.AzureAuth) (Handler, error) { // Do implements [Handler.Do]. // // Extracts the azure access token from the local file and set it as an authorization header. -func (a *azureHandler) Do(_ context.Context, requestHeaders map[string]string, headerMut *extprocv3.HeaderMutation, _ *extprocv3.BodyMutation) error { +func (a *azureHandler) Do(_ context.Context, requestHeaders map[string]string, _ []byte) ([]internalapi.Header, error) { requestHeaders["Authorization"] = fmt.Sprintf("Bearer %s", a.azureAccessToken) - headerMut.SetHeaders = append(headerMut.SetHeaders, &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{Key: "Authorization", RawValue: []byte(requestHeaders["Authorization"])}, - }) - return nil + return []internalapi.Header{{"Authorization", fmt.Sprintf("Bearer %s", a.azureAccessToken)}}, nil } diff --git a/internal/extproc/backendauth/azure_test.go b/internal/backendauth/azure_test.go similarity index 52% rename from internal/extproc/backendauth/azure_test.go rename to internal/backendauth/azure_test.go index baec2e6e6..4ea2a2634 100644 --- a/internal/extproc/backendauth/azure_test.go +++ b/internal/backendauth/azure_test.go @@ -8,8 +8,6 @@ package backendauth import ( "testing" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/stretchr/testify/require" "github.com/envoyproxy/ai-gateway/internal/filterapi" @@ -30,31 +28,15 @@ func TestNewAzureHandler_Do(t *testing.T) { require.NoError(t, err) require.NotNil(t, handler) - requestHeaders := map[string]string{":method": "POST"} - headerMut := &extprocv3.HeaderMutation{ - SetHeaders: []*corev3.HeaderValueOption{ - { - Header: &corev3.HeaderValue{ - Key: ":path", - Value: "/model/some-random-model/chat/completion", - }, - }, - }, - } - bodyMut := &extprocv3.BodyMutation{ - Mutation: &extprocv3.BodyMutation_Body{ - Body: []byte(`{"messages": [{"role": "user", "content": [{"text": "Say this is a test!"}]}]}`), - }, - } - - err = handler.Do(t.Context(), requestHeaders, headerMut, bodyMut) + requestHeaders := map[string]string{":method": "POST", ":path": "/model/some-random-model/chat/completion"} + headers, err := handler.Do(t.Context(), requestHeaders, []byte(`{"messages": [{"role": "user", "content": [{"text": "Say this is a test!"}]}]}`)) require.NoError(t, err) bearerToken, ok := requestHeaders["Authorization"] require.True(t, ok) require.Equal(t, "Bearer some-access-token", bearerToken) - require.Len(t, headerMut.SetHeaders, 2) - require.Equal(t, "Authorization", headerMut.SetHeaders[1].Header.Key) - require.Equal(t, []byte("Bearer some-access-token"), headerMut.SetHeaders[1].Header.GetRawValue()) + require.Len(t, headers, 1) + require.Equal(t, "Authorization", headers[0][0]) + require.Equal(t, "Bearer some-access-token", headers[0][1]) } diff --git a/internal/extproc/backendauth/azureapikey.go b/internal/backendauth/azureapikey.go similarity index 65% rename from internal/extproc/backendauth/azureapikey.go rename to internal/backendauth/azureapikey.go index 7bad3bd7c..9eed9181b 100644 --- a/internal/extproc/backendauth/azureapikey.go +++ b/internal/backendauth/azureapikey.go @@ -10,10 +10,8 @@ import ( "fmt" "strings" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" ) type azureAPIKeyHandler struct { @@ -29,13 +27,7 @@ func newAzureAPIKeyHandler(auth *filterapi.AzureAPIKeyAuth) (Handler, error) { // Do sets the api-key header for Azure OpenAI authentication. // Azure OpenAI uses "api-key" header instead of "Authorization: Bearer". -func (a *azureAPIKeyHandler) Do(_ context.Context, requestHeaders map[string]string, headerMut *extprocv3.HeaderMutation, _ *extprocv3.BodyMutation) error { +func (a *azureAPIKeyHandler) Do(_ context.Context, requestHeaders map[string]string, _ []byte) ([]internalapi.Header, error) { requestHeaders["api-key"] = a.apiKey - headerMut.SetHeaders = append(headerMut.SetHeaders, &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{ - Key: "api-key", - RawValue: []byte(a.apiKey), - }, - }) - return nil + return []internalapi.Header{{"api-key", a.apiKey}}, nil } diff --git a/internal/extproc/backendauth/azureapikey_test.go b/internal/backendauth/azureapikey_test.go similarity index 71% rename from internal/extproc/backendauth/azureapikey_test.go rename to internal/backendauth/azureapikey_test.go index a86841a00..4e7f4d4a4 100644 --- a/internal/extproc/backendauth/azureapikey_test.go +++ b/internal/backendauth/azureapikey_test.go @@ -6,10 +6,8 @@ package backendauth import ( - "context" "testing" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/stretchr/testify/require" "github.com/envoyproxy/ai-gateway/internal/filterapi" @@ -21,18 +19,17 @@ func TestAzureAPIKeyHandler(t *testing.T) { require.NoError(t, err) headers := make(map[string]string) - headerMut := &extprocv3.HeaderMutation{} - err = handler.Do(context.Background(), headers, headerMut, nil) + hdrs, err := handler.Do(t.Context(), headers, nil) require.NoError(t, err) // Verify header in map require.Equal(t, "test-azure-key", headers["api-key"]) // Verify header in mutation - require.Len(t, headerMut.SetHeaders, 1) - require.Equal(t, "api-key", headerMut.SetHeaders[0].Header.Key) - require.Equal(t, "test-azure-key", string(headerMut.SetHeaders[0].Header.RawValue)) + require.Len(t, hdrs, 1) + require.Equal(t, "api-key", hdrs[0][0]) + require.Equal(t, "test-azure-key", hdrs[0][1]) }) t.Run("trims whitespace", func(t *testing.T) { @@ -40,12 +37,13 @@ func TestAzureAPIKeyHandler(t *testing.T) { require.NoError(t, err) headers := make(map[string]string) - headerMut := &extprocv3.HeaderMutation{} - - err = handler.Do(context.Background(), headers, headerMut, nil) + hdrs, err := handler.Do(t.Context(), headers, nil) require.NoError(t, err) require.Equal(t, "key-with-spaces", headers["api-key"]) + require.Len(t, hdrs, 1) + require.Equal(t, "api-key", hdrs[0][0]) + require.Equal(t, "key-with-spaces", hdrs[0][1]) }) t.Run("requires non-empty key", func(t *testing.T) { diff --git a/internal/extproc/backendauth/gcp.go b/internal/backendauth/gcp.go similarity index 58% rename from internal/extproc/backendauth/gcp.go rename to internal/backendauth/gcp.go index 7259f99d3..857b99b36 100644 --- a/internal/extproc/backendauth/gcp.go +++ b/internal/backendauth/gcp.go @@ -9,10 +9,8 @@ import ( "context" "fmt" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/internalapi" ) type gcpHandler struct { @@ -45,45 +43,21 @@ func newGCPHandler(gcpAuth *filterapi.GCPAuth) (Handler, error) { // // The ":path" header is expected to contain the API-specific suffix, which is injected by translator.requestBody. // The suffix is combined with the generated prefix to form the complete path for the GCP API call. -func (g *gcpHandler) Do(_ context.Context, _ map[string]string, headerMut *extprocv3.HeaderMutation, _ *extprocv3.BodyMutation) error { - var pathHeaderFound bool - +func (g *gcpHandler) Do(_ context.Context, requestHeaders map[string]string, _ []byte) ([]internalapi.Header, error) { // Build the GCP URL prefix using the configured region and project name. prefixPath := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s", g.region, g.projectName, g.region) // Find and update the ":path" header by prepending the prefix. - for _, hdr := range headerMut.SetHeaders { - if hdr.Header != nil && hdr.Header.Key == ":path" { - pathHeaderFound = true - // Update the string value if present. - if len(hdr.Header.Value) > 0 { - suffixPath := hdr.Header.Value - hdr.Header.Value = fmt.Sprintf("%s/%s", prefixPath, suffixPath) - } - // Update the raw byte value if present. - if len(hdr.Header.RawValue) > 0 { - suffixPath := string(hdr.Header.RawValue) - path := fmt.Sprintf("%s/%s", prefixPath, suffixPath) - hdr.Header.RawValue = []byte(path) - } - break - } - } + path := requestHeaders[":path"] + // Update the raw byte value if present. + newPath := fmt.Sprintf("%s/%s", prefixPath, path) - if !pathHeaderFound { - return fmt.Errorf("missing ':path' header in the request") + if path == "" { + return nil, fmt.Errorf("missing ':path' header in the request") } // Add the Authorization header with the GCP access token. - headerMut.SetHeaders = append( - headerMut.SetHeaders, - &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{ - Key: "Authorization", - RawValue: fmt.Appendf(nil, "Bearer %s", g.gcpAccessToken), - }, - }, - ) - - return nil + requestHeaders[":path"] = newPath + requestHeaders["Authorization"] = fmt.Sprintf("Bearer %s", g.gcpAccessToken) + return []internalapi.Header{{":path", newPath}, {"Authorization", fmt.Sprintf("Bearer %s", g.gcpAccessToken)}}, nil } diff --git a/internal/backendauth/gcp_test.go b/internal/backendauth/gcp_test.go new file mode 100644 index 000000000..13c83503d --- /dev/null +++ b/internal/backendauth/gcp_test.go @@ -0,0 +1,116 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package backendauth + +import ( + "context" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/filterapi" +) + +func TestNewGCPHandler(t *testing.T) { + testCases := []struct { + name string + gcpAuth *filterapi.GCPAuth + wantHandler *gcpHandler + wantErrorMsg string + }{ + { + name: "valid config", + gcpAuth: &filterapi.GCPAuth{ + AccessToken: "test-token", + Region: "us-central1", + ProjectName: "test-project", + }, + wantHandler: &gcpHandler{ + gcpAccessToken: "test-token", + region: "us-central1", + projectName: "test-project", + }, + }, + { + name: "missing auth token", + gcpAuth: &filterapi.GCPAuth{ + AccessToken: "", + Region: "us-central1", + ProjectName: "test-project", + }, + wantHandler: nil, + wantErrorMsg: "GCP access token cannot be empty", + }, + { + name: "nil config", + gcpAuth: nil, + wantHandler: nil, + wantErrorMsg: "GCP auth configuration cannot be nil", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + handler, err := newGCPHandler(tc.gcpAuth) + if tc.wantErrorMsg != "" { + require.ErrorContains(t, err, tc.wantErrorMsg) + } else { + require.NoError(t, err) + require.NotNil(t, handler) + + if d := cmp.Diff(tc.wantHandler, handler, cmp.AllowUnexported(gcpHandler{})); d != "" { + t.Errorf("Handler mismatch (-want +got):\n%s", d) + } + } + }) + } +} + +func TestGCPHandler_Do(t *testing.T) { + handler := &gcpHandler{ + gcpAccessToken: "test-token", + region: "us-central1", + projectName: "test-project", + } + testCases := []struct { + name string + handler *gcpHandler + requestHeaders map[string]string + wantPathValue string + wantPathRawValue []byte + }{ + { + name: "basic headers update", + handler: handler, + requestHeaders: map[string]string{ + ":path": "publishers/google/models/gemini-pro:generateContent", + }, + wantPathValue: "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:generateContent", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + hdrs, err := tc.handler.Do(ctx, tc.requestHeaders, nil) + require.NoError(t, err) + + expectedAuthHeader := fmt.Sprintf("Bearer %s", tc.handler.gcpAccessToken) + + hdrsMap := stringPairsToMap(hdrs) + authValue, ok := hdrsMap["Authorization"] + require.True(t, ok, "Authorization header not found in returned headers") + require.Equal(t, expectedAuthHeader, authValue, "Authorization header value mismatch") + + pathValue, ok := hdrsMap[":path"] + require.True(t, ok, ":path header not found in returned headers") + require.Equal(t, tc.wantPathValue, pathValue, ":path header value mismatch") + }) + } +} diff --git a/internal/extproc/backendauth/aws_bench_test.go b/internal/extproc/backendauth/aws_bench_test.go deleted file mode 100644 index 64a501530..000000000 --- a/internal/extproc/backendauth/aws_bench_test.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright Envoy AI Gateway Authors -// SPDX-License-Identifier: Apache-2.0 -// The full text of the Apache license is available in the LICENSE file at -// the root of the repo. - -package backendauth - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/envoyproxy/ai-gateway/internal/filterapi" -) - -// BenchmarkAWSHandler_Do benchmarks the Do method with different credential sources. -// Run with: go test -bench=BenchmarkAWSHandler_Do -benchmem -tags=benchmark ./internal/extproc/backendauth/ -func BenchmarkAWSHandler_Do(b *testing.B) { - b.Run("file_credentials", func(b *testing.B) { - awsFileBody := "[default]\naws_access_key_id=AKIAIOSFODNN7EXAMPLE\naws_secret_access_key=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY\n" //nolint:gosec - handler, err := newAWSHandler(b.Context(), &filterapi.AWSAuth{ - CredentialFileLiteral: awsFileBody, - Region: "us-east-1", - }) - require.NoError(b, err) - - reqHeaders, headerMut, bodyMut := createTestRequest( - "POST", - "/model/anthropic.claude-v2/converse", - []byte(`{"messages": [{"role": "user", "content": [{"text": "Hello"}]}]}`), - ) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - err := handler.Do(b.Context(), reqHeaders, headerMut, bodyMut) - if err != nil { - b.Fatal(err) - } - } - }) - - b.Run("default_chain_env_credentials", func(b *testing.B) { - b.Setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE") - b.Setenv("AWS_SECRET_ACCESS_KEY", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY") - - handler, err := newAWSHandler(b.Context(), &filterapi.AWSAuth{ - Region: "us-east-1", - }) - require.NoError(b, err) - - reqHeaders, headerMut, bodyMut := createTestRequest( - "POST", - "/model/anthropic.claude-v2/converse", - []byte(`{"messages": [{"role": "user", "content": [{"text": "Hello"}]}]}`), - ) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - err := handler.Do(b.Context(), reqHeaders, headerMut, bodyMut) - if err != nil { - b.Fatal(err) - } - } - }) - - b.Run("concurrent", func(b *testing.B) { - awsFileBody := "[default]\naws_access_key_id=AKIAIOSFODNN7EXAMPLE\naws_secret_access_key=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY\n" //nolint:gosec - handler, err := newAWSHandler(b.Context(), &filterapi.AWSAuth{ - CredentialFileLiteral: awsFileBody, - Region: "us-east-1", - }) - require.NoError(b, err) - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - reqHeaders, headerMut, bodyMut := createTestRequest( - "POST", - "/model/anthropic.claude-v2/converse", - []byte(`{"messages": [{"role": "user", "content": [{"text": "Hello"}]}]}`), - ) - err := handler.Do(b.Context(), reqHeaders, headerMut, bodyMut) - if err != nil { - b.Fatal(err) - } - } - }) - }) - - b.Run("just_credential_retrieve", func(b *testing.B) { - awsFileBody := "[default]\naws_access_key_id=AKIAIOSFODNN7EXAMPLE\naws_secret_access_key=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY\n" //nolint:gosec - handler, err := newAWSHandler(b.Context(), &filterapi.AWSAuth{ - CredentialFileLiteral: awsFileBody, - Region: "us-east-1", - }) - require.NoError(b, err) - - awsH := handler.(*awsHandler) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := awsH.credentialsProvider.Retrieve(b.Context()) - if err != nil { - b.Fatal(err) - } - } - }) -} diff --git a/internal/extproc/backendauth/gcp_test.go b/internal/extproc/backendauth/gcp_test.go deleted file mode 100644 index b3e85e3e2..000000000 --- a/internal/extproc/backendauth/gcp_test.go +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright Envoy AI Gateway Authors -// SPDX-License-Identifier: Apache-2.0 -// The full text of the Apache license is available in the LICENSE file at -// the root of the repo. - -package backendauth - -import ( - "bytes" - "context" - "fmt" - "testing" - - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/google/go-cmp/cmp" - "github.com/stretchr/testify/require" - - "github.com/envoyproxy/ai-gateway/internal/filterapi" -) - -func TestNewGCPHandler(t *testing.T) { - testCases := []struct { - name string - gcpAuth *filterapi.GCPAuth - wantHandler *gcpHandler - wantErrorMsg string - }{ - { - name: "valid config", - gcpAuth: &filterapi.GCPAuth{ - AccessToken: "test-token", - Region: "us-central1", - ProjectName: "test-project", - }, - wantHandler: &gcpHandler{ - gcpAccessToken: "test-token", - region: "us-central1", - projectName: "test-project", - }, - }, - { - name: "missing auth token", - gcpAuth: &filterapi.GCPAuth{ - AccessToken: "", - Region: "us-central1", - ProjectName: "test-project", - }, - wantHandler: nil, - wantErrorMsg: "GCP access token cannot be empty", - }, - { - name: "nil config", - gcpAuth: nil, - wantHandler: nil, - wantErrorMsg: "GCP auth configuration cannot be nil", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - handler, err := newGCPHandler(tc.gcpAuth) - if tc.wantErrorMsg != "" { - require.ErrorContains(t, err, tc.wantErrorMsg) - } else { - require.NoError(t, err) - require.NotNil(t, handler) - - if d := cmp.Diff(tc.wantHandler, handler, cmp.AllowUnexported(gcpHandler{})); d != "" { - t.Errorf("Handler mismatch (-want +got):\n%s", d) - } - } - }) - } -} - -func TestGCPHandler_Do(t *testing.T) { - handler := &gcpHandler{ - gcpAccessToken: "test-token", - region: "us-central1", - projectName: "test-project", - } - testCases := []struct { - name string - handler *gcpHandler - requestHeaders map[string]string - headerMut *extprocv3.HeaderMutation - bodyMut *extprocv3.BodyMutation - wantPathValue string - wantPathRawValue []byte - wantErrorMsg string - }{ - { - name: "basic headers update with string value", - handler: handler, - headerMut: &extprocv3.HeaderMutation{ - SetHeaders: []*corev3.HeaderValueOption{ - { - Header: &corev3.HeaderValue{ - Key: ":path", - Value: "publishers/google/models/gemini-pro:generateContent", - }, - }, - }, - }, - bodyMut: &extprocv3.BodyMutation{}, - wantPathValue: "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:generateContent", - }, - { - name: "basic headers update with raw value", - handler: handler, - headerMut: &extprocv3.HeaderMutation{ - SetHeaders: []*corev3.HeaderValueOption{ - { - Header: &corev3.HeaderValue{ - Key: ":path", - RawValue: []byte("publishers/google/models/gemini-pro:generateContent"), - }, - }, - }, - }, - bodyMut: &extprocv3.BodyMutation{}, - wantPathRawValue: []byte("https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:generateContent"), - }, - { - name: "no path header", - handler: handler, - headerMut: &extprocv3.HeaderMutation{ - SetHeaders: []*corev3.HeaderValueOption{ - { - Header: &corev3.HeaderValue{ - Key: "Content-Type", - Value: "application/json", - }, - }, - }, - }, - bodyMut: &extprocv3.BodyMutation{}, - wantErrorMsg: "missing ':path' header in the request", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - err := tc.handler.Do(ctx, nil, tc.headerMut, tc.bodyMut) - - if tc.wantErrorMsg != "" { - require.ErrorContains(t, err, tc.wantErrorMsg, "Expected error message not found") - } else { - require.NoError(t, err) - - // Check Authorization header. - authHeaderFound := false - expectedAuthHeader := fmt.Sprintf("Bearer %s", tc.handler.gcpAccessToken) - - // Check path header if expected. - pathHeaderUpdated := false - - for _, header := range tc.headerMut.SetHeaders { - if header.Header.Key == "Authorization" { - authHeaderFound = true - require.Equal(t, []byte(expectedAuthHeader), header.Header.RawValue) - } - - if header.Header.Key == ":path" { - pathHeaderUpdated = true - if len(tc.wantPathValue) > 0 { - require.Equal(t, tc.wantPathValue, header.Header.Value) - } - if len(tc.wantPathRawValue) > 0 { - require.True(t, bytes.Equal(tc.wantPathRawValue, header.Header.RawValue)) - } - } - } - - // Authorization header should always be added. - require.True(t, authHeaderFound, "Authorization header not found") - - // Only check path header if we had expectations for it. - if len(tc.wantPathValue) > 0 || len(tc.wantPathRawValue) > 0 { - require.True(t, pathHeaderUpdated, "Path header not updated as expected") - } - } - }) - } -} diff --git a/internal/extproc/chatcompletion_processor.go b/internal/extproc/chatcompletion_processor.go index db36fd536..8ea39adda 100644 --- a/internal/extproc/chatcompletion_processor.go +++ b/internal/extproc/chatcompletion_processor.go @@ -20,10 +20,10 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" + "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/llmcostcel" "github.com/envoyproxy/ai-gateway/internal/metrics" @@ -255,9 +255,16 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(ctx contex // Apply header mutations from the route and also restore original headers on retry. if h := c.headerMutator; h != nil { - if hm := c.headerMutator.Mutate(c.requestHeaders, c.onRetry); hm != nil { - headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, hm.RemoveHeaders...) - headerMutation.SetHeaders = append(headerMutation.SetHeaders, hm.SetHeaders...) + sets, removes := c.headerMutator.Mutate(c.requestHeaders, c.onRetry) + headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, removes...) + for _, hdr := range sets { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{ + Key: hdr.Key(), + RawValue: []byte(hdr.Value()), + }, + }) } } @@ -266,9 +273,17 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(ctx contex } if h := c.handler; h != nil { - if err = h.Do(ctx, c.requestHeaders, headerMutation, bodyMutation); err != nil { + var hdrs []internalapi.Header + hdrs, err = h.Do(ctx, c.requestHeaders, bodyMutation.GetBody()) + if err != nil { return nil, fmt.Errorf("failed to do auth request: %w", err) } + for _, h := range hdrs { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{Key: h.Key(), RawValue: []byte(h.Value())}, + }) + } } var dm *structpb.Struct diff --git a/internal/extproc/chatcompletion_processor_test.go b/internal/extproc/chatcompletion_processor_test.go index 42039c62f..2c91415e8 100644 --- a/internal/extproc/chatcompletion_processor_test.go +++ b/internal/extproc/chatcompletion_processor_test.go @@ -20,9 +20,9 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/llmcostcel" "github.com/envoyproxy/ai-gateway/internal/metrics" diff --git a/internal/extproc/completions_processor.go b/internal/extproc/completions_processor.go index eaa0f7d42..2cd1cc494 100644 --- a/internal/extproc/completions_processor.go +++ b/internal/extproc/completions_processor.go @@ -19,10 +19,10 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" + "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/metrics" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" @@ -237,9 +237,16 @@ func (c *completionsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.C // Apply header mutations from the route and also restore original headers on retry. if h := c.headerMutator; h != nil { - if hm := c.headerMutator.Mutate(c.requestHeaders, c.onRetry); hm != nil { - headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, hm.RemoveHeaders...) - headerMutation.SetHeaders = append(headerMutation.SetHeaders, hm.SetHeaders...) + sets, removes := c.headerMutator.Mutate(c.requestHeaders, c.onRetry) + headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, removes...) + for _, hdr := range sets { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{ + Key: hdr.Key(), + RawValue: []byte(hdr.Value()), + }, + }) } } @@ -247,9 +254,17 @@ func (c *completionsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.C c.requestHeaders[h.Header.Key] = string(h.Header.RawValue) } if h := c.handler; h != nil { - if err = h.Do(ctx, c.requestHeaders, headerMutation, bodyMutation); err != nil { + var hdrs []internalapi.Header + hdrs, err = h.Do(ctx, c.requestHeaders, bodyMutation.GetBody()) + if err != nil { return nil, fmt.Errorf("failed to do auth request: %w", err) } + for _, h := range hdrs { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{Key: h.Key(), RawValue: []byte(h.Value())}, + }) + } } var dm *structpb.Struct diff --git a/internal/extproc/completions_processor_test.go b/internal/extproc/completions_processor_test.go index 874fecc11..797a9728f 100644 --- a/internal/extproc/completions_processor_test.go +++ b/internal/extproc/completions_processor_test.go @@ -19,9 +19,9 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/llmcostcel" "github.com/envoyproxy/ai-gateway/internal/metrics" diff --git a/internal/extproc/embeddings_processor.go b/internal/extproc/embeddings_processor.go index 9814083c5..ea75cf21a 100644 --- a/internal/extproc/embeddings_processor.go +++ b/internal/extproc/embeddings_processor.go @@ -18,10 +18,10 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" + "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/metrics" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" @@ -211,9 +211,16 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co // Apply header mutations from the route and also restore original headers on retry. if h := e.headerMutator; h != nil { - if hm := e.headerMutator.Mutate(e.requestHeaders, e.onRetry); hm != nil { - headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, hm.RemoveHeaders...) - headerMutation.SetHeaders = append(headerMutation.SetHeaders, hm.SetHeaders...) + sets, removes := e.headerMutator.Mutate(e.requestHeaders, e.onRetry) + headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, removes...) + for _, hdr := range sets { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{ + Key: hdr.Key(), + RawValue: []byte(hdr.Value()), + }, + }) } } @@ -221,9 +228,17 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Co e.requestHeaders[h.Header.Key] = string(h.Header.RawValue) } if h := e.handler; h != nil { - if err = h.Do(ctx, e.requestHeaders, headerMutation, bodyMutation); err != nil { + var hdrs []internalapi.Header + hdrs, err = h.Do(ctx, e.requestHeaders, bodyMutation.GetBody()) + if err != nil { return nil, fmt.Errorf("failed to do auth request: %w", err) } + for _, h := range hdrs { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{Key: h.Key(), RawValue: []byte(h.Value())}, + }) + } } var dm *structpb.Struct diff --git a/internal/extproc/embeddings_processor_test.go b/internal/extproc/embeddings_processor_test.go index 0fe3ea022..dd659369c 100644 --- a/internal/extproc/embeddings_processor_test.go +++ b/internal/extproc/embeddings_processor_test.go @@ -18,9 +18,9 @@ import ( "github.com/stretchr/testify/require" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/llmcostcel" "github.com/envoyproxy/ai-gateway/internal/metrics" @@ -543,9 +543,11 @@ func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutat // Check that header mutations were applied. require.NotNil(t, commonRes.HeaderMutation) require.ElementsMatch(t, []string{"authorization", "x-api-key"}, commonRes.HeaderMutation.RemoveHeaders) - require.Len(t, commonRes.HeaderMutation.SetHeaders, 1) + require.Len(t, commonRes.HeaderMutation.SetHeaders, 2) require.Equal(t, "x-new-header", commonRes.HeaderMutation.SetHeaders[0].Header.Key) require.Equal(t, []byte("new-value"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue) + require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[1].Header.Key) + require.Equal(t, []byte("mock-auth-handler"), commonRes.HeaderMutation.SetHeaders[1].Header.RawValue) // Check that headers were modified in the request headers. require.Equal(t, "new-value", headers["x-new-header"]) @@ -589,7 +591,9 @@ func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutat // Check that no header mutations were applied. require.NotNil(t, commonRes.HeaderMutation) require.Empty(t, commonRes.HeaderMutation.RemoveHeaders) - require.Empty(t, commonRes.HeaderMutation.SetHeaders) + require.Len(t, commonRes.HeaderMutation.SetHeaders, 1) + require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[0].Header.Key) + require.Equal(t, []byte("mock-auth-handler"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue) // Check that original headers remain unchanged. require.Equal(t, "bearer token123", headers["authorization"]) @@ -626,19 +630,6 @@ func TestEmbeddingsProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *tes // Verify header mutator was created. require.NotNil(t, p.headerMutator) - - // Test that the header mutator works correctly. - testHeaders := map[string]string{ - "x-sensitive": "secret", - "x-existing": "value", - } - mutation := p.headerMutator.Mutate(testHeaders, false) // onRetry = false. - - require.NotNil(t, mutation) - require.ElementsMatch(t, []string{"x-sensitive"}, mutation.RemoveHeaders) - require.Len(t, mutation.SetHeaders, 1) - require.Equal(t, "x-backend", mutation.SetHeaders[0].Header.Key) - require.Equal(t, []byte("backend-value"), mutation.SetHeaders[0].Header.RawValue) }) t.Run("header mutator with original headers", func(t *testing.T) { @@ -674,32 +665,5 @@ func TestEmbeddingsProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *tes HeaderMutation: headerMutations, }, nil, rp) require.NoError(t, err) - - // Verify header mutator was created with original headers. - require.NotNil(t, p.headerMutator) - - // Test retry scenario - original headers should be restored. - testHeaders := map[string]string{ - "x-existing": "previously-set-value", - } - mutation := p.headerMutator.Mutate(testHeaders, true) // onRetry = true. - - require.NotNil(t, mutation) - // RemoveHeaders should be empty because authorization doesn't exist in testHeaders. - require.Empty(t, mutation.RemoveHeaders) - - // Should restore x-custom header (not being removed and not already present). - var restoredHeader *corev3.HeaderValueOption - for _, h := range mutation.SetHeaders { - if h.Header.Key == "x-custom" { - restoredHeader = h - break - } - } - require.NotNil(t, restoredHeader) - require.Equal(t, []byte("original-value"), restoredHeader.Header.RawValue) - require.Equal(t, "original-value", testHeaders["x-custom"]) - // x-existing should be equal to existing-value from original headers. - require.Equal(t, "existing-value", testHeaders["x-existing"]) }) } diff --git a/internal/extproc/imagegeneration_processor.go b/internal/extproc/imagegeneration_processor.go index 41bc2a437..c023c76e6 100644 --- a/internal/extproc/imagegeneration_processor.go +++ b/internal/extproc/imagegeneration_processor.go @@ -19,10 +19,10 @@ import ( openaisdk "github.com/openai/openai-go/v2" "google.golang.org/protobuf/types/known/structpb" - "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" + "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/metrics" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" @@ -219,9 +219,16 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessRequestHeaders(ctx conte // Apply header mutations from the route and also restore original headers on retry. if h := i.headerMutator; h != nil { - if hm := i.headerMutator.Mutate(i.requestHeaders, i.onRetry); hm != nil { - headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, hm.RemoveHeaders...) - headerMutation.SetHeaders = append(headerMutation.SetHeaders, hm.SetHeaders...) + sets, removes := i.headerMutator.Mutate(i.requestHeaders, i.onRetry) + headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, removes...) + for _, hdr := range sets { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{ + Key: hdr.Key(), + RawValue: []byte(hdr.Value()), + }, + }) } } @@ -230,9 +237,17 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessRequestHeaders(ctx conte } if h := i.handler; h != nil { - if err = h.Do(ctx, i.requestHeaders, headerMutation, bodyMutation); err != nil { + var hdrs []internalapi.Header + hdrs, err = h.Do(ctx, i.requestHeaders, bodyMutation.GetBody()) + if err != nil { return nil, fmt.Errorf("failed to do auth request: %w", err) } + for _, h := range hdrs { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{Key: h.Key(), RawValue: []byte(h.Value())}, + }) + } } var dm *structpb.Struct diff --git a/internal/extproc/messages_processor.go b/internal/extproc/messages_processor.go index 7c3911f1d..28e5e81d4 100644 --- a/internal/extproc/messages_processor.go +++ b/internal/extproc/messages_processor.go @@ -19,10 +19,10 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/anthropic" - "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" + "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/metrics" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" @@ -199,9 +199,16 @@ func (c *messagesProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Cont // Apply header mutations from the route and also restore original headers on retry. if h := c.headerMutator; h != nil { - if hm := c.headerMutator.Mutate(c.requestHeaders, c.onRetry); hm != nil { - headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, hm.RemoveHeaders...) - headerMutation.SetHeaders = append(headerMutation.SetHeaders, hm.SetHeaders...) + sets, removes := c.headerMutator.Mutate(c.requestHeaders, c.onRetry) + headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, removes...) + for _, hdr := range sets { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{ + Key: hdr.Key(), + RawValue: []byte(hdr.Value()), + }, + }) } } @@ -209,9 +216,17 @@ func (c *messagesProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Cont c.requestHeaders[h.Header.Key] = string(h.Header.RawValue) } if h := c.handler; h != nil { - if err = h.Do(ctx, c.requestHeaders, headerMutation, bodyMutation); err != nil { + var hdrs []internalapi.Header + hdrs, err = h.Do(ctx, c.requestHeaders, bodyMutation.GetBody()) + if err != nil { return nil, fmt.Errorf("failed to do auth request: %w", err) } + for _, h := range hdrs { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{Key: h.Key(), RawValue: []byte(h.Value())}, + }) + } } var dm *structpb.Struct diff --git a/internal/extproc/messages_processor_test.go b/internal/extproc/messages_processor_test.go index 20ff8824d..e1c74edef 100644 --- a/internal/extproc/messages_processor_test.go +++ b/internal/extproc/messages_processor_test.go @@ -19,9 +19,9 @@ import ( "google.golang.org/protobuf/types/known/structpb" anthropicschema "github.com/envoyproxy/ai-gateway/internal/apischema/anthropic" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/metrics" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" @@ -656,9 +656,11 @@ func TestMessagesProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutatio // Check that header mutations were applied. require.NotNil(t, commonRes.HeaderMutation) require.ElementsMatch(t, []string{"authorization", "x-api-key"}, commonRes.HeaderMutation.RemoveHeaders) - require.Len(t, commonRes.HeaderMutation.SetHeaders, 1) + require.Len(t, commonRes.HeaderMutation.SetHeaders, 2) require.Equal(t, "x-new-header", commonRes.HeaderMutation.SetHeaders[0].Header.Key) require.Equal(t, []byte("new-value"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue) + require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[1].Header.Key) + require.Equal(t, []byte("mock-auth-handler"), commonRes.HeaderMutation.SetHeaders[1].Header.RawValue) // Check that headers were modified in the request headers. require.Equal(t, "new-value", headers["x-new-header"]) @@ -717,7 +719,9 @@ func TestMessagesProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutatio // Check that no header mutations were applied. require.NotNil(t, commonRes.HeaderMutation) require.Empty(t, commonRes.HeaderMutation.RemoveHeaders) - require.Empty(t, commonRes.HeaderMutation.SetHeaders) + require.Len(t, commonRes.HeaderMutation.SetHeaders, 1) + require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[0].Header.Key) + require.Equal(t, []byte("mock-auth-handler"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue) // Check that original headers remain unchanged. require.Equal(t, "bearer token123", headers["authorization"]) @@ -766,19 +770,6 @@ func TestMessagesProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *testi // Verify header mutator was created. require.NotNil(t, p.headerMutator) - - // Test that the header mutator works correctly. - testHeaders := map[string]string{ - "x-sensitive": "current-secret", - "x-existing": "current-value", - } - mutation := p.headerMutator.Mutate(testHeaders, false) - - require.NotNil(t, mutation) - require.ElementsMatch(t, []string{"x-sensitive"}, mutation.RemoveHeaders) - require.Len(t, mutation.SetHeaders, 1) - require.Equal(t, "x-backend", mutation.SetHeaders[0].Header.Key) - require.Equal(t, []byte("backend-value"), mutation.SetHeaders[0].Header.RawValue) }) t.Run("header mutator with original headers", func(t *testing.T) { @@ -822,29 +813,5 @@ func TestMessagesProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *testi // Verify header mutator was created with original headers. require.NotNil(t, p.headerMutator) - - // Test retry scenario - original headers should be restored. - testHeaders := map[string]string{ - "x-existing": "previously-set-value", - } - mutation := p.headerMutator.Mutate(testHeaders, true) // onRetry = true. - - require.NotNil(t, mutation) - // RemoveHeaders should be empty because authorization doesn't exist in testHeaders. - require.Empty(t, mutation.RemoveHeaders) - - // Should restore x-custom header (not being removed and not already present). - var restoredHeader *corev3.HeaderValueOption - for _, h := range mutation.SetHeaders { - if h.Header.Key == "x-custom" { - restoredHeader = h - break - } - } - require.NotNil(t, restoredHeader) - require.Equal(t, []byte("original-value"), restoredHeader.Header.RawValue) - require.Equal(t, "original-value", testHeaders["x-custom"]) - // x-existing should be equal to existing-value from original headers. - require.Equal(t, "existing-value", testHeaders["x-existing"]) }) } diff --git a/internal/extproc/mocks_test.go b/internal/extproc/mocks_test.go index 1c4dda246..be0cd5b68 100644 --- a/internal/extproc/mocks_test.go +++ b/internal/extproc/mocks_test.go @@ -18,7 +18,7 @@ import ( "google.golang.org/grpc/metadata" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" + "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/internalapi" @@ -508,15 +508,15 @@ var _ metrics.CompletionMetrics = &mockCompletionMetrics{} type mockBackendAuthHandler struct{} // Do implements [backendauth.Handler.Do]. -func (m *mockBackendAuthHandler) Do(context.Context, map[string]string, *extprocv3.HeaderMutation, *extprocv3.BodyMutation) error { - return nil +func (m *mockBackendAuthHandler) Do(context.Context, map[string]string, []byte) ([]internalapi.Header, error) { + return []internalapi.Header{{"foo", "mock-auth-handler"}}, nil } // mockBackendAuthHandlerError returns error on Do. type mockBackendAuthHandlerError struct{} -func (m *mockBackendAuthHandlerError) Do(context.Context, map[string]string, *extprocv3.HeaderMutation, *extprocv3.BodyMutation) error { - return io.EOF +func (m *mockBackendAuthHandlerError) Do(context.Context, map[string]string, []byte) ([]internalapi.Header, error) { + return nil, io.EOF } // mockImageGenerationMetrics implements [metrics.ImageGenerationMetrics] for testing. diff --git a/internal/extproc/processor.go b/internal/extproc/processor.go index b37fc7d50..6318a75a8 100644 --- a/internal/extproc/processor.go +++ b/internal/extproc/processor.go @@ -13,7 +13,7 @@ import ( extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/google/cel-go/cel" - "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" + "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/filterapi" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) diff --git a/internal/extproc/rerank_processor.go b/internal/extproc/rerank_processor.go index 75eed0b32..6aa245f84 100644 --- a/internal/extproc/rerank_processor.go +++ b/internal/extproc/rerank_processor.go @@ -18,10 +18,10 @@ import ( "google.golang.org/protobuf/types/known/structpb" cohereschema "github.com/envoyproxy/ai-gateway/internal/apischema/cohere" - "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" + "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/metrics" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" @@ -193,9 +193,16 @@ func (r *rerankProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Contex // Apply header mutations from the route and also restore original headers on retry. if h := r.headerMutator; h != nil { - if hm := r.headerMutator.Mutate(r.requestHeaders, r.onRetry); hm != nil { - headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, hm.RemoveHeaders...) - headerMutation.SetHeaders = append(headerMutation.SetHeaders, hm.SetHeaders...) + sets, removes := r.headerMutator.Mutate(r.requestHeaders, r.onRetry) + headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, removes...) + for _, hdr := range sets { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{ + Key: hdr.Key(), + RawValue: []byte(hdr.Value()), + }, + }) } } @@ -203,9 +210,17 @@ func (r *rerankProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Contex r.requestHeaders[h.Header.Key] = string(h.Header.RawValue) } if h := r.handler; h != nil { - if err = h.Do(ctx, r.requestHeaders, headerMutation, bodyMutation); err != nil { + var hdrs []internalapi.Header + hdrs, err = h.Do(ctx, r.requestHeaders, bodyMutation.GetBody()) + if err != nil { return nil, fmt.Errorf("failed to do auth request: %w", err) } + for _, h := range hdrs { + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + AppendAction: corev3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + Header: &corev3.HeaderValue{Key: h.Key(), RawValue: []byte(h.Value())}, + }) + } } var dm *structpb.Struct diff --git a/internal/extproc/rerank_processor_test.go b/internal/extproc/rerank_processor_test.go index 9e2f52403..b9f29f198 100644 --- a/internal/extproc/rerank_processor_test.go +++ b/internal/extproc/rerank_processor_test.go @@ -20,9 +20,9 @@ import ( "github.com/stretchr/testify/require" cohere "github.com/envoyproxy/ai-gateway/internal/apischema/cohere" - "github.com/envoyproxy/ai-gateway/internal/extproc/headermutator" "github.com/envoyproxy/ai-gateway/internal/extproc/translator" "github.com/envoyproxy/ai-gateway/internal/filterapi" + "github.com/envoyproxy/ai-gateway/internal/headermutator" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/llmcostcel" "github.com/envoyproxy/ai-gateway/internal/metrics" diff --git a/internal/extproc/server.go b/internal/extproc/server.go index fd606d42a..650d4042f 100644 --- a/internal/extproc/server.go +++ b/internal/extproc/server.go @@ -26,7 +26,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/prototext" - "github.com/envoyproxy/ai-gateway/internal/extproc/backendauth" + "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/llmcostcel" diff --git a/internal/extproc/headermutator/header_mutator.go b/internal/headermutator/header_mutator.go similarity index 80% rename from internal/extproc/headermutator/header_mutator.go rename to internal/headermutator/header_mutator.go index b2ad4999f..580ff0a19 100644 --- a/internal/extproc/headermutator/header_mutator.go +++ b/internal/headermutator/header_mutator.go @@ -8,9 +8,6 @@ package headermutator import ( "strings" - corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/internalapi" ) @@ -31,11 +28,10 @@ func NewHeaderMutator(headerMutations *filterapi.HTTPHeaderMutation, originalHea } // Mutate mutates the headers based on the header mutations and restores original headers if mutated previously. -func (h *HeaderMutator) Mutate(headers map[string]string, onRetry bool) *extprocv3.HeaderMutation { +func (h *HeaderMutator) Mutate(headers map[string]string, onRetry bool) (sets []internalapi.Header, removes []string) { skipRemove := h.headerMutations == nil || len(h.headerMutations.Remove) == 0 skipSet := h.headerMutations == nil || len(h.headerMutations.Set) == 0 - headerMutation := &extprocv3.HeaderMutation{} // Removes sensitive headers before sending to backend. removedHeadersSet := make(map[string]struct{}) if !skipRemove { @@ -47,7 +43,7 @@ func (h *HeaderMutator) Mutate(headers map[string]string, onRetry bool) *extproc if _, ok := headers[key]; ok { // Do NOT delete from the local headers map so metrics can still read it. // Instead, always instruct Envoy to remove it before forwarding upstream. - headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, key) + removes = append(removes, key) } } } @@ -62,9 +58,7 @@ func (h *HeaderMutator) Mutate(headers map[string]string, onRetry bool) *extproc } setHeadersSet[key] = struct{}{} headers[key] = h.Value - headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{Key: h.Name, RawValue: []byte(h.Value)}, - }) + sets = append(sets, internalapi.Header{key, h.Value}) } } @@ -83,9 +77,7 @@ func (h *HeaderMutator) Mutate(headers map[string]string, onRetry bool) *extproc headers[key] = v if !isRemoved { setHeadersSet[key] = struct{}{} - headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{Key: key, RawValue: []byte(v)}, - }) + sets = append(sets, internalapi.Header{key, v}) } } } @@ -103,17 +95,15 @@ func (h *HeaderMutator) Mutate(headers map[string]string, onRetry bool) *extproc originalValue, exists := h.originalHeaders[key] if !exists { delete(headers, key) - headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, key) + removes = append(removes, key) } else { // Restore original value. headers[key] = originalValue - headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{Key: key, RawValue: []byte(originalValue)}, - }) + sets = append(sets, internalapi.Header{key, originalValue}) } } } - return headerMutation + return } // shouldIgnoreHeader returns true if the header key should be ignored for mutation. diff --git a/internal/extproc/headermutator/header_mutator_test.go b/internal/headermutator/header_mutator_test.go similarity index 84% rename from internal/extproc/headermutator/header_mutator_test.go rename to internal/headermutator/header_mutator_test.go index aac0473eb..9526749c2 100644 --- a/internal/extproc/headermutator/header_mutator_test.go +++ b/internal/headermutator/header_mutator_test.go @@ -26,13 +26,12 @@ func TestHeaderMutator_Mutate(t *testing.T) { Set: []filterapi.HTTPHeader{{Name: "x-new-header", Value: "newval"}}, } mutator := NewHeaderMutator(mutations, nil) - mutation := mutator.Mutate(headers, false) + sets, removes := mutator.Mutate(headers, false) - require.NotNil(t, mutation) - require.ElementsMatch(t, []string{"authorization", "x-api-key"}, mutation.RemoveHeaders) - require.Len(t, mutation.SetHeaders, 1) - require.Equal(t, "x-new-header", mutation.SetHeaders[0].Header.Key) - require.Equal(t, []byte("newval"), mutation.SetHeaders[0].Header.RawValue) + require.ElementsMatch(t, []string{"authorization", "x-api-key"}, removes) + require.Len(t, sets, 1) + require.Equal(t, "x-new-header", sets[0][0]) + require.Equal(t, "newval", sets[0][1]) // Sensitive headers remain locally for metrics, but will be stripped upstream by Envoy. require.Equal(t, "secret", headers["authorization"]) require.Equal(t, "key123", headers["x-api-key"]) @@ -65,14 +64,14 @@ func TestHeaderMutator_Mutate(t *testing.T) { Set: []filterapi.HTTPHeader{}, } mutator := NewHeaderMutator(mutations, originalHeaders) - mutation := mutator.Mutate(headers, true) + sets, removes := mutator.Mutate(headers, true) - require.NotNil(t, mutation) - require.ElementsMatch(t, []string{"authorization", "only-set-previously"}, mutation.RemoveHeaders) - require.Len(t, mutation.SetHeaders, 4) + require.ElementsMatch(t, []string{"authorization", "only-set-previously"}, removes) + require.Len(t, sets, 4) setHeadersMap := make(map[string]string) - for _, h := range mutation.SetHeaders { - setHeadersMap[h.Header.Key] = string(h.Header.RawValue) + for _, h := range sets { + key, value := h[0], h[1] + setHeadersMap[key] = value } require.Equal(t, "key123", setHeadersMap["x-api-key"]) require.Equal(t, "value", setHeadersMap["other"]) diff --git a/internal/internalapi/headers.go b/internal/internalapi/headers.go new file mode 100644 index 000000000..36878bd79 --- /dev/null +++ b/internal/internalapi/headers.go @@ -0,0 +1,19 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package internalapi + +// Header represents a single HTTP header as a key-value pair. +type Header [2]string + +// Key returns the header key. +func (h Header) Key() string { + return h[0] +} + +// Value returns the header value. +func (h Header) Value() string { + return h[1] +} From 66eaa8cc7efe67f3f6ed269742fe50ede5dd00d2 Mon Sep 17 00:00:00 2001 From: yxia216 Date: Thu, 6 Nov 2025 10:29:43 -0500 Subject: [PATCH 6/6] fix-tests Signed-off-by: yxia216 --- tests/extproc/testupstream_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/extproc/testupstream_test.go b/tests/extproc/testupstream_test.go index 75d8abc8a..b044254c8 100644 --- a/tests/extproc/testupstream_test.go +++ b/tests/extproc/testupstream_test.go @@ -563,7 +563,9 @@ data: {"choices":[{"index":0,"delta":{"content":" you","role":"assistant"}}],"ob data: {"choices":[{"index":0,"delta":{"content":" today","role":"assistant"}}],"object":"chat.completion.chunk"} -data: {"choices":[{"index":0,"delta":{"content":"?","role":"assistant"},"finish_reason":"stop"}],"object":"chat.completion.chunk","usage":{"prompt_tokens":10,"completion_tokens":7,"total_tokens":17,"completion_tokens_details":{},"prompt_tokens_details":{}}} +data: {"choices":[{"index":0,"delta":{"content":"?","role":"assistant"},"finish_reason":"stop"}],"object":"chat.completion.chunk"} + +data: {"object":"chat.completion.chunk","usage":{"prompt_tokens":10,"completion_tokens":7,"total_tokens":17,"completion_tokens_details":{},"prompt_tokens_details":{}}} data: [DONE] `,