add gpt-4-vision-preview support

This commit is contained in:
Zhang Minghan 2023-11-21 23:03:13 +08:00
parent bb783f04b2
commit 4ee2051ca9
7 changed files with 128 additions and 16 deletions

View File

@ -24,16 +24,59 @@ func processFormat(data string) string {
return item
}
func formatMessages(props *ChatProps) []globals.Message {
func formatMessages(props *ChatProps) interface{} {
if props.Model == globals.GPT4Vision {
base := props.Message[len(props.Message)-1].Content
urls := utils.ExtractUrls(base)
urls := utils.ExtractImageUrls(base)
if len(urls) > 0 {
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
}
props.Message[len(props.Message)-1].Content = base
return props.Message
} else if props.Model == globals.GPT41106VisionPreview {
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
if message.Role == globals.User {
urls := utils.ExtractImageUrls(message.Content)
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
obj, err := utils.NewImage(url)
if err != nil {
return nil
}
props.Buffer.AddImage(obj)
return &MessageContent{
Type: "image_url",
ImageUrl: &ImageUrl{
Url: url,
},
}
})
return Message{
Role: message.Role,
Content: utils.Prepend(images, MessageContent{
Type: "text",
Text: &message.Content,
}),
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
}
}
return Message{
Role: message.Role,
Content: MessageContents{
MessageContent{
Type: "text",
Text: &message.Content,
},
},
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
}
})
}
return props.Message

View File

@ -2,10 +2,30 @@ package chatgpt
import "chat/globals"
type ImageUrl struct {
Url string `json:"url"`
Detail *string `json:"detail,omitempty"`
}
type MessageContent struct {
Type string `json:"type"`
Text *string `json:"text,omitempty"`
ImageUrl *ImageUrl `json:"image_url,omitempty"`
}
type MessageContents []MessageContent
type Message struct {
Role string `json:"role"`
Content MessageContents `json:"content"`
ToolCallId *string `json:"tool_call_id,omitempty"` // only `tool` role
ToolCalls *globals.ToolCalls `json:"tool_calls,omitempty"` // only `assistant` role
}
// ChatRequest is the request body for chatgpt
type ChatRequest struct {
Model string `json:"model"`
Messages []globals.Message `json:"messages"`
Messages interface{} `json:"messages"`
MaxToken *int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
@ -16,7 +36,7 @@ type ChatRequest struct {
ToolChoice *interface{} `json:"tool_choice,omitempty"` // string or object
}
// CompletionRequest ChatRequest is the request body for chatgpt completion
// CompletionRequest is the request body for chatgpt completion
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`

View File

@ -3,6 +3,7 @@ package skylark
import (
"chat/globals"
"chat/utils"
"fmt"
"github.com/volcengine/volc-sdk-golang/service/maas"
"github.com/volcengine/volc-sdk-golang/service/maas/models/api"
)
@ -19,6 +20,7 @@ type ChatProps struct {
TopP *float32
TopK *int
Tools *globals.FunctionTools
Buffer utils.Buffer
}
func getMessages(messages []globals.Message) []*api.Message {
@ -54,6 +56,27 @@ func (c *ChatInstance) CreateRequest(props *ChatProps) *api.ChatReq {
}
}
func getChoice(choice *api.ChatResp, buffer utils.Buffer) string {
if choice == nil {
return ""
}
calls := choice.Choice.Message.FunctionCall
if calls != nil {
buffer.SetToolCalls(&globals.ToolCalls{
globals.ToolCall{
Type: "function",
Id: globals.ToolCallId(fmt.Sprintf("%s-%s", calls.Name, choice.ReqId)),
Function: globals.ToolCallFunction{
Name: calls.Name,
Arguments: calls.Arguments,
},
},
})
}
return choice.Choice.Message.Content
}
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
req := c.CreateRequest(props)
channel, err := c.Instance.StreamChat(req)
@ -66,7 +89,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global
return partial.Error
}
if err := callback(partial.Choice.Message.Content); err != nil {
if err := callback(getChoice(partial, props.Buffer)); err != nil {
return err
}
}

View File

@ -72,12 +72,19 @@ func getChoice(form *ChatResponse, buffer utils.Buffer) string {
return ""
}
buffer.SetToolCalls(&globals.ToolCalls{
globals.ToolCall{
Type: "text",
Id: globals.ToolCallId(form.Header.Sid),
},
})
if resp[0].FunctionCall != nil {
buffer.SetToolCalls(&globals.ToolCalls{
globals.ToolCall{
Type: "function",
Id: globals.ToolCallId(fmt.Sprintf("%s-%s", resp[0].FunctionCall.Name, resp[0].FunctionCall.Arguments)),
Function: globals.ToolCallFunction{
Name: resp[0].FunctionCall.Name,
Arguments: resp[0].FunctionCall.Arguments,
},
},
})
}
return resp[0].Content
}

View File

@ -53,6 +53,22 @@ func InsertSlice[T any](arr []T, index int, value []T) []T {
return arr
}
func Append[T any](arr []T, value T) []T {
return append(arr, value)
}
func AppendSlice[T any](arr []T, value []T) []T {
return append(arr, value...)
}
func Prepend[T any](arr []T, value T) []T {
return append([]T{value}, arr...)
}
func PrependSlice[T any](arr []T, value []T) []T {
return append(value, arr...)
}
func Remove[T any](arr []T, index int) []T {
return append(arr[:index], arr[index+1:]...)
}

View File

@ -45,12 +45,10 @@ func (b *Buffer) GetChunk() string {
return b.Latest
}
func (b *Buffer) SetImages(images Images) {
b.Images = images
func (b *Buffer) AddImage(image *Image) {
b.Images = append(b.Images, *image)
b.Quota += Sum(Each(images, func(image Image) float32 {
return float32(image.CountTokens(b.Model)) * 0.7
}))
b.Quota += float32(image.CountTokens(b.Model)) * 0.7
}
func (b *Buffer) GetImages() Images {

View File

@ -137,3 +137,8 @@ func ExtractUrls(data string) []string {
re := regexp.MustCompile(`(https?://\S+)`)
return re.FindAllString(data, -1)
}
func ExtractImageUrls(data string) []string {
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|svg|bmp))`)
return re.FindAllString(data, -1)
}