mirror of
https://github.com/coaidev/coai.git
synced 2025-05-21 22:10:12 +09:00
update gpt-4-32k adapter
This commit is contained in:
parent
5130d806a9
commit
653a90bd08
@ -3,23 +3,24 @@ package api
|
|||||||
import (
|
import (
|
||||||
"chat/auth"
|
"chat/auth"
|
||||||
"chat/types"
|
"chat/types"
|
||||||
"chat/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Buffer struct {
|
type Buffer struct {
|
||||||
Enable bool `json:"enable"`
|
Model string `json:"model"`
|
||||||
Quota float32 `json:"quota"`
|
Quota float32 `json:"quota"`
|
||||||
Data string `json:"data"`
|
Data string `json:"data"`
|
||||||
Cursor int `json:"cursor"`
|
Cursor int `json:"cursor"`
|
||||||
Times int `json:"times"`
|
Times int `json:"times"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBuffer(enable bool, history []types.ChatGPTMessage) *Buffer {
|
func NewBuffer(model string, history []types.ChatGPTMessage) *Buffer {
|
||||||
buffer := &Buffer{Data: "", Cursor: 0, Times: 0, Enable: enable}
|
return &Buffer{
|
||||||
if enable {
|
Data: "",
|
||||||
buffer.Quota = auth.CountInputToken(utils.CountTokenPrice(history))
|
Cursor: 0,
|
||||||
|
Times: 0,
|
||||||
|
Model: model,
|
||||||
|
Quota: auth.CountInputToken(model, history),
|
||||||
}
|
}
|
||||||
return buffer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) GetCursor() int {
|
func (b *Buffer) GetCursor() int {
|
||||||
@ -27,10 +28,7 @@ func (b *Buffer) GetCursor() int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) GetQuota() float32 {
|
func (b *Buffer) GetQuota() float32 {
|
||||||
if !b.Enable {
|
return b.Quota + auth.CountOutputToken(b.Model, b.ReadTimes())
|
||||||
return 0.
|
|
||||||
}
|
|
||||||
return b.Quota + auth.CountOutputToken(b.ReadTimes())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) Write(data string) string {
|
func (b *Buffer) Write(data string) string {
|
||||||
|
27
api/chat.go
27
api/chat.go
@ -29,15 +29,15 @@ func SendSegmentMessage(conn *websocket.Conn, message interface{}) {
|
|||||||
_ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message)))
|
_ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetErrorQuota(isGPT4 bool) float32 {
|
func GetErrorQuota(model string) float32 {
|
||||||
if isGPT4 {
|
if types.IsGPT4Model(model) {
|
||||||
return -0xe // special value for error
|
return -0xe // special value for error
|
||||||
} else {
|
} else {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
|
func GetTextSegment(instance *conversation.Conversation) (string, []types.ChatGPTMessage) {
|
||||||
var keyword string
|
var keyword string
|
||||||
var segment []types.ChatGPTMessage
|
var segment []types.ChatGPTMessage
|
||||||
|
|
||||||
@ -46,11 +46,16 @@ func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.
|
|||||||
} else {
|
} else {
|
||||||
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
|
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
|
||||||
}
|
}
|
||||||
|
return keyword, segment
|
||||||
|
}
|
||||||
|
|
||||||
|
func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
|
||||||
|
keyword, segment := GetTextSegment(instance)
|
||||||
SendSegmentMessage(conn, types.ChatSegmentResponse{Keyword: keyword, End: false})
|
SendSegmentMessage(conn, types.ChatSegmentResponse{Keyword: keyword, End: false})
|
||||||
|
|
||||||
isProPlan := auth.CanEnableSubscription(db, cache, user)
|
model := instance.GetModel()
|
||||||
if instance.IsEnableGPT4() && (!isProPlan) && (!auth.CanEnableGPT4(db, user)) {
|
useReverse := auth.CanEnableSubscription(db, cache, user)
|
||||||
|
if !auth.CanEnableModelWithSubscription(db, user, model, useReverse) {
|
||||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||||
Message: defaultQuotaMessage,
|
Message: defaultQuotaMessage,
|
||||||
Quota: 0,
|
Quota: 0,
|
||||||
@ -59,9 +64,9 @@ func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.
|
|||||||
return defaultQuotaMessage
|
return defaultQuotaMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer := NewBuffer(instance.IsEnableGPT4(), segment)
|
buffer := NewBuffer(model, segment)
|
||||||
StreamRequest(instance.IsEnableGPT4(), isProPlan, segment,
|
StreamRequest(model, useReverse, segment,
|
||||||
utils.Multi(instance.IsEnableGPT4() || isProPlan, -1, 2000),
|
utils.Multi(types.IsGPT4Model(model) || useReverse, -1, 2000),
|
||||||
func(resp string) {
|
func(resp string) {
|
||||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||||
Message: buffer.Write(resp),
|
Message: buffer.Write(resp),
|
||||||
@ -70,19 +75,19 @@ func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
if buffer.IsEmpty() {
|
if buffer.IsEmpty() {
|
||||||
if isProPlan {
|
if useReverse {
|
||||||
auth.DecreaseSubscriptionUsage(cache, user)
|
auth.DecreaseSubscriptionUsage(cache, user)
|
||||||
}
|
}
|
||||||
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||||
Message: defaultErrorMessage,
|
Message: defaultErrorMessage,
|
||||||
Quota: GetErrorQuota(instance.IsEnableGPT4()),
|
Quota: GetErrorQuota(model),
|
||||||
End: true,
|
End: true,
|
||||||
})
|
})
|
||||||
return defaultErrorMessage
|
return defaultErrorMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
// collect quota
|
// collect quota
|
||||||
if !isProPlan {
|
if !useReverse {
|
||||||
user.UseQuota(db, buffer.GetQuota())
|
user.UseQuota(db, buffer.GetQuota())
|
||||||
}
|
}
|
||||||
SendSegmentMessage(conn, types.ChatSegmentResponse{End: true, Quota: buffer.GetQuota()})
|
SendSegmentMessage(conn, types.ChatSegmentResponse{End: true, Quota: buffer.GetQuota()})
|
||||||
|
@ -105,14 +105,25 @@ func NativeStreamRequest(model string, endpoint string, apikeys string, messages
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func StreamRequest(enableGPT4 bool, isProPlan bool, messages []types.ChatGPTMessage, token int, callback func(string)) {
|
func StreamRequest(model string, enableReverse bool, messages []types.ChatGPTMessage, token int, callback func(string)) {
|
||||||
if enableGPT4 {
|
switch model {
|
||||||
if isProPlan {
|
case types.GPT4,
|
||||||
|
types.GPT40314,
|
||||||
|
types.GPT40613:
|
||||||
|
if enableReverse {
|
||||||
NativeStreamRequest(viper.GetString("openai.reverse"), viper.GetString("openai.pro_endpoint"), viper.GetString("openai.pro"), messages, token, callback)
|
NativeStreamRequest(viper.GetString("openai.reverse"), viper.GetString("openai.pro_endpoint"), viper.GetString("openai.pro"), messages, token, callback)
|
||||||
} else {
|
} else {
|
||||||
NativeStreamRequest("gpt-4", viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
|
NativeStreamRequest(model, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
|
||||||
}
|
}
|
||||||
} else {
|
case types.GPT432k,
|
||||||
NativeStreamRequest("gpt-3.5-turbo-16k-0613", viper.GetString("openai.user_endpoint"), viper.GetString("openai.user"), messages, token, callback)
|
types.GPT432k0613,
|
||||||
|
types.GPT432k0314:
|
||||||
|
NativeStreamRequest(model, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
|
||||||
|
case types.GPT3Turbo16k,
|
||||||
|
types.GPT3Turbo16k0301,
|
||||||
|
types.GPT3Turbo16k0613:
|
||||||
|
NativeStreamRequest(types.GPT3Turbo16k, viper.GetString("openai.user_endpoint"), viper.GetString("openai.user"), messages, token, callback)
|
||||||
|
default:
|
||||||
|
NativeStreamRequest(types.GPT3Turbo, viper.GetString("openai.anonymous_endpoint"), viper.GetString("openai.anonymous"), messages, token, callback)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,7 +62,6 @@ const resources = {
|
|||||||
buy: "Buy {{amount}} points",
|
buy: "Buy {{amount}} points",
|
||||||
dalle: "DALL·E Image Generator",
|
dalle: "DALL·E Image Generator",
|
||||||
"dalle-free": "5 free quotas per day",
|
"dalle-free": "5 free quotas per day",
|
||||||
gpt4: "GPT-4",
|
|
||||||
flex: "Flexible Billing",
|
flex: "Flexible Billing",
|
||||||
input: "Input",
|
input: "Input",
|
||||||
output: "Output",
|
output: "Output",
|
||||||
@ -106,7 +105,7 @@ const resources = {
|
|||||||
"free-web": "web searching feature",
|
"free-web": "web searching feature",
|
||||||
"free-conversation": "conversation storage",
|
"free-conversation": "conversation storage",
|
||||||
"free-api": "API calls",
|
"free-api": "API calls",
|
||||||
"pro-gpt4": "GPT-4 10 requests per day",
|
"pro-gpt4": "GPT-4 50 requests per day",
|
||||||
"pro-dalle": "50 quotas per day",
|
"pro-dalle": "50 quotas per day",
|
||||||
"pro-service": "Priority Service Support",
|
"pro-service": "Priority Service Support",
|
||||||
"pro-thread": "Concurrency Increase",
|
"pro-thread": "Concurrency Increase",
|
||||||
@ -207,7 +206,6 @@ const resources = {
|
|||||||
buy: "购买 {{amount}} 点数",
|
buy: "购买 {{amount}} 点数",
|
||||||
dalle: "DALL·E AI 绘图",
|
dalle: "DALL·E AI 绘图",
|
||||||
"dalle-free": "每天 5 次免费绘图配额",
|
"dalle-free": "每天 5 次免费绘图配额",
|
||||||
gpt4: "GPT-4",
|
|
||||||
flex: "灵活计费",
|
flex: "灵活计费",
|
||||||
input: "输入",
|
input: "输入",
|
||||||
output: "输出",
|
output: "输出",
|
||||||
@ -250,7 +248,7 @@ const resources = {
|
|||||||
"free-web": "联网搜索功能",
|
"free-web": "联网搜索功能",
|
||||||
"free-conversation": "对话存储记录",
|
"free-conversation": "对话存储记录",
|
||||||
"free-api": "API 调用",
|
"free-api": "API 调用",
|
||||||
"pro-gpt4": "GPT-4 每日请求 10 次",
|
"pro-gpt4": "GPT-4 每日请求 50 次",
|
||||||
"pro-dalle": "每日 50 次绘图",
|
"pro-dalle": "每日 50 次绘图",
|
||||||
"pro-service": "优先服务支持",
|
"pro-service": "优先服务支持",
|
||||||
"pro-thread": "并发数提升",
|
"pro-thread": "并发数提升",
|
||||||
@ -354,7 +352,6 @@ const resources = {
|
|||||||
buy: "Купить {{amount}} очков",
|
buy: "Купить {{amount}} очков",
|
||||||
dalle: "Генератор изображений DALL·E",
|
dalle: "Генератор изображений DALL·E",
|
||||||
"dalle-free": "5 бесплатных квот в день",
|
"dalle-free": "5 бесплатных квот в день",
|
||||||
gpt4: "GPT-4",
|
|
||||||
flex: "Гибкая тарификация",
|
flex: "Гибкая тарификация",
|
||||||
input: "Вход",
|
input: "Вход",
|
||||||
output: "Выход",
|
output: "Выход",
|
||||||
@ -399,7 +396,7 @@ const resources = {
|
|||||||
"free-web": "веб-поиск",
|
"free-web": "веб-поиск",
|
||||||
"free-conversation": "хранение разговоров",
|
"free-conversation": "хранение разговоров",
|
||||||
"free-api": "API вызовы",
|
"free-api": "API вызовы",
|
||||||
"pro-gpt4": "GPT-4 10 запросов в день",
|
"pro-gpt4": "GPT-4 50 запросов в день",
|
||||||
"pro-dalle": "50 квот в день",
|
"pro-dalle": "50 квот в день",
|
||||||
"pro-service": "Приоритетная служба поддержки",
|
"pro-service": "Приоритетная служба поддержки",
|
||||||
"pro-thread": "Увеличение параллелизма",
|
"pro-thread": "Увеличение параллелизма",
|
||||||
|
@ -248,7 +248,7 @@ function Quota() {
|
|||||||
<Separator orientation={`horizontal`} className={`my-2`} />
|
<Separator orientation={`horizontal`} className={`my-2`} />
|
||||||
<div className={`product-item`}>
|
<div className={`product-item`}>
|
||||||
<div className={`row title`}>
|
<div className={`row title`}>
|
||||||
<div>{t("buy.gpt4")}</div>
|
<div>GPT-4</div>
|
||||||
<div className={`grow`} />
|
<div className={`grow`} />
|
||||||
<div className={`column`}>
|
<div className={`column`}>
|
||||||
<Cloud className={`h-4 w-4`} /> {t("buy.flex")}
|
<Cloud className={`h-4 w-4`} /> {t("buy.flex")}
|
||||||
@ -276,6 +276,37 @@ function Quota() {
|
|||||||
4.3 / 1k token
|
4.3 / 1k token
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className={`product-item`}>
|
||||||
|
<div className={`row title`}>
|
||||||
|
<div>GPT-4-32K</div>
|
||||||
|
<div className={`grow`} />
|
||||||
|
<div className={`column`}>
|
||||||
|
<Cloud className={`h-4 w-4`} /> {t("buy.flex")}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className={`row desc`}>
|
||||||
|
<div className={`column`}>
|
||||||
|
<HardDriveUpload className={`h-4 w-4`} />
|
||||||
|
{t("buy.input")}
|
||||||
|
</div>
|
||||||
|
<div className={`grow`} />
|
||||||
|
<div className={`column`}>
|
||||||
|
<Cloud className={`h-4 w-4`} />
|
||||||
|
4.2 / 1k token
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className={`row desc`}>
|
||||||
|
<div className={`column`}>
|
||||||
|
<HardDriveDownload className={`h-4 w-4`} />
|
||||||
|
{t("buy.output")}
|
||||||
|
</div>
|
||||||
|
<div className={`grow`} />
|
||||||
|
<div className={`column`}>
|
||||||
|
<Cloud className={`h-4 w-4`} />
|
||||||
|
8.6 / 1k token
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
<div className={`row desc`}>
|
<div className={`row desc`}>
|
||||||
<div className={`column info`}>
|
<div className={`column info`}>
|
||||||
<Info className={`h-4 w-4`} />
|
<Info className={`h-4 w-4`} />
|
||||||
|
@ -46,6 +46,8 @@ import { buySubscription } from "../conversation/addition.ts";
|
|||||||
|
|
||||||
function calc_prize(month: number): number {
|
function calc_prize(month: number): number {
|
||||||
if (month >= 12) {
|
if (month >= 12) {
|
||||||
|
return 8 * month * 0.8;
|
||||||
|
} else if (month >= 6) {
|
||||||
return 8 * month * 0.9;
|
return 8 * month * 0.9;
|
||||||
}
|
}
|
||||||
return 8 * month;
|
return 8 * month;
|
||||||
@ -101,17 +103,22 @@ function Upgrade({ children }: UpgradeProps) {
|
|||||||
<SelectContent>
|
<SelectContent>
|
||||||
<SelectItem value={"1"}>{t(`sub.time.1`)}</SelectItem>
|
<SelectItem value={"1"}>{t(`sub.time.1`)}</SelectItem>
|
||||||
<SelectItem value={"3"}>{t(`sub.time.3`)}</SelectItem>
|
<SelectItem value={"3"}>{t(`sub.time.3`)}</SelectItem>
|
||||||
<SelectItem value={"6"}>{t(`sub.time.6`)}</SelectItem>
|
<SelectItem value={"6"}>
|
||||||
|
{t(`sub.time.6`)}
|
||||||
|
<Badge className={`ml-2 cent`}>
|
||||||
|
{t(`percent`, { cent: 9 })}
|
||||||
|
</Badge>
|
||||||
|
</SelectItem>
|
||||||
<SelectItem value={"12"}>
|
<SelectItem value={"12"}>
|
||||||
{t(`sub.time.12`)}
|
{t(`sub.time.12`)}
|
||||||
<Badge className={`ml-2 cent`}>
|
<Badge className={`ml-2 cent`}>
|
||||||
{t(`percent`, { cent: 9 })}
|
{t(`percent`, { cent: 8 })}
|
||||||
</Badge>
|
</Badge>
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
<p className={`price`}>
|
<p className={`price`}>
|
||||||
{t("sub.price", { price: calc_prize(month) })}
|
{t("sub.price", { price: calc_prize(month).toFixed(2) })}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<DialogFooter>
|
<DialogFooter>
|
||||||
|
@ -29,7 +29,7 @@ func BuySubscription(db *sql.DB, user *User, month int) bool {
|
|||||||
|
|
||||||
func IncreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
|
func IncreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
|
||||||
today := time.Now().Format("2006-01-02")
|
today := time.Now().Format("2006-01-02")
|
||||||
return utils.IncrWithLimit(cache, fmt.Sprintf(":subscription-usage:%s:%d", today, user.ID), 1, 999, 60*60*24) // 1 day
|
return utils.IncrWithLimit(cache, fmt.Sprintf(":subscription-usage:%s:%d", today, user.ID), 1, 50, 60*60*24) // 1 day
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
|
func DecreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/types"
|
||||||
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,18 +15,46 @@ import (
|
|||||||
// $0.03 / 1K tokens $0.06 / 1K tokens
|
// $0.03 / 1K tokens $0.06 / 1K tokens
|
||||||
// ¥0.21 / 1K tokens ¥0.43 / 1K tokens
|
// ¥0.21 / 1K tokens ¥0.43 / 1K tokens
|
||||||
// 2.1 nio / 1K tokens 4.3 nio / 1K tokens
|
// 2.1 nio / 1K tokens 4.3 nio / 1K tokens
|
||||||
|
//
|
||||||
|
// GPT-4 price (32k-context)
|
||||||
|
// Input Output
|
||||||
|
// $0.06 / 1K tokens $0.12 / 1K tokens
|
||||||
|
// ¥0.43 / 1K tokens ¥0.86 / 1K tokens
|
||||||
|
// 4.3 nio / 1K tokens 8.6 nio / 1K tokens
|
||||||
|
|
||||||
// Dalle price (512x512)
|
// Dalle price (512x512)
|
||||||
// $0.018 / per image
|
// $0.018 / per image
|
||||||
// ¥0.13 / per image
|
// ¥0.13 / per image
|
||||||
// 1 nio / per image
|
// 1 nio / per image
|
||||||
|
|
||||||
func CountInputToken(n int) float32 {
|
func CountInputToken(model string, v []types.ChatGPTMessage) float32 {
|
||||||
return float32(n) / 1000 * 2.1
|
switch model {
|
||||||
|
case types.GPT3Turbo:
|
||||||
|
return 0
|
||||||
|
case types.GPT3Turbo16k:
|
||||||
|
return 0
|
||||||
|
case types.GPT4:
|
||||||
|
return float32(utils.CountTokenPrice(v, model)) / 1000 * 2.1
|
||||||
|
case types.GPT432k:
|
||||||
|
return float32(utils.CountTokenPrice(v, model)) / 1000 * 2.1 * 2
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountOutputToken(n int) float32 {
|
func CountOutputToken(model string, t int) float32 {
|
||||||
return float32(n) / 1000 * 4.3
|
switch model {
|
||||||
|
case types.GPT3Turbo:
|
||||||
|
return 0
|
||||||
|
case types.GPT3Turbo16k:
|
||||||
|
return 0
|
||||||
|
case types.GPT4:
|
||||||
|
return float32(t*utils.GetWeightByModel(model)) / 1000 * 4.3
|
||||||
|
case types.GPT432k:
|
||||||
|
return float32(t*utils.GetWeightByModel(model)) / 1000 * 8.6
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReduceDalle(db *sql.DB, user *User) bool {
|
func ReduceDalle(db *sql.DB, user *User) bool {
|
||||||
@ -34,8 +64,24 @@ func ReduceDalle(db *sql.DB, user *User) bool {
|
|||||||
return user.UseQuota(db, 1)
|
return user.UseQuota(db, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CanEnableGPT4(db *sql.DB, user *User) bool {
|
func CanEnableModel(db *sql.DB, user *User, model string) bool {
|
||||||
return user.GetQuota(db) >= 5
|
switch model {
|
||||||
|
case types.GPT4, types.GPT40613, types.GPT40314:
|
||||||
|
return user.GetQuota(db) >= 5
|
||||||
|
case types.GPT432k, types.GPT432k0613, types.GPT432k0314:
|
||||||
|
return user.GetQuota(db) >= 50
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CanEnableModelWithSubscription(db *sql.DB, user *User, model string, useReverse bool) bool {
|
||||||
|
if utils.Contains(model, types.GPT4Array) {
|
||||||
|
if useReverse {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return CanEnableModel(db, user, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuyQuota(db *sql.DB, user *User, quota int) bool {
|
func BuyQuota(db *sql.DB, user *User, quota int) bool {
|
||||||
|
@ -9,42 +9,42 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
Id int64 `json:"id"`
|
Id int64 `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Message []types.ChatGPTMessage `json:"message"`
|
Message []types.ChatGPTMessage `json:"message"`
|
||||||
EnableGPT4 bool `json:"enable_gpt4"`
|
Model string `json:"model"`
|
||||||
EnableWeb bool `json:"enable_web"`
|
EnableWeb bool `json:"enable_web"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FormMessage struct {
|
type FormMessage struct {
|
||||||
Type string `json:"type"` // ping
|
Type string `json:"type"` // ping
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Web bool `json:"web"`
|
Web bool `json:"web"`
|
||||||
GPT4 bool `json:"gpt4"`
|
Model string `json:"model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConversation(db *sql.DB, id int64) *Conversation {
|
func NewConversation(db *sql.DB, id int64) *Conversation {
|
||||||
return &Conversation{
|
return &Conversation{
|
||||||
UserID: id,
|
UserID: id,
|
||||||
Id: GetConversationLengthByUserID(db, id) + 1,
|
Id: GetConversationLengthByUserID(db, id) + 1,
|
||||||
Name: "new chat",
|
Name: "new chat",
|
||||||
Message: []types.ChatGPTMessage{},
|
Message: []types.ChatGPTMessage{},
|
||||||
EnableGPT4: false,
|
Model: types.GPT3Turbo,
|
||||||
EnableWeb: false,
|
EnableWeb: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conversation) IsEnableGPT4() bool {
|
func (c *Conversation) GetModel() string {
|
||||||
return c.EnableGPT4
|
return c.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conversation) IsEnableWeb() bool {
|
func (c *Conversation) IsEnableWeb() bool {
|
||||||
return c.EnableWeb
|
return c.EnableWeb
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conversation) SetEnableGPT4(enable bool) {
|
func (c *Conversation) SetModel(model string) {
|
||||||
c.EnableGPT4 = enable
|
c.Model = model
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conversation) SetEnableWeb(enable bool) {
|
func (c *Conversation) SetEnableWeb(enable bool) {
|
||||||
@ -141,7 +141,7 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.AddMessageFromUser(form.Message)
|
c.AddMessageFromUser(form.Message)
|
||||||
c.SetEnableGPT4(form.GPT4)
|
c.SetModel(form.Model)
|
||||||
c.SetEnableWeb(form.Web)
|
c.SetEnableWeb(form.Web)
|
||||||
return form.Message, nil
|
return form.Message, nil
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,16 @@ func GenerateAPI(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, func(data string) {
|
useReverse := auth.CanEnableSubscription(db, cache, user)
|
||||||
|
if !auth.CanEnableModelWithSubscription(db, user, form.Model, useReverse) {
|
||||||
|
api.SendSegmentMessage(conn, types.ChatSegmentResponse{
|
||||||
|
Message: "You don't have enough quota to use this model.",
|
||||||
|
Quota: 0,
|
||||||
|
End: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(data string) {
|
||||||
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
|
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
|
||||||
End: false,
|
End: false,
|
||||||
Message: data,
|
Message: data,
|
||||||
|
@ -5,10 +5,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateGenerationWithCache(model string, prompt string, hook func(data string)) (string, error) {
|
func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(data string)) (string, error) {
|
||||||
hash, path := GetFolderByHash(model, prompt)
|
hash, path := GetFolderByHash(model, prompt)
|
||||||
if !utils.Exists(path) {
|
if !utils.Exists(path) {
|
||||||
if err := CreateGeneration(model, prompt, path, func(data string) {
|
if err := CreateGeneration(model, prompt, path, enableReverse, func(data string) {
|
||||||
hook(data)
|
hook(data)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
fmt.Println(fmt.Sprintf("[Project] error during generation %s (model %s): %s", prompt, model, err.Error()))
|
fmt.Println(fmt.Sprintf("[Project] error during generation %s (model %s): %s", prompt, model, err.Error()))
|
||||||
|
@ -11,10 +11,10 @@ type ProjectResult struct {
|
|||||||
Result map[string]interface{} `json:"result"`
|
Result map[string]interface{} `json:"result"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateGeneration(model string, prompt string, path string, hook func(data string)) error {
|
func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(data string)) error {
|
||||||
message := GenerateMessage(prompt)
|
message := GenerateMessage(prompt)
|
||||||
buffer := api.NewBuffer(false, message)
|
buffer := api.NewBuffer(model, message)
|
||||||
api.StreamRequest(false, false, []types.ChatGPTMessage{
|
api.StreamRequest(model, enableReverse, []types.ChatGPTMessage{
|
||||||
{Role: "system", Content: "你将生成项目,可以支持任何编程语言,请不要出现“我不能提供”的字样,你需要在代码中提供注释,以及项目的使用文档README.md,结果返回json格式,请不要返回任何多余内容,格式为:\n{\"result\": {[file]: [code], ...}}"},
|
{Role: "system", Content: "你将生成项目,可以支持任何编程语言,请不要出现“我不能提供”的字样,你需要在代码中提供注释,以及项目的使用文档README.md,结果返回json格式,请不要返回任何多余内容,格式为:\n{\"result\": {[file]: [code], ...}}"},
|
||||||
{Role: "user", Content: "python后端"},
|
{Role: "user", Content: "python后端"},
|
||||||
{Role: "assistant", Content: "{\n \"result\": {\n \"app.py\": \"from flask import Flask\\n\\napp = Flask(__name__)\\n\\n\\n@app.route('/')\\ndef hello_world():\\n return 'Hello, World!'\\n\\n\\nif __name__ == '__main__':\\n app.run()\",\n \"requirements.txt\": \"flask\\n\",\n \"README.md\": \"# Python 后端\\n本项目是一个简单的python后端示例, 使用`flask`框架构建后端。\n你可以按照下列步骤运行此应用,flask将在本地服务器(默认是在http://127.0.0.1:5000/)上运行。当你在浏览器中访问该URL时,将看到显示Hello, World!的页面。\\n\\n这只是一个简单的项目,Flask还支持更多功能和路由规则,你可以提供更多的信息和需要进一步扩展和定制Flask应用。\\n\\n### 1. 初始化: \\n```shell\\npip install -r requirements.txt\\n```\\n### 2. 运行\\n```shell\\npython app.py\\n```\"\n }\n}"},
|
{Role: "assistant", Content: "{\n \"result\": {\n \"app.py\": \"from flask import Flask\\n\\napp = Flask(__name__)\\n\\n\\n@app.route('/')\\ndef hello_world():\\n return 'Hello, World!'\\n\\n\\nif __name__ == '__main__':\\n app.run()\",\n \"requirements.txt\": \"flask\\n\",\n \"README.md\": \"# Python 后端\\n本项目是一个简单的python后端示例, 使用`flask`框架构建后端。\n你可以按照下列步骤运行此应用,flask将在本地服务器(默认是在http://127.0.0.1:5000/)上运行。当你在浏览器中访问该URL时,将看到显示Hello, World!的页面。\\n\\n这只是一个简单的项目,Flask还支持更多功能和路由规则,你可以提供更多的信息和需要进一步扩展和定制Flask应用。\\n\\n### 1. 初始化: \\n```shell\\npip install -r requirements.txt\\n```\\n### 2. 运行\\n```shell\\npython app.py\\n```\"\n }\n}"},
|
||||||
|
29
qodana.yaml
Normal file
29
qodana.yaml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#-------------------------------------------------------------------------------#
|
||||||
|
# Qodana analysis is configured by qodana.yaml file #
|
||||||
|
# https://www.jetbrains.com/help/qodana/qodana-yaml.html #
|
||||||
|
#-------------------------------------------------------------------------------#
|
||||||
|
version: "1.0"
|
||||||
|
|
||||||
|
#Specify inspection profile for code analysis
|
||||||
|
profile:
|
||||||
|
name: qodana.starter
|
||||||
|
|
||||||
|
#Enable inspections
|
||||||
|
#include:
|
||||||
|
# - name: <SomeEnabledInspectionId>
|
||||||
|
|
||||||
|
#Disable inspections
|
||||||
|
#exclude:
|
||||||
|
# - name: <SomeDisabledInspectionId>
|
||||||
|
# paths:
|
||||||
|
# - <path/where/not/run/inspection>
|
||||||
|
|
||||||
|
#Execute shell command before Qodana execution (Applied in CI/CD pipeline)
|
||||||
|
#bootstrap: sh ./prepare-qodana.sh
|
||||||
|
|
||||||
|
#Install IDE plugins before Qodana execution (Applied in CI/CD pipeline)
|
||||||
|
#plugins:
|
||||||
|
# - id: <plugin.id> #(plugin id can be found at https://plugins.jetbrains.com)
|
||||||
|
|
||||||
|
#Specify Qodana linter for analysis (Applied in CI/CD pipeline)
|
||||||
|
linter: jetbrains/qodana-go:latest
|
51
types/globals.go
Normal file
51
types/globals.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import "chat/utils"
|
||||||
|
|
||||||
|
const (
|
||||||
|
GPT3Turbo = "gpt-3.5-turbo"
|
||||||
|
GPT3Turbo0613 = "gpt-3.5-turbo-0613"
|
||||||
|
GPT3Turbo0301 = "gpt-3.5-turbo-0301"
|
||||||
|
GPT3Turbo16k = "gpt-3.5-turbo-16k"
|
||||||
|
GPT3Turbo16k0613 = "gpt-3.5-turbo-16k-0613"
|
||||||
|
GPT3Turbo16k0301 = "gpt-3.5-turbo-16k-0301"
|
||||||
|
GPT4 = "gpt-4"
|
||||||
|
GPT40314 = "gpt-4-0314"
|
||||||
|
GPT40613 = "gpt-4-0613"
|
||||||
|
GPT432k = "gpt-4-32k"
|
||||||
|
GPT432k0314 = "gpt-4-32k-0314"
|
||||||
|
GPT432k0613 = "gpt-4-32k-0613"
|
||||||
|
Dalle = "dalle"
|
||||||
|
)
|
||||||
|
|
||||||
|
var GPT3TurboArray = []string{
|
||||||
|
GPT3Turbo,
|
||||||
|
GPT3Turbo0613,
|
||||||
|
GPT3Turbo0301,
|
||||||
|
}
|
||||||
|
|
||||||
|
var GPT3Turbo16kArray = []string{
|
||||||
|
GPT3Turbo16k,
|
||||||
|
GPT3Turbo16k0613,
|
||||||
|
GPT3Turbo16k0301,
|
||||||
|
}
|
||||||
|
|
||||||
|
var GPT4Array = []string{
|
||||||
|
GPT4,
|
||||||
|
GPT40314,
|
||||||
|
GPT40613,
|
||||||
|
}
|
||||||
|
|
||||||
|
var GPT432kArray = []string{
|
||||||
|
GPT432k,
|
||||||
|
GPT432k0314,
|
||||||
|
GPT432k0613,
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsGPT4Model(model string) bool {
|
||||||
|
return utils.Contains(model, GPT4Array) || utils.Contains(model, GPT432kArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsGPT3TurboModel(model string) bool {
|
||||||
|
return utils.Contains(model, GPT3TurboArray) || utils.Contains(model, GPT3Turbo16kArray)
|
||||||
|
}
|
@ -13,22 +13,29 @@ import (
|
|||||||
|
|
||||||
func GetWeightByModel(model string) int {
|
func GetWeightByModel(model string) int {
|
||||||
switch model {
|
switch model {
|
||||||
case "gpt-3.5-turbo-0613",
|
case types.GPT432k,
|
||||||
"gpt-3.5-turbo-16k-0613",
|
types.GPT432k0613,
|
||||||
"gpt-4-0314",
|
types.GPT432k0314:
|
||||||
"gpt-4-32k-0314",
|
return 3 * 10
|
||||||
"gpt-4-0613",
|
case types.GPT3Turbo,
|
||||||
"gpt-4-32k-0613":
|
types.GPT3Turbo0613,
|
||||||
|
|
||||||
|
types.GPT3Turbo16k,
|
||||||
|
types.GPT3Turbo16k0613,
|
||||||
|
|
||||||
|
types.GPT4,
|
||||||
|
types.GPT40314,
|
||||||
|
types.GPT40613:
|
||||||
return 3
|
return 3
|
||||||
case "gpt-3.5-turbo-0301":
|
case types.GPT3Turbo0301, types.GPT3Turbo16k0301:
|
||||||
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
default:
|
default:
|
||||||
if strings.Contains(model, "gpt-3.5-turbo") {
|
if strings.Contains(model, types.GPT3Turbo) {
|
||||||
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
|
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
|
||||||
return GetWeightByModel("gpt-3.5-turbo-0613")
|
return GetWeightByModel(types.GPT3Turbo0613)
|
||||||
} else if strings.Contains(model, "gpt-4") {
|
} else if strings.Contains(model, types.GPT4) {
|
||||||
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
|
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
|
||||||
return GetWeightByModel("gpt-4-0613")
|
return GetWeightByModel(types.GPT40613)
|
||||||
} else {
|
} else {
|
||||||
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
|
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
|
||||||
panic(fmt.Errorf("not implemented for model %s", model))
|
panic(fmt.Errorf("not implemented for model %s", model))
|
||||||
@ -55,6 +62,6 @@ func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (token
|
|||||||
return tokens
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenPrice(messages []types.ChatGPTMessage) int {
|
func CountTokenPrice(messages []types.ChatGPTMessage, model string) int {
|
||||||
return NumTokensFromMessages(messages, "gpt-4")
|
return NumTokensFromMessages(messages, model)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user