fix token billing error

This commit is contained in:
Zhang Minghan 2023-12-26 12:51:07 +08:00
parent 6487a486ba
commit 13c12956af
5 changed files with 19 additions and 13 deletions

View File

@ -17,9 +17,6 @@ RUN wget https://nodejs.org/dist/v16.14.0/node-v16.14.0-linux-x64.tar.xz && \
ENV PATH=$PATH:/usr/local/go/bin:/usr/local/node/bin ENV PATH=$PATH:/usr/local/go/bin:/usr/local/node/bin
# Install npm
RUN npm install -g pnpm
# Copy source code # Copy source code
COPY . . COPY . .
@ -40,7 +37,8 @@ RUN go install && \
go build . go build .
# Build frontend # Build frontend
RUN cd /app && \ RUN npm install -g pnpm && \
cd /app && \
pnpm install && \ pnpm install && \
pnpm run build && \ pnpm run build && \
rm -rf node_modules rm -rf node_modules

View File

@ -17,13 +17,17 @@ import (
const defaultMessage = "Sorry, I don't understand. Please try again." const defaultMessage = "Sorry, I don't understand. Please try again."
const defaultQuotaMessage = "You don't have enough quota or you don't have permission to use this model. please [buy](/buy) or [subscribe](/subscribe) to get more." const defaultQuotaMessage = "You don't have enough quota or you don't have permission to use this model. please [buy](/buy) or [subscribe](/subscribe) to get more."
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool) { func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
db := utils.GetDBFromContext(c) db := utils.GetDBFromContext(c)
quota := buffer.GetQuota() quota := buffer.GetQuota()
if buffer.IsEmpty() || buffer.GetCharge().IsBillingType(globals.TimesBilling) { if buffer.IsEmpty() {
return
} else if buffer.GetCharge().IsBillingType(globals.TimesBilling) && err != nil {
// billing type is times, but error occurred
return return
} }
// collect quota for tokens billing (though error occurred) or times billing
if !uncountable && quota > 0 && user != nil { if !uncountable && quota > 0 && user != nil {
user.UseQuota(db, quota) user.UseQuota(db, quota)
} }
@ -115,7 +119,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
auth.RevertSubscriptionUsage(db, cache, user, model) auth.RevertSubscriptionUsage(db, cache, user, model)
CollectQuota(conn.GetCtx(), user, buffer, plan) CollectQuota(conn.GetCtx(), user, buffer, plan, err)
conn.Send(globals.ChatSegmentResponse{ conn.Send(globals.ChatSegmentResponse{
Message: err.Error(), Message: err.Error(),
End: true, End: true,
@ -123,7 +127,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
return err.Error() return err.Error()
} }
CollectQuota(conn.GetCtx(), user, buffer, plan) CollectQuota(conn.GetCtx(), user, buffer, plan, err)
if buffer.IsEmpty() { if buffer.IsEmpty() {
conn.Send(globals.ChatSegmentResponse{ conn.Send(globals.ChatSegmentResponse{

View File

@ -57,11 +57,11 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
admin.AnalysisRequest(model, buffer, err) admin.AnalysisRequest(model, buffer, err)
if err != nil { if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, model) auth.RevertSubscriptionUsage(db, cache, user, model)
CollectQuota(c, user, buffer, plan) CollectQuota(c, user, buffer, plan, err)
return err.Error(), 0 return err.Error(), 0
} }
CollectQuota(c, user, buffer, plan) CollectQuota(c, user, buffer, plan, err)
SaveCacheData(c, &CacheProps{ SaveCacheData(c, &CacheProps{
Message: segment, Message: segment,

View File

@ -196,7 +196,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
return return
} }
CollectQuota(c, user, buffer, plan) CollectQuota(c, user, buffer, plan, err)
c.JSON(http.StatusOK, TranshipmentResponse{ c.JSON(http.StatusOK, TranshipmentResponse{
Id: fmt.Sprintf("chatcmpl-%s", id), Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion", Object: "chat.completion",
@ -266,7 +266,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
} }
partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil) partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil)
CollectQuota(c, user, buffer, plan) CollectQuota(c, user, buffer, plan, err)
close(partial) close(partial)
return return
}() }()

View File

@ -78,7 +78,11 @@ func AuthMiddleware() gin.HandlerFunc {
instance := ProcessAuthorization(c) instance := ProcessAuthorization(c)
if viper.GetBool("serve_static") { if viper.GetBool("serve_static") {
path = strings.TrimPrefix(path, "/api") if !strings.HasPrefix(path, "/api") {
return
} else {
path = strings.TrimPrefix(path, "/api")
}
} }
db := utils.GetDBFromContext(c) db := utils.GetDBFromContext(c)