Skip to content

Commit 312acd9

Browse files
committed
more
Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
1 parent e082155 commit 312acd9

17 files changed

+65
-51
lines changed

internal/extproc/chatcompletion_processor_test.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseHeaders(t *testin
206206
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
207207
require.NoError(t, err)
208208
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
209-
require.Equal(t, mt.retHeaderMutation, commonRes.HeaderMutation)
209+
require.Empty(t, commonRes.HeaderMutation.SetHeaders)
210210
mm.RequireRequestNotCompleted(t)
211211
require.Nil(t, res.ModeOverride)
212212
})
@@ -221,7 +221,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseHeaders(t *testin
221221
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
222222
require.NoError(t, err)
223223
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
224-
require.Equal(t, mt.retHeaderMutation, commonRes.HeaderMutation)
224+
require.Empty(t, commonRes.HeaderMutation)
225225
require.Equal(t, &extprocv3http.ProcessingMode{ResponseBodyMode: extprocv3http.ProcessingMode_STREAMED}, res.ModeOverride)
226226
})
227227
t.Run("error/streaming", func(t *testing.T) {
@@ -235,7 +235,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseHeaders(t *testin
235235
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
236236
require.NoError(t, err)
237237
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
238-
require.Equal(t, mt.retHeaderMutation, commonRes.HeaderMutation)
238+
require.Empty(t, commonRes.HeaderMutation)
239239
require.Nil(t, res.ModeOverride)
240240
})
241241
}
@@ -333,11 +333,13 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
333333
res, err := p.ProcessResponseBody(t.Context(), inBody)
334334
require.NoError(t, err)
335335
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response
336-
require.Equal(t, "error-body", commonRes.BodyMutation.GetBody())
336+
require.Equal(t, "error-body", string(commonRes.BodyMutation.GetBody()))
337+
require.Len(t, commonRes.HeaderMutation.SetHeaders, 1)
338+
require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[0].Header.Key)
339+
require.Equal(t, []byte("bar"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue)
337340
require.Len(t, commonRes.HeaderMutation.SetHeaders, 1)
338341
require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[0].Header.Key)
339342
require.Equal(t, []byte("bar"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue)
340-
require.Equal(t, expHeadMut, commonRes.HeaderMutation)
341343
mm.RequireRequestFailure(t)
342344
})
343345

@@ -472,7 +474,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing
472474
t.Run("ok", func(t *testing.T) {
473475
someBody := bodyFromModel(t, "some-model", tc.stream, nil)
474476
headers := map[string]string{":path": "/foo", internalapi.ModelNameHeaderKeyDefault: "some-model"}
475-
headerMut := []internalapi.Header{{"foo", "bar"}}
477+
headerMut := []internalapi.Header{{"a", "b"}}
476478
bodyMut := []byte("some body")
477479

478480
var expBody openai.ChatCompletionRequest
@@ -504,9 +506,11 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing
504506
require.NotNil(t, resp)
505507
commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders).RequestHeaders.Response
506508
require.Equal(t, string(bodyMut), string(commonRes.BodyMutation.GetBody()))
507-
require.Len(t, commonRes.HeaderMutation.SetHeaders, 1)
508-
require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[0].Header.Key)
509-
require.Equal(t, []byte("bar"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue)
509+
require.Len(t, commonRes.HeaderMutation.SetHeaders, 2)
510+
require.Equal(t, "a", commonRes.HeaderMutation.SetHeaders[0].Header.Key)
511+
require.Equal(t, []byte("b"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue)
512+
require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[1].Header.Key)
513+
require.Equal(t, "mock-auth-handler", string(commonRes.HeaderMutation.SetHeaders[1].Header.RawValue))
510514

511515
mm.RequireRequestNotCompleted(t)
512516
// Verify models were set

internal/extproc/completions_processor_test.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,10 @@ func Test_completionsProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) {
172172
require.True(t, ok)
173173
require.NotNil(t, re)
174174
require.NotNil(t, re.ResponseBody)
175-
require.Equal(t, mt.resErrorHeaderMutation, re.ResponseBody.GetResponse().GetHeaderMutation())
176-
require.Equal(t, mt.resErrorBodyMutation, re.ResponseBody.GetResponse().GetBodyMutation())
175+
require.Len(t, re.ResponseBody.GetResponse().GetHeaderMutation().SetHeaders, 1)
176+
require.Equal(t, "test", re.ResponseBody.GetResponse().GetHeaderMutation().SetHeaders[0].Header.Key)
177+
require.Equal(t, "error", string(re.ResponseBody.GetResponse().GetHeaderMutation().SetHeaders[0].Header.RawValue))
178+
require.Equal(t, "test error", string(re.ResponseBody.GetResponse().GetBodyMutation().GetBody()))
177179
})
178180

179181
t.Run("successful response with token usage", func(t *testing.T) {
@@ -204,8 +206,10 @@ func Test_completionsProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) {
204206
require.True(t, ok)
205207
require.NotNil(t, re)
206208
require.NotNil(t, re.ResponseBody)
207-
require.Equal(t, mt.resHeaderMutation, re.ResponseBody.GetResponse().GetHeaderMutation())
208-
require.Equal(t, mt.resBodyMutation, re.ResponseBody.GetResponse().GetBodyMutation())
209+
require.Len(t, re.ResponseBody.GetResponse().GetHeaderMutation().SetHeaders, 1)
210+
require.Equal(t, "test", re.ResponseBody.GetResponse().GetHeaderMutation().SetHeaders[0].Header.Key)
211+
require.Equal(t, "success", string(re.ResponseBody.GetResponse().GetHeaderMutation().SetHeaders[0].Header.RawValue))
212+
require.Equal(t, "response body", string(re.ResponseBody.GetResponse().GetBodyMutation().GetBody()))
209213

210214
// Check that costs were accumulated
211215
require.Equal(t, uint32(10), p.costs.InputTokens)
@@ -564,7 +568,7 @@ func Test_completionsProcessorUpstreamFilter_ProcessResponseHeaders_Streaming(t
564568
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
565569
require.NoError(t, err)
566570
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
567-
require.Equal(t, mt.resHeaderMutation, commonRes.HeaderMutation)
571+
require.Empty(t, commonRes.HeaderMutation.SetHeaders)
568572
require.Nil(t, res.ModeOverride)
569573
})
570574

@@ -578,7 +582,7 @@ func Test_completionsProcessorUpstreamFilter_ProcessResponseHeaders_Streaming(t
578582
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
579583
require.NoError(t, err)
580584
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
581-
require.Equal(t, mt.resHeaderMutation, commonRes.HeaderMutation)
585+
require.Empty(t, commonRes.HeaderMutation.SetHeaders)
582586
require.Equal(t, &extprocv3http.ProcessingMode{ResponseBodyMode: extprocv3http.ProcessingMode_STREAMED}, res.ModeOverride)
583587
})
584588

@@ -592,7 +596,7 @@ func Test_completionsProcessorUpstreamFilter_ProcessResponseHeaders_Streaming(t
592596
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
593597
require.NoError(t, err)
594598
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
595-
require.Equal(t, mt.resHeaderMutation, commonRes.HeaderMutation)
599+
require.Empty(t, commonRes.HeaderMutation.SetHeaders)
596600
require.Nil(t, res.ModeOverride)
597601
})
598602
}
@@ -799,7 +803,7 @@ func Test_completionsProcessorUpstreamFilter_CELCostEvaluation(t *testing.T) {
799803
require.NoError(t, err)
800804
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response
801805
require.Equal(t, string(expBody), string(commonRes.BodyMutation.GetBody()))
802-
require.Len(t, commonRes.HeaderMutation, 1)
806+
require.Len(t, commonRes.HeaderMutation.SetHeaders, 1)
803807
require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[0].Header.Key)
804808
require.Equal(t, "bar", string(commonRes.HeaderMutation.SetHeaders[0].Header.RawValue))
805809
mm.RequireRequestSuccess(t)

internal/extproc/embeddings_processor_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func Test_embeddingsProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T)
119119
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
120120
require.NoError(t, err)
121121
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
122-
require.Equal(t, mt.retHeaderMutation, commonRes.HeaderMutation)
122+
require.Empty(t, commonRes.HeaderMutation.SetHeaders)
123123
mm.RequireRequestNotCompleted(t)
124124
})
125125
}
@@ -329,7 +329,7 @@ func Test_embeddingsProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T)
329329
t.Run("ok", func(t *testing.T) {
330330
someBody := embeddingBodyFromModel(t, "some-model")
331331
headers := map[string]string{":path": "/foo", internalapi.ModelNameHeaderKeyDefault: "some-model"}
332-
headerMut := []internalapi.Header{{"foo", "bar"}}
332+
headerMut := []internalapi.Header{{"a", "b"}}
333333
bodyMut := []byte("some body")
334334

335335
var expBody openai.EmbeddingRequest
@@ -351,9 +351,11 @@ func Test_embeddingsProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T)
351351
require.Equal(t, mt, p.translator)
352352
require.NotNil(t, resp)
353353
commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders).RequestHeaders.Response
354-
require.Len(t, commonRes.HeaderMutation.SetHeaders, 1)
355-
require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[0].Header.Key)
356-
require.Equal(t, []byte("bar"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue)
354+
require.Len(t, commonRes.HeaderMutation.SetHeaders, 2)
355+
require.Equal(t, "a", commonRes.HeaderMutation.SetHeaders[0].Header.Key)
356+
require.Equal(t, []byte("b"), commonRes.HeaderMutation.SetHeaders[0].Header.RawValue)
357+
require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[1].Header.Key)
358+
require.Equal(t, []byte("mock-auth-handler"), commonRes.HeaderMutation.SetHeaders[1].Header.RawValue)
357359
require.Equal(t, string(bodyMut), string(commonRes.BodyMutation.GetBody()))
358360

359361
mm.RequireRequestNotCompleted(t)

internal/extproc/imagegeneration_processor_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseHeaders(t *testi
216216
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
217217
require.NoError(t, err)
218218
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
219-
require.Equal(t, mt.retHeaderMutation, commonRes.HeaderMutation)
219+
require.Empty(t, commonRes.HeaderMutation.SetHeaders)
220220
mm.RequireRequestNotCompleted(t)
221221
})
222222
}
@@ -274,7 +274,7 @@ func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseBody(t *testing.
274274
res, err := p.ProcessResponseBody(t.Context(), inBody)
275275
require.NoError(t, err)
276276
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response
277-
require.Equal(t, nil, commonRes.BodyMutation)
277+
require.Nil(t, commonRes.BodyMutation)
278278
require.Equal(t, &extprocv3.HeaderMutation{}, commonRes.HeaderMutation)
279279
mm.RequireRequestSuccess(t)
280280
require.Equal(t, 124, mm.tokenUsageCount) // 1 input + 123 output
@@ -311,7 +311,7 @@ func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseBody(t *testing.
311311
res, err := p.ProcessResponseBody(t.Context(), inBody)
312312
require.NoError(t, err)
313313
commonRes := res.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response
314-
require.Equal(t, expBody, commonRes.BodyMutation.GetBody())
314+
require.Equal(t, string(expBody), string(commonRes.BodyMutation.GetBody()))
315315
require.Len(t, commonRes.HeaderMutation.SetHeaders, 1)
316316
require.Equal(t, "foo", commonRes.HeaderMutation.SetHeaders[0].Header.Key)
317317
require.Equal(t, "bar", string(commonRes.HeaderMutation.SetHeaders[0].Header.RawValue))
@@ -352,7 +352,7 @@ func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseBody(t *testing.
352352
require.Len(t, reqHM.SetHeaders, 1)
353353
require.Equal(t, "foo", reqHM.SetHeaders[0].Header.Key)
354354
require.Equal(t, "bar", string(reqHM.SetHeaders[0].Header.RawValue))
355-
require.Equal(t, []byte("changed"), commonRes.BodyMutation.GetBody())
355+
require.Equal(t, "changed", string(commonRes.BodyMutation.GetBody()))
356356
mm.RequireRequestSuccess(t)
357357
})
358358
}

internal/extproc/rerank_processor_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,12 @@ func Test_rerankProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) {
128128
require.NoError(t, err)
129129
req := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders)
130130
common := req.RequestHeaders.Response
131-
require.Equal(t, headerMut, common.HeaderMutation)
132-
require.Equal(t, bodyMut, common.BodyMutation)
131+
require.Len(t, common.HeaderMutation.SetHeaders, 2)
132+
require.Equal(t, "foo", common.HeaderMutation.SetHeaders[0].Header.Key)
133+
require.Equal(t, "bar", string(common.HeaderMutation.SetHeaders[0].Header.RawValue))
134+
require.Equal(t, "foo", common.HeaderMutation.SetHeaders[1].Header.Key)
135+
require.Equal(t, "mock-auth-handler", string(common.HeaderMutation.SetHeaders[1].Header.RawValue))
136+
require.Equal(t, "patched", string(common.BodyMutation.GetBody()))
133137
// Not completed yet
134138
mm.RequireRequestNotCompleted(t)
135139
require.Equal(t, "rerank-english-v3", mm.originalModel)
@@ -329,7 +333,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T) {
329333
res, err := p.ProcessResponseHeaders(t.Context(), inHeaders)
330334
require.NoError(t, err)
331335
common := res.Response.(*extprocv3.ProcessingResponse_ResponseHeaders).ResponseHeaders.Response
332-
require.Equal(t, mt.retHeaderMutation, common.HeaderMutation)
336+
require.Empty(t, common.HeaderMutation.SetHeaders)
333337
mm.RequireRequestNotCompleted(t)
334338
}
335339

@@ -383,7 +387,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) {
383387
require.NoError(t, err)
384388
common := resp.Response.(*extprocv3.ProcessingResponse_ResponseBody).ResponseBody.Response
385389
require.NotNil(t, common.HeaderMutation)
386-
require.NotNil(t, common.BodyMutation)
390+
require.Nil(t, common.BodyMutation)
387391
mm.RequireTokenUsage(t, 10)
388392
mm.RequireRequestSuccess(t)
389393
// Response model chosen is retResponseModel

internal/extproc/translator/anthropic_gcpanthropic.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type anthropicToGCPAnthropicTranslator struct {
3333

3434
// RequestBody implements [AnthropicMessagesTranslator.RequestBody] for Anthropic to GCP Anthropic translation.
3535
// This handles the transformation from native Anthropic format to GCP Anthropic format.
36-
func (a *anthropicToGCPAnthropicTranslator) RequestBody(raw []byte, body *anthropicschema.MessagesRequest, _ bool) (
36+
func (a *anthropicToGCPAnthropicTranslator) RequestBody(_ []byte, body *anthropicschema.MessagesRequest, _ bool) (
3737
newHeaders []internalapi.Header, newBody []byte, err error,
3838
) {
3939
// Extract model name for GCP endpoint from the parsed request.

internal/extproc/translator/anthropic_gcpanthropic_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func TestAnthropicToGCPAnthropicTranslator_RequestBody_ModelNameOverride(t *test
7474
pathHeader := headerMutation[0]
7575
require.Equal(t, pathHeaderName, pathHeader.Key())
7676
expectedPath := "publishers/anthropic/models/" + tt.expectedInPath + ":rawPredict"
77-
assert.Equal(t, expectedPath, string(pathHeader.Value()))
77+
assert.Equal(t, expectedPath, pathHeader.Value())
7878

7979
// Check that model field is removed from body (since it's in the path).
8080
var modifiedReq map[string]any

internal/extproc/translator/cohere_rerank_v2_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func TestCohereToCohereTranslatorV2Rerank_RequestBody_InvalidJSONCreatesBodyWith
9999
// Verify content-length header is set alongside :path
100100
require.GreaterOrEqual(t, len(headerMutation), 2)
101101
require.Equal(t, pathHeaderName, headerMutation[0].Key())
102-
require.Equal(t, "/v2/rerank", string(headerMutation[0].Value()))
102+
require.Equal(t, "/v2/rerank", headerMutation[0].Value())
103103
require.Equal(t, contentLengthHeaderName, headerMutation[1].Key())
104104
}
105105

internal/extproc/translator/imagegeneration_openai_openai_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestOpenAIToOpenAIImageTranslator_RequestBody_ModelOverrideAndPath(t *testi
2727
require.NotNil(t, hm)
2828
require.Len(t, hm, 2) // path and content-length headers
2929
require.Equal(t, pathHeaderName, hm[0].Key())
30-
require.Equal(t, "/v1/images/generations", string(hm[0].Value()))
30+
require.Equal(t, "/v1/images/generations", hm[0].Value())
3131
require.Equal(t, contentLengthHeaderName, hm[1].Key())
3232

3333
require.NotNil(t, bm)

internal/extproc/translator/openai_awsbedrock_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
12571257
require.NotNil(t, hm)
12581258
require.Len(t, hm, 2)
12591259
require.Equal(t, pathHeaderName, hm[0].Key())
1260-
require.Equal(t, "/model/"+url.PathEscape(modelNameOverride)+"/converse", string(hm[0].Value()))
1260+
require.Equal(t, "/model/"+url.PathEscape(modelNameOverride)+"/converse", hm[0].Value())
12611261
})
12621262
}
12631263

@@ -1408,7 +1408,7 @@ func TestOpenAIToAWSBedrockTranslator_ResponseError(t *testing.T) {
14081408
require.NotNil(t, hm)
14091409
require.Len(t, hm, 2)
14101410
require.Equal(t, contentTypeHeaderName, hm[0].Key())
1411-
require.Equal(t, jsonContentType, hm[0].Value())
1411+
require.JSONEq(t, jsonContentType, hm[0].Value())
14121412
require.Equal(t, contentLengthHeaderName, hm[1].Key())
14131413
require.Equal(t, strconv.Itoa(len(newBody)), hm[1].Value())
14141414

@@ -1669,7 +1669,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
16691669
require.NotNil(t, hm)
16701670
require.Len(t, hm, 1)
16711671
require.Equal(t, contentLengthHeaderName, hm[0].Key())
1672-
require.Equal(t, strconv.Itoa(len(newBody)), string(hm[0].Value()))
1672+
require.Equal(t, strconv.Itoa(len(newBody)), hm[0].Value())
16731673

16741674
expectedBody, err := json.Marshal(tt.output)
16751675
require.NoError(t, err)
@@ -1727,7 +1727,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBodyURLEncoding(t *
17271727
require.NotNil(t, hm)
17281728
require.Len(t, hm, 2)
17291729
require.Equal(t, pathHeaderName, hm[0].Key())
1730-
require.Equal(t, tt.expectedPath, string(hm[0].Value()))
1730+
require.Equal(t, tt.expectedPath, hm[0].Value())
17311731
})
17321732
}
17331733
}

0 commit comments

Comments
 (0)