feat: add azure models and fix dashscope models

This commit is contained in:
Zhang Minghan 2023-12-29 10:31:13 +08:00
parent 61c917765b
commit 7ab1629955
8 changed files with 165 additions and 27 deletions

View File

@ -25,20 +25,31 @@ func (c *ChatInstance) GetHeader() map[string]string {
}
}
func (c *ChatInstance) FormatMessages(message []globals.Message) []Message {
var messages []Message
for _, v := range message {
if v.Role == globals.Tool {
continue
}
messages = append(messages, Message{
Role: v.Role,
Content: v.Content,
})
}
return messages
}
func (c *ChatInstance) GetChatBody(props *ChatProps) ChatRequest {
if props.Token <= 0 || props.Token > 1500 {
props.Token = 1500
}
return ChatRequest{
Model: strings.TrimSuffix(props.Model, "-net"),
Input: ChatInput{
Messages: utils.EachNotNil(props.Message, func(message globals.Message) *globals.Message {
if message.Role == globals.Tool {
return nil
}
return &message
}),
Messages: c.FormatMessages(props.Message),
},
Parameters: ChatParam{
MaxTokens: props.Token,

View File

@ -1,7 +1,5 @@
package dashscope
import "chat/globals"
// ChatRequest is the request body for dashscope
type ChatRequest struct {
Model string `json:"model"`
@ -9,9 +7,13 @@ type ChatRequest struct {
Parameters ChatParam `json:"parameters"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatInput struct {
Prompt string `json:"prompt"`
Messages []globals.Message `json:"messages"`
Messages []Message `json:"messages"`
}
type ChatParam struct {

View File

@ -91,6 +91,49 @@ export const supportModels: Model[] = [
tag: ["official", "unstable", "image-generation"],
},
{
id: "azure-gpt-3.5-turbo",
name: "Azure GPT-3.5",
free: false,
auth: true,
tag: ["official"],
},
{
id: "azure-gpt-3.5-turbo-16k",
name: "Azure GPT-3.5 16K",
free: false,
auth: true,
tag: ["official"],
},
{
id: "azure-gpt-4",
name: "Azure GPT-4",
free: false,
auth: true,
tag: ["official", "high-quality"],
},
{
id: "azure-gpt-4-1106-preview",
name: "Azure GPT-4 Turbo 128k",
free: false,
auth: true,
tag: ["official", "high-context", "unstable"],
},
{
id: "azure-gpt-4-vision-preview",
name: "Azure GPT-4 Vision 128k",
free: false,
auth: true,
tag: ["official", "high-context", "multi-modal"],
},
{
id: "azure-gpt-4-32k",
name: "Azure GPT-4 32k",
free: false,
auth: true,
tag: ["official", "multi-modal"],
},
// spark desk
{
id: "spark-desk-v3",
@ -352,6 +395,13 @@ export const defaultModels = [
"gpt-4-v",
"gpt-4-dalle",
"azure-gpt-3.5-turbo",
"azure-gpt-3.5-turbo-16k",
"azure-gpt-4",
"azure-gpt-4-1106-preview",
"azure-gpt-4-vision-preview",
"azure-gpt-4-32k",
"claude-1-100k",
"claude-2",
"claude-2.1",
@ -414,6 +464,12 @@ export const modelAvatars: Record<string, string> = {
"gpt-4-32k-0613": "gpt432k.webp",
"gpt-4-v": "gpt4v.png",
"gpt-4-dalle": "gpt4dalle.png",
"azure-gpt-3.5-turbo": "gpt35turbo.png",
"azure-gpt-3.5-turbo-16k": "gpt35turbo16k.webp",
"azure-gpt-4": "gpt4.png",
"azure-gpt-4-1106-preview": "gpt432k.webp",
"azure-gpt-4-vision-preview": "gpt4v.png",
"azure-gpt-4-32k": "gpt432k.webp",
"claude-1-100k": "claude.png",
"claude-2": "claude100k.png",
"claude-2.1": "claude100k.png",

View File

@ -16,7 +16,7 @@ import (
"time"
)
func TranshipmentAPI(c *gin.Context) {
func ChatRelayAPI(c *gin.Context) {
username := utils.GetUserFromContext(c)
if username == "" {
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid api key"), "authentication_error")
@ -28,7 +28,7 @@ func TranshipmentAPI(c *gin.Context) {
return
}
var form TranshipmentForm
var form RelayForm
if err := c.ShouldBindJSON(&form); err != nil {
abortWithErrorResponse(c, fmt.Errorf("invalid request body: %s", err.Error()), "invalid_request_error")
return
@ -67,7 +67,7 @@ func TranshipmentAPI(c *gin.Context) {
}
}
func GetChatProps(form TranshipmentForm, messages []globals.Message, buffer *utils.Buffer, plan bool) *adapter.ChatProps {
func GetChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer, plan bool) *adapter.ChatProps {
return &adapter.ChatProps{
Model: form.Model,
Message: messages,
@ -85,7 +85,7 @@ func GetChatProps(form TranshipmentForm, messages []globals.Message, buffer *uti
}
}
func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)
@ -105,7 +105,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []
}
CollectQuota(c, user, buffer, plan, err)
c.JSON(http.StatusOK, TranshipmentResponse{
c.JSON(http.StatusOK, RelayResponse{
Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion",
Created: created,
@ -126,8 +126,8 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []
})
}
func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, data string, buffer *utils.Buffer, end bool, err error) TranshipmentStreamResponse {
return TranshipmentStreamResponse{
func getStreamTranshipmentForm(id string, created int64, form RelayForm, data string, buffer *utils.Buffer, end bool, err error) RelayStreamResponse {
return RelayStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion.chunk",
Created: created,
@ -152,8 +152,8 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
}
}
func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
partial := make(chan TranshipmentStreamResponse)
func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
partial := make(chan RelayStreamResponse)
db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)

View File

@ -1 +1,57 @@
package manager
import (
"chat/auth"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"strings"
"time"
)
func ImagesRelayAPI(c *gin.Context) {
username := utils.GetUserFromContext(c)
if username == "" {
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid api key"), "authentication_error")
return
}
if utils.GetAgentFromContext(c) != "api" {
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid agent"), "authentication_error")
return
}
var form RelayImageForm
if err := c.ShouldBindJSON(&form); err != nil {
abortWithErrorResponse(c, fmt.Errorf("invalid request body: %s", err.Error()), "invalid_request_error")
return
}
prompt := strings.TrimSpace(form.Prompt)
if prompt == "" {
sendErrorResponse(c, fmt.Errorf("prompt is required"), "invalid_request_error")
}
db := utils.GetDBFromContext(c)
user := &auth.User{
Username: username,
}
created := time.Now().Unix()
if strings.HasSuffix(form.Model, "-official") {
form.Model = strings.TrimSuffix(form.Model, "-official")
}
check := auth.CanEnableModel(db, user, form.Model)
if !check {
sendErrorResponse(c, fmt.Errorf("quota exceeded"), "quota_exceeded_error")
return
}
createRelayImageObject(c, form, prompt, created, user, false)
}
func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string, created int64, user *auth.User, plan bool) {
}

View File

@ -22,7 +22,7 @@ func sendErrorResponse(c *gin.Context, err error, types ...string) {
errType = "chatnio_api_error"
}
c.JSON(http.StatusServiceUnavailable, TranshipmentErrorResponse{
c.JSON(http.StatusServiceUnavailable, RelayErrorResponse{
Error: TranshipmentError{
Message: err.Error(),
Type: errType,

View File

@ -11,7 +11,7 @@ func Register(app *gin.RouterGroup) {
app.GET("/v1/charge", ChargeAPI)
app.GET("/dashboard/billing/usage", GetBillingUsage)
app.GET("/dashboard/billing/subscription", GetSubscription)
app.POST("/v1/chat/completions", TranshipmentAPI)
app.POST("/v1/chat/completions", ChatRelayAPI)
broadcast.Register(app)
}

View File

@ -26,7 +26,7 @@ type MessageContent struct {
type MessageContents []MessageContent
type TranshipmentForm struct {
type RelayForm struct {
Model string `json:"model" binding:"required"`
Messages []Message `json:"messages" binding:"required"`
Stream bool `json:"stream"`
@ -54,7 +54,7 @@ type Usage struct {
TotalTokens int `json:"total_tokens"`
}
type TranshipmentResponse struct {
type RelayResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
@ -70,7 +70,7 @@ type ChoiceDelta struct {
FinishReason interface{} `json:"finish_reason"`
}
type TranshipmentStreamResponse struct {
type RelayStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
@ -81,7 +81,7 @@ type TranshipmentStreamResponse struct {
Error error `json:"error,omitempty"`
}
type TranshipmentErrorResponse struct {
type RelayErrorResponse struct {
Error TranshipmentError `json:"error"`
}
@ -90,6 +90,19 @@ type TranshipmentError struct {
Type string `json:"type"`
}
type RelayImageForm struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N *int `json:"n,omitempty"`
}
type RelayImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
} `json:"data"`
}
func transformContent(content interface{}) string {
switch v := content.(type) {
case string: