diff --git a/internal/extproc/translator/openai_gcpanthropic.go b/internal/extproc/translator/openai_gcpanthropic.go index 79c9ea562..911de7260 100644 --- a/internal/extproc/translator/openai_gcpanthropic.go +++ b/internal/extproc/translator/openai_gcpanthropic.go @@ -34,6 +34,14 @@ const ( tempNotSupportedError = "temperature %.2f is not supported by Anthropic (must be between 0.0 and 1.0)" ) +// anthropicInputSchemaKeysToSkip defines the keys from an OpenAI function parameter map +// that are handled explicitly and should not go into the ExtraFields map. +var anthropicInputSchemaKeysToSkip = map[string]struct{}{ + "required": {}, + "type": {}, + "properties": {}, +} + // NewChatCompletionOpenAIToGCPAnthropicTranslator implements [Factory] for OpenAI to GCP Anthropic translation. // This translator converts OpenAI ChatCompletion API requests to GCP Anthropic API format. func NewChatCompletionOpenAIToGCPAnthropicTranslator(apiVersion string, modelNameOverride internalapi.ModelNameOverride) OpenAIChatCompletionTranslator { @@ -163,16 +171,6 @@ func translateOpenAItoAnthropicTools(openAITools []openai.Tool, openAIToolChoice inputSchema := anthropic.ToolInputSchemaParam{} - // Dereference json schema - // If the paramsMap contains $refs we need to dereference them - var dereferencedParamsMap any - if dereferencedParamsMap, err = jsonSchemaDereference(paramsMap); err != nil { - return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("failed to dereference tool parameters: %w", err) - } - if paramsMap, ok = dereferencedParamsMap.(map[string]any); !ok { - return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("failed to cast dereferenced tool parameters to map[string]interface{}") - } - var typeVal string if typeVal, ok = paramsMap["type"].(string); ok { inputSchema.Type = constant.Object(typeVal) @@ -194,6 +192,21 @@ func translateOpenAItoAnthropicTools(openAITools []openai.Tool, openAIToolChoice inputSchema.Required = requiredSlice } + // ExtraFieldsMap to construct + ExtraFieldsMap := make(map[string]any) + + // Iterate over the original map from openai + for key, value := range paramsMap { + // Check if the current key should be skipped + if _, found := anthropicInputSchemaKeysToSkip[key]; found { + continue + } + + // If not skipped, add the key-value pair to extra field map + ExtraFieldsMap[key] = value + } + inputSchema.ExtraFields = ExtraFieldsMap + toolParam.InputSchema = inputSchema } diff --git a/internal/extproc/translator/openai_gcpanthropic_test.go b/internal/extproc/translator/openai_gcpanthropic_test.go index faf49ed92..f410f6cd1 100644 --- a/internal/extproc/translator/openai_gcpanthropic_test.go +++ b/internal/extproc/translator/openai_gcpanthropic_test.go @@ -1248,6 +1248,77 @@ func TestTranslateOpenAItoAnthropicTools(t *testing.T) { }, expectErr: true, }, + { + name: "nested schema in tool's defintions", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather without type", + Parameters: map[string]any{ + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + "required": []any{"location"}, + "$defs": map[string]any{ + "ReferencePassage": map[string]any{ + "properties": map[string]any{ + "url": map[string]any{ + "title": "Url", + "type": "string", + }, + "passage_id": map[string]any{ + "title": "Passage Id", + "type": "string", + }, + }, + "required": []string{"url", "passage_id"}, + "title": "ReferencePassage", + "type": "object", + }, + }, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.String("Get the weather without type"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "", + Properties: map[string]any{ + "location": map[string]any{"type": "string"}, + }, + Required: []string{"location"}, + ExtraFields: map[string]any{ + "$defs": map[string]any{ + "ReferencePassage": map[string]any{ + "properties": map[string]any{ + "url": map[string]any{ + "title": "Url", + "type": "string", + }, + "passage_id": map[string]any{ + "title": "Passage Id", + "type": "string", + }, + }, + "required": []string{"url", "passage_id"}, + "title": "ReferencePassage", + "type": "object", + }, + }, + }, + }, + }, + }, + }, + }, } for _, tt := range tests { @@ -1318,273 +1389,6 @@ func TestFinishReasonTranslation(t *testing.T) { } } -// TestToolParameterDereferencing tests the JSON schema dereferencing functionality -// for tool parameters when translating from OpenAI to GCP Anthropic. -func TestToolParameterDereferencing(t *testing.T) { - tests := []struct { - name string - openAIReq *openai.ChatCompletionRequest - expectedTools []anthropic.ToolUnionParam - expectedToolChoice anthropic.ToolChoiceUnionParam - expectErr bool - expectedErrMsg string - }{ - { - name: "tool with complex nested $ref - successful dereferencing", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "complex_tool", - Description: "Tool with complex nested references", - Parameters: map[string]any{ - "type": "object", - "$defs": map[string]any{ - "BaseType": map[string]any{ - "type": "object", - "properties": map[string]any{ - "id": map[string]any{ - "type": "string", - }, - "required": []any{"id"}, - }, - }, - "NestedType": map[string]any{ - "allOf": []any{ - map[string]any{"$ref": "#/$defs/BaseType"}, - map[string]any{ - "properties": map[string]any{ - "name": map[string]any{ - "type": "string", - }, - }, - }, - }, - }, - }, - "properties": map[string]any{ - "nested": map[string]any{ - "$ref": "#/$defs/NestedType", - }, - }, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "complex_tool", - Description: anthropic.String("Tool with complex nested references"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "nested": map[string]any{ - "allOf": []any{ - map[string]any{ - "type": "object", - "properties": map[string]any{ - "id": map[string]any{ - "type": "string", - }, - "required": []any{"id"}, - }, - }, - map[string]any{ - "properties": map[string]any{ - "name": map[string]any{ - "type": "string", - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - { - name: "tool with invalid $ref - dereferencing error", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "invalid_ref_tool", - Description: "Tool with invalid reference", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{ - "$ref": "#/$defs/NonExistent", - }, - }, - }, - }, - }, - }, - }, - expectErr: true, - expectedErrMsg: "failed to dereference tool parameters", - }, - { - name: "tool with circular $ref - dereferencing error", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "circular_ref_tool", - Description: "Tool with circular reference", - Parameters: map[string]any{ - "type": "object", - "$defs": map[string]any{ - "A": map[string]any{ - "type": "object", - "properties": map[string]any{ - "b": map[string]any{ - "$ref": "#/$defs/B", - }, - }, - }, - "B": map[string]any{ - "type": "object", - "properties": map[string]any{ - "a": map[string]any{ - "$ref": "#/$defs/A", - }, - }, - }, - }, - "properties": map[string]any{ - "circular": map[string]any{ - "$ref": "#/$defs/A", - }, - }, - }, - }, - }, - }, - }, - expectErr: true, - expectedErrMsg: "failed to dereference tool parameters", - }, - { - name: "tool without $ref - no dereferencing needed", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "simple_tool", - Description: "Simple tool without references", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{ - "type": "string", - }, - }, - "required": []any{"location"}, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "simple_tool", - Description: anthropic.String("Simple tool without references"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "location": map[string]any{ - "type": "string", - }, - }, - Required: []string{"location"}, - }, - }, - }, - }, - }, - { - name: "tool parameter dereferencing returns non-map type - casting error", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "problematic_tool", - Description: "Tool with parameters that can't be properly dereferenced to map", - // This creates a scenario where jsonSchemaDereference might return a non-map type - // though this is a contrived example since normally the function should return map[string]any - Parameters: map[string]any{ - "$ref": "#/$defs/StringType", // This would resolve to a string, not a map - "$defs": map[string]any{ - "StringType": "not-a-map", // This would cause the casting to fail - }, - }, - }, - }, - }, - }, - expectErr: true, - expectedErrMsg: "failed to cast dereferenced tool parameters", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tools, toolChoice, err := translateOpenAItoAnthropicTools(tt.openAIReq.Tools, tt.openAIReq.ToolChoice, tt.openAIReq.ParallelToolCalls) - - if tt.expectErr { - require.Error(t, err) - if tt.expectedErrMsg != "" { - require.Contains(t, err.Error(), tt.expectedErrMsg) - } - return - } - - require.NoError(t, err) - - if tt.openAIReq.Tools != nil { - require.NotNil(t, tools) - require.Len(t, tools, len(tt.expectedTools)) - - for i, expectedTool := range tt.expectedTools { - actualTool := tools[i] - require.Equal(t, expectedTool.GetName(), actualTool.GetName()) - require.Equal(t, expectedTool.GetType(), actualTool.GetType()) - require.Equal(t, expectedTool.GetDescription(), actualTool.GetDescription()) - - expectedSchema := expectedTool.GetInputSchema() - actualSchema := actualTool.GetInputSchema() - - require.Equal(t, expectedSchema.Type, actualSchema.Type) - require.Equal(t, expectedSchema.Required, actualSchema.Required) - - // For properties, we'll do a deep comparison to verify dereferencing worked - if expectedSchema.Properties != nil { - require.NotNil(t, actualSchema.Properties) - require.Equal(t, expectedSchema.Properties, actualSchema.Properties) - } - } - } - - if tt.openAIReq.ToolChoice != nil { - require.NotNil(t, toolChoice) - require.Equal(t, *tt.expectedToolChoice.GetType(), *toolChoice.GetType()) - } - }) - } -} - // TestContentTranslationCoverage adds specific coverage for the openAIToAnthropicContent helper. func TestContentTranslationCoverage(t *testing.T) { tests := []struct {