mirror of
https://github.com/coaidev/coai.git
synced 2025-05-28 17:30:15 +09:00
add gpt-4-vision-preview support
This commit is contained in:
parent
bb783f04b2
commit
4ee2051ca9
@ -24,16 +24,59 @@ func processFormat(data string) string {
|
||||
return item
|
||||
}
|
||||
|
||||
func formatMessages(props *ChatProps) []globals.Message {
|
||||
func formatMessages(props *ChatProps) interface{} {
|
||||
if props.Model == globals.GPT4Vision {
|
||||
base := props.Message[len(props.Message)-1].Content
|
||||
urls := utils.ExtractUrls(base)
|
||||
urls := utils.ExtractImageUrls(base)
|
||||
|
||||
if len(urls) > 0 {
|
||||
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
|
||||
}
|
||||
props.Message[len(props.Message)-1].Content = base
|
||||
return props.Message
|
||||
} else if props.Model == globals.GPT41106VisionPreview {
|
||||
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
|
||||
if message.Role == globals.User {
|
||||
urls := utils.ExtractImageUrls(message.Content)
|
||||
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
|
||||
obj, err := utils.NewImage(url)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
props.Buffer.AddImage(obj)
|
||||
|
||||
return &MessageContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &ImageUrl{
|
||||
Url: url,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
return Message{
|
||||
Role: message.Role,
|
||||
Content: utils.Prepend(images, MessageContent{
|
||||
Type: "text",
|
||||
Text: &message.Content,
|
||||
}),
|
||||
ToolCalls: message.ToolCalls,
|
||||
ToolCallId: message.ToolCallId,
|
||||
}
|
||||
}
|
||||
|
||||
return Message{
|
||||
Role: message.Role,
|
||||
Content: MessageContents{
|
||||
MessageContent{
|
||||
Type: "text",
|
||||
Text: &message.Content,
|
||||
},
|
||||
},
|
||||
ToolCalls: message.ToolCalls,
|
||||
ToolCallId: message.ToolCallId,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return props.Message
|
||||
|
@ -2,10 +2,30 @@ package chatgpt
|
||||
|
||||
import "chat/globals"
|
||||
|
||||
type ImageUrl struct {
|
||||
Url string `json:"url"`
|
||||
Detail *string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
type MessageContent struct {
|
||||
Type string `json:"type"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
ImageUrl *ImageUrl `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type MessageContents []MessageContent
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content MessageContents `json:"content"`
|
||||
ToolCallId *string `json:"tool_call_id,omitempty"` // only `tool` role
|
||||
ToolCalls *globals.ToolCalls `json:"tool_calls,omitempty"` // only `assistant` role
|
||||
}
|
||||
|
||||
// ChatRequest is the request body for chatgpt
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []globals.Message `json:"messages"`
|
||||
Messages interface{} `json:"messages"`
|
||||
MaxToken *int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
|
||||
@ -16,7 +36,7 @@ type ChatRequest struct {
|
||||
ToolChoice *interface{} `json:"tool_choice,omitempty"` // string or object
|
||||
}
|
||||
|
||||
// CompletionRequest ChatRequest is the request body for chatgpt completion
|
||||
// CompletionRequest is the request body for chatgpt completion
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
|
@ -3,6 +3,7 @@ package skylark
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/volcengine/volc-sdk-golang/service/maas"
|
||||
"github.com/volcengine/volc-sdk-golang/service/maas/models/api"
|
||||
)
|
||||
@ -19,6 +20,7 @@ type ChatProps struct {
|
||||
TopP *float32
|
||||
TopK *int
|
||||
Tools *globals.FunctionTools
|
||||
Buffer utils.Buffer
|
||||
}
|
||||
|
||||
func getMessages(messages []globals.Message) []*api.Message {
|
||||
@ -54,6 +56,27 @@ func (c *ChatInstance) CreateRequest(props *ChatProps) *api.ChatReq {
|
||||
}
|
||||
}
|
||||
|
||||
func getChoice(choice *api.ChatResp, buffer utils.Buffer) string {
|
||||
if choice == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
calls := choice.Choice.Message.FunctionCall
|
||||
if calls != nil {
|
||||
buffer.SetToolCalls(&globals.ToolCalls{
|
||||
globals.ToolCall{
|
||||
Type: "function",
|
||||
Id: globals.ToolCallId(fmt.Sprintf("%s-%s", calls.Name, choice.ReqId)),
|
||||
Function: globals.ToolCallFunction{
|
||||
Name: calls.Name,
|
||||
Arguments: calls.Arguments,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return choice.Choice.Message.Content
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||
req := c.CreateRequest(props)
|
||||
channel, err := c.Instance.StreamChat(req)
|
||||
@ -66,7 +89,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global
|
||||
return partial.Error
|
||||
}
|
||||
|
||||
if err := callback(partial.Choice.Message.Content); err != nil {
|
||||
if err := callback(getChoice(partial, props.Buffer)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -72,12 +72,19 @@ func getChoice(form *ChatResponse, buffer utils.Buffer) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
buffer.SetToolCalls(&globals.ToolCalls{
|
||||
globals.ToolCall{
|
||||
Type: "text",
|
||||
Id: globals.ToolCallId(form.Header.Sid),
|
||||
},
|
||||
})
|
||||
if resp[0].FunctionCall != nil {
|
||||
buffer.SetToolCalls(&globals.ToolCalls{
|
||||
globals.ToolCall{
|
||||
Type: "function",
|
||||
Id: globals.ToolCallId(fmt.Sprintf("%s-%s", resp[0].FunctionCall.Name, resp[0].FunctionCall.Arguments)),
|
||||
Function: globals.ToolCallFunction{
|
||||
Name: resp[0].FunctionCall.Name,
|
||||
Arguments: resp[0].FunctionCall.Arguments,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return resp[0].Content
|
||||
}
|
||||
|
||||
|
@ -53,6 +53,22 @@ func InsertSlice[T any](arr []T, index int, value []T) []T {
|
||||
return arr
|
||||
}
|
||||
|
||||
func Append[T any](arr []T, value T) []T {
|
||||
return append(arr, value)
|
||||
}
|
||||
|
||||
func AppendSlice[T any](arr []T, value []T) []T {
|
||||
return append(arr, value...)
|
||||
}
|
||||
|
||||
func Prepend[T any](arr []T, value T) []T {
|
||||
return append([]T{value}, arr...)
|
||||
}
|
||||
|
||||
func PrependSlice[T any](arr []T, value []T) []T {
|
||||
return append(value, arr...)
|
||||
}
|
||||
|
||||
func Remove[T any](arr []T, index int) []T {
|
||||
return append(arr[:index], arr[index+1:]...)
|
||||
}
|
||||
|
@ -45,12 +45,10 @@ func (b *Buffer) GetChunk() string {
|
||||
return b.Latest
|
||||
}
|
||||
|
||||
func (b *Buffer) SetImages(images Images) {
|
||||
b.Images = images
|
||||
func (b *Buffer) AddImage(image *Image) {
|
||||
b.Images = append(b.Images, *image)
|
||||
|
||||
b.Quota += Sum(Each(images, func(image Image) float32 {
|
||||
return float32(image.CountTokens(b.Model)) * 0.7
|
||||
}))
|
||||
b.Quota += float32(image.CountTokens(b.Model)) * 0.7
|
||||
}
|
||||
|
||||
func (b *Buffer) GetImages() Images {
|
||||
|
@ -137,3 +137,8 @@ func ExtractUrls(data string) []string {
|
||||
re := regexp.MustCompile(`(https?://\S+)`)
|
||||
return re.FindAllString(data, -1)
|
||||
}
|
||||
|
||||
func ExtractImageUrls(data string) []string {
|
||||
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|svg|bmp))`)
|
||||
return re.FindAllString(data, -1)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user