fix: fix skylark format issue when the assistant role message is the first message (#155)

Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com>
This commit is contained in:
Deng Junhai 2024-03-31 13:43:55 +08:00
parent 48a4b6d767
commit abd05d8334

View File

@ -5,6 +5,7 @@ import (
"chat/globals" "chat/globals"
"chat/utils" "chat/utils"
"fmt" "fmt"
"github.com/volcengine/volc-sdk-golang/service/maas" "github.com/volcengine/volc-sdk-golang/service/maas"
"github.com/volcengine/volc-sdk-golang/service/maas/models/api" "github.com/volcengine/volc-sdk-golang/service/maas/models/api"
) )
@ -12,17 +13,41 @@ import (
const defaultMaxTokens int64 = 1500 const defaultMaxTokens int64 = 1500
func getMessages(messages []globals.Message) []*api.Message { func getMessages(messages []globals.Message) []*api.Message {
return utils.Each[globals.Message, *api.Message](messages, func(message globals.Message) *api.Message { result := make([]*api.Message, 0)
for _, message := range messages {
if message.Role == globals.Tool { if message.Role == globals.Tool {
message.Role = maas.ChatRoleOfFunction message.Role = maas.ChatRoleOfFunction
} }
return &api.Message{ msg := &api.Message{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.Content,
FunctionCall: getFunctionCall(message.ToolCalls), FunctionCall: getFunctionCall(message.ToolCalls),
} }
})
hasPrevious := len(result) > 0
// a message should not followed by the same role message, merge them
if hasPrevious && result[len(result)-1].Role == message.Role {
prev := result[len(result)-1]
prev.Content += msg.Content
if message.ToolCalls != nil {
prev.FunctionCall = msg.FunctionCall
}
continue
}
// `assistant` message should follow a user or function message, if not has previous message, change the role to `user`
if !hasPrevious && message.Role == maas.ChatRoleOfAssistant {
msg.Role = maas.ChatRoleOfUser
}
result = append(result, msg)
}
return result
} }
func (c *ChatInstance) GetMaxTokens(token *int) int64 { func (c *ChatInstance) GetMaxTokens(token *int) int64 {