From 13c12956af473fc2921cf3598562ba27af17f752 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Tue, 26 Dec 2023 12:51:07 +0800 Subject: [PATCH] fix token billing error --- Dockerfile | 6 ++---- manager/chat.go | 12 ++++++++---- manager/completions.go | 4 ++-- manager/transhipment.go | 4 ++-- middleware/auth.go | 6 +++++- 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/Dockerfile b/Dockerfile index faaafc4..b1d11e6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,9 +17,6 @@ RUN wget https://nodejs.org/dist/v16.14.0/node-v16.14.0-linux-x64.tar.xz && \ ENV PATH=$PATH:/usr/local/go/bin:/usr/local/node/bin -# Install npm -RUN npm install -g pnpm - # Copy source code COPY . . @@ -40,7 +37,8 @@ RUN go install && \ go build . # Build frontend -RUN cd /app && \ +RUN npm install -g pnpm && \ + cd /app && \ pnpm install && \ pnpm run build && \ rm -rf node_modules diff --git a/manager/chat.go b/manager/chat.go index 2137a48..74d80fc 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -17,13 +17,17 @@ import ( const defaultMessage = "Sorry, I don't understand. Please try again." const defaultQuotaMessage = "You don't have enough quota or you don't have permission to use this model. please [buy](/buy) or [subscribe](/subscribe) to get more." -func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool) { +func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) { db := utils.GetDBFromContext(c) quota := buffer.GetQuota() - if buffer.IsEmpty() || buffer.GetCharge().IsBillingType(globals.TimesBilling) { + if buffer.IsEmpty() { + return + } else if buffer.GetCharge().IsBillingType(globals.TimesBilling) && err != nil { + // billing type is times, but error occurred return } + // collect quota for tokens billing (though error occurred) or times billing if !uncountable && quota > 0 && user != nil { user.UseQuota(db, quota) } @@ -115,7 +119,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) auth.RevertSubscriptionUsage(db, cache, user, model) - CollectQuota(conn.GetCtx(), user, buffer, plan) + CollectQuota(conn.GetCtx(), user, buffer, plan, err) conn.Send(globals.ChatSegmentResponse{ Message: err.Error(), End: true, @@ -123,7 +127,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve return err.Error() } - CollectQuota(conn.GetCtx(), user, buffer, plan) + CollectQuota(conn.GetCtx(), user, buffer, plan, err) if buffer.IsEmpty() { conn.Send(globals.ChatSegmentResponse{ diff --git a/manager/completions.go b/manager/completions.go index f932482..6a03443 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -57,11 +57,11 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] admin.AnalysisRequest(model, buffer, err) if err != nil { auth.RevertSubscriptionUsage(db, cache, user, model) - CollectQuota(c, user, buffer, plan) + CollectQuota(c, user, buffer, plan, err) return err.Error(), 0 } - CollectQuota(c, user, buffer, plan) + CollectQuota(c, user, buffer, plan, err) SaveCacheData(c, &CacheProps{ Message: segment, diff --git a/manager/transhipment.go b/manager/transhipment.go index 6b93014..a1831b7 100644 --- a/manager/transhipment.go +++ b/manager/transhipment.go @@ -196,7 +196,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, return } - CollectQuota(c, user, buffer, plan) + CollectQuota(c, user, buffer, plan, err) c.JSON(http.StatusOK, TranshipmentResponse{ Id: fmt.Sprintf("chatcmpl-%s", id), Object: "chat.completion", @@ -266,7 +266,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st } partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil) - CollectQuota(c, user, buffer, plan) + CollectQuota(c, user, buffer, plan, err) close(partial) return }() diff --git a/middleware/auth.go b/middleware/auth.go index ba0a2a6..085501e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -78,7 +78,11 @@ func AuthMiddleware() gin.HandlerFunc { instance := ProcessAuthorization(c) if viper.GetBool("serve_static") { - path = strings.TrimPrefix(path, "/api") + if !strings.HasPrefix(path, "/api") { + return + } else { + path = strings.TrimPrefix(path, "/api") + } } db := utils.GetDBFromContext(c)