feat: add azure models and fix dashscope models

This commit is contained in:
Zhang Minghan 2023-12-29 10:31:13 +08:00
parent 61c917765b
commit 7ab1629955
8 changed files with 165 additions and 27 deletions

View File

@ -25,20 +25,31 @@ func (c *ChatInstance) GetHeader() map[string]string {
} }
} }
func (c *ChatInstance) FormatMessages(message []globals.Message) []Message {
var messages []Message
for _, v := range message {
if v.Role == globals.Tool {
continue
}
messages = append(messages, Message{
Role: v.Role,
Content: v.Content,
})
}
return messages
}
func (c *ChatInstance) GetChatBody(props *ChatProps) ChatRequest { func (c *ChatInstance) GetChatBody(props *ChatProps) ChatRequest {
if props.Token <= 0 || props.Token > 1500 { if props.Token <= 0 || props.Token > 1500 {
props.Token = 1500 props.Token = 1500
} }
return ChatRequest{ return ChatRequest{
Model: strings.TrimSuffix(props.Model, "-net"), Model: strings.TrimSuffix(props.Model, "-net"),
Input: ChatInput{ Input: ChatInput{
Messages: utils.EachNotNil(props.Message, func(message globals.Message) *globals.Message { Messages: c.FormatMessages(props.Message),
if message.Role == globals.Tool {
return nil
}
return &message
}),
}, },
Parameters: ChatParam{ Parameters: ChatParam{
MaxTokens: props.Token, MaxTokens: props.Token,

View File

@ -1,7 +1,5 @@
package dashscope package dashscope
import "chat/globals"
// ChatRequest is the request body for dashscope // ChatRequest is the request body for dashscope
type ChatRequest struct { type ChatRequest struct {
Model string `json:"model"` Model string `json:"model"`
@ -9,9 +7,13 @@ type ChatRequest struct {
Parameters ChatParam `json:"parameters"` Parameters ChatParam `json:"parameters"`
} }
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatInput struct { type ChatInput struct {
Prompt string `json:"prompt"` Messages []Message `json:"messages"`
Messages []globals.Message `json:"messages"`
} }
type ChatParam struct { type ChatParam struct {

View File

@ -91,6 +91,49 @@ export const supportModels: Model[] = [
tag: ["official", "unstable", "image-generation"], tag: ["official", "unstable", "image-generation"],
}, },
{
id: "azure-gpt-3.5-turbo",
name: "Azure GPT-3.5",
free: false,
auth: true,
tag: ["official"],
},
{
id: "azure-gpt-3.5-turbo-16k",
name: "Azure GPT-3.5 16K",
free: false,
auth: true,
tag: ["official"],
},
{
id: "azure-gpt-4",
name: "Azure GPT-4",
free: false,
auth: true,
tag: ["official", "high-quality"],
},
{
id: "azure-gpt-4-1106-preview",
name: "Azure GPT-4 Turbo 128k",
free: false,
auth: true,
tag: ["official", "high-context", "unstable"],
},
{
id: "azure-gpt-4-vision-preview",
name: "Azure GPT-4 Vision 128k",
free: false,
auth: true,
tag: ["official", "high-context", "multi-modal"],
},
{
id: "azure-gpt-4-32k",
name: "Azure GPT-4 32k",
free: false,
auth: true,
tag: ["official", "multi-modal"],
},
// spark desk // spark desk
{ {
id: "spark-desk-v3", id: "spark-desk-v3",
@ -352,6 +395,13 @@ export const defaultModels = [
"gpt-4-v", "gpt-4-v",
"gpt-4-dalle", "gpt-4-dalle",
"azure-gpt-3.5-turbo",
"azure-gpt-3.5-turbo-16k",
"azure-gpt-4",
"azure-gpt-4-1106-preview",
"azure-gpt-4-vision-preview",
"azure-gpt-4-32k",
"claude-1-100k", "claude-1-100k",
"claude-2", "claude-2",
"claude-2.1", "claude-2.1",
@ -414,6 +464,12 @@ export const modelAvatars: Record<string, string> = {
"gpt-4-32k-0613": "gpt432k.webp", "gpt-4-32k-0613": "gpt432k.webp",
"gpt-4-v": "gpt4v.png", "gpt-4-v": "gpt4v.png",
"gpt-4-dalle": "gpt4dalle.png", "gpt-4-dalle": "gpt4dalle.png",
"azure-gpt-3.5-turbo": "gpt35turbo.png",
"azure-gpt-3.5-turbo-16k": "gpt35turbo16k.webp",
"azure-gpt-4": "gpt4.png",
"azure-gpt-4-1106-preview": "gpt432k.webp",
"azure-gpt-4-vision-preview": "gpt4v.png",
"azure-gpt-4-32k": "gpt432k.webp",
"claude-1-100k": "claude.png", "claude-1-100k": "claude.png",
"claude-2": "claude100k.png", "claude-2": "claude100k.png",
"claude-2.1": "claude100k.png", "claude-2.1": "claude100k.png",

View File

@ -16,7 +16,7 @@ import (
"time" "time"
) )
func TranshipmentAPI(c *gin.Context) { func ChatRelayAPI(c *gin.Context) {
username := utils.GetUserFromContext(c) username := utils.GetUserFromContext(c)
if username == "" { if username == "" {
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid api key"), "authentication_error") abortWithErrorResponse(c, fmt.Errorf("access denied for invalid api key"), "authentication_error")
@ -28,7 +28,7 @@ func TranshipmentAPI(c *gin.Context) {
return return
} }
var form TranshipmentForm var form RelayForm
if err := c.ShouldBindJSON(&form); err != nil { if err := c.ShouldBindJSON(&form); err != nil {
abortWithErrorResponse(c, fmt.Errorf("invalid request body: %s", err.Error()), "invalid_request_error") abortWithErrorResponse(c, fmt.Errorf("invalid request body: %s", err.Error()), "invalid_request_error")
return return
@ -67,7 +67,7 @@ func TranshipmentAPI(c *gin.Context) {
} }
} }
func GetChatProps(form TranshipmentForm, messages []globals.Message, buffer *utils.Buffer, plan bool) *adapter.ChatProps { func GetChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer, plan bool) *adapter.ChatProps {
return &adapter.ChatProps{ return &adapter.ChatProps{
Model: form.Model, Model: form.Model,
Message: messages, Message: messages,
@ -85,7 +85,7 @@ func GetChatProps(form TranshipmentForm, messages []globals.Message, buffer *uti
} }
} }
func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) { func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
db := utils.GetDBFromContext(c) db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c) cache := utils.GetCacheFromContext(c)
@ -105,7 +105,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []
} }
CollectQuota(c, user, buffer, plan, err) CollectQuota(c, user, buffer, plan, err)
c.JSON(http.StatusOK, TranshipmentResponse{ c.JSON(http.StatusOK, RelayResponse{
Id: fmt.Sprintf("chatcmpl-%s", id), Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion", Object: "chat.completion",
Created: created, Created: created,
@ -126,8 +126,8 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []
}) })
} }
func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, data string, buffer *utils.Buffer, end bool, err error) TranshipmentStreamResponse { func getStreamTranshipmentForm(id string, created int64, form RelayForm, data string, buffer *utils.Buffer, end bool, err error) RelayStreamResponse {
return TranshipmentStreamResponse{ return RelayStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", id), Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: created, Created: created,
@ -152,8 +152,8 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
} }
} }
func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) { func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
partial := make(chan TranshipmentStreamResponse) partial := make(chan RelayStreamResponse)
db := utils.GetDBFromContext(c) db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c) cache := utils.GetCacheFromContext(c)

View File

@ -1 +1,57 @@
package manager package manager
import (
"chat/auth"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"strings"
"time"
)
func ImagesRelayAPI(c *gin.Context) {
username := utils.GetUserFromContext(c)
if username == "" {
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid api key"), "authentication_error")
return
}
if utils.GetAgentFromContext(c) != "api" {
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid agent"), "authentication_error")
return
}
var form RelayImageForm
if err := c.ShouldBindJSON(&form); err != nil {
abortWithErrorResponse(c, fmt.Errorf("invalid request body: %s", err.Error()), "invalid_request_error")
return
}
prompt := strings.TrimSpace(form.Prompt)
if prompt == "" {
sendErrorResponse(c, fmt.Errorf("prompt is required"), "invalid_request_error")
}
db := utils.GetDBFromContext(c)
user := &auth.User{
Username: username,
}
created := time.Now().Unix()
if strings.HasSuffix(form.Model, "-official") {
form.Model = strings.TrimSuffix(form.Model, "-official")
}
check := auth.CanEnableModel(db, user, form.Model)
if !check {
sendErrorResponse(c, fmt.Errorf("quota exceeded"), "quota_exceeded_error")
return
}
createRelayImageObject(c, form, prompt, created, user, false)
}
func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string, created int64, user *auth.User, plan bool) {
}

View File

@ -22,7 +22,7 @@ func sendErrorResponse(c *gin.Context, err error, types ...string) {
errType = "chatnio_api_error" errType = "chatnio_api_error"
} }
c.JSON(http.StatusServiceUnavailable, TranshipmentErrorResponse{ c.JSON(http.StatusServiceUnavailable, RelayErrorResponse{
Error: TranshipmentError{ Error: TranshipmentError{
Message: err.Error(), Message: err.Error(),
Type: errType, Type: errType,

View File

@ -11,7 +11,7 @@ func Register(app *gin.RouterGroup) {
app.GET("/v1/charge", ChargeAPI) app.GET("/v1/charge", ChargeAPI)
app.GET("/dashboard/billing/usage", GetBillingUsage) app.GET("/dashboard/billing/usage", GetBillingUsage)
app.GET("/dashboard/billing/subscription", GetSubscription) app.GET("/dashboard/billing/subscription", GetSubscription)
app.POST("/v1/chat/completions", TranshipmentAPI) app.POST("/v1/chat/completions", ChatRelayAPI)
broadcast.Register(app) broadcast.Register(app)
} }

View File

@ -26,7 +26,7 @@ type MessageContent struct {
type MessageContents []MessageContent type MessageContents []MessageContent
type TranshipmentForm struct { type RelayForm struct {
Model string `json:"model" binding:"required"` Model string `json:"model" binding:"required"`
Messages []Message `json:"messages" binding:"required"` Messages []Message `json:"messages" binding:"required"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
@ -54,7 +54,7 @@ type Usage struct {
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
} }
type TranshipmentResponse struct { type RelayResponse struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
@ -70,7 +70,7 @@ type ChoiceDelta struct {
FinishReason interface{} `json:"finish_reason"` FinishReason interface{} `json:"finish_reason"`
} }
type TranshipmentStreamResponse struct { type RelayStreamResponse struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
@ -81,7 +81,7 @@ type TranshipmentStreamResponse struct {
Error error `json:"error,omitempty"` Error error `json:"error,omitempty"`
} }
type TranshipmentErrorResponse struct { type RelayErrorResponse struct {
Error TranshipmentError `json:"error"` Error TranshipmentError `json:"error"`
} }
@ -90,6 +90,19 @@ type TranshipmentError struct {
Type string `json:"type"` Type string `json:"type"`
} }
type RelayImageForm struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N *int `json:"n,omitempty"`
}
type RelayImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
} `json:"data"`
}
func transformContent(content interface{}) string { func transformContent(content interface{}) string {
switch v := content.(type) { switch v := content.(type) {
case string: case string:
@ -100,7 +113,7 @@ func transformContent(content interface{}) string {
if data == nil || len(*data) == 0 { if data == nil || len(*data) == 0 {
return "" return ""
} }
for _, v := range *data { for _, v := range *data {
if v.Text != nil { if v.Text != nil {
result += *v.Text result += *v.Text