From 78ff6a19a4c828ccdf24f1bd9c890e1b90432c68 Mon Sep 17 00:00:00 2001 From: Deng Junhai Date: Sat, 22 Jun 2024 02:36:00 +0800 Subject: [PATCH] fix: fix tool calls `required` omitempty field Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com> --- adapter/skylark/formatter.go | 11 +++++++---- globals/tools.go | 5 +++-- utils/tokenizer.go | 4 ++++ 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/adapter/skylark/formatter.go b/adapter/skylark/formatter.go index 6e3cab3..38db695 100644 --- a/adapter/skylark/formatter.go +++ b/adapter/skylark/formatter.go @@ -3,6 +3,7 @@ package skylark import ( "chat/globals" "chat/utils" + structpb "github.com/golang/protobuf/ptypes/struct" "github.com/volcengine/volc-sdk-golang/service/maas/models/api" ) @@ -20,19 +21,21 @@ func getFunctionCall(calls *globals.ToolCalls) *api.FunctionCall { } func getType(p globals.ToolProperty) string { - if p.Type == nil { + t, ok := p["type"] + if !ok { return "string" } - return *p.Type + return t.(string) } func getDescription(p globals.ToolProperty) string { - if p.Description == nil { + desc, ok := p["description"] + if !ok { return "" } - return *p.Description + return desc.(string) } func getValue(p globals.ToolProperty) *structpb.Value { diff --git a/globals/tools.go b/globals/tools.go index 364964f..368361e 100644 --- a/globals/tools.go +++ b/globals/tools.go @@ -16,7 +16,7 @@ type ToolFunction struct { type ToolParameters struct { Type string `json:"type"` Properties ToolProperties `json:"properties"` - Required []string `json:"required"` + Required *[]string `json:"required,omitempty"` } type ToolProperties map[string]ToolProperty @@ -25,7 +25,8 @@ type ToolProperties map[string]ToolProperty type JsonSchemaType any type JSONSchemaDefinition any -type ToolProperty struct { +type ToolProperty map[string]interface{} +type DetailToolProperty struct { Type *string `json:"type,omitempty"` Enum *[]JsonSchemaType `json:"enum,omitempty"` Const *JsonSchemaType `json:"const,omitempty"` diff --git a/utils/tokenizer.go b/utils/tokenizer.go index 7784f7d..5001ae9 100644 --- a/utils/tokenizer.go +++ b/utils/tokenizer.go @@ -83,6 +83,10 @@ func NumTokensFromMessages(messages []globals.Message, model string, responseTyp } func NumTokensFromResponse(response string, model string) int { + if len(response) == 0 { + return 0 + } + return NumTokensFromMessages([]globals.Message{{Content: response}}, model, true) }