mirror of
https://github.com/coaidev/coai.git
synced 2025-05-20 13:30:13 +09:00
update token counter and close channel when buffer end
This commit is contained in:
parent
a8b1082261
commit
83133b526f
@ -139,9 +139,9 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
PromptTokens: int(buffer.CountInputToken()),
|
PromptTokens: buffer.CountInputToken(),
|
||||||
CompletionTokens: int(buffer.CountOutputToken()),
|
CompletionTokens: buffer.CountOutputToken(),
|
||||||
TotalTokens: int(buffer.CountToken()),
|
TotalTokens: buffer.CountToken(),
|
||||||
},
|
},
|
||||||
Quota: buffer.GetQuota(),
|
Quota: buffer.GetQuota(),
|
||||||
})
|
})
|
||||||
@ -164,9 +164,9 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
PromptTokens: int(buffer.CountInputToken()),
|
PromptTokens: buffer.CountInputToken(),
|
||||||
CompletionTokens: int(buffer.CountOutputToken()),
|
CompletionTokens: buffer.CountOutputToken(),
|
||||||
TotalTokens: int(buffer.CountToken()),
|
TotalTokens: buffer.CountToken(),
|
||||||
},
|
},
|
||||||
Quota: buffer.GetQuota(),
|
Quota: buffer.GetQuota(),
|
||||||
}
|
}
|
||||||
@ -183,16 +183,18 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
|
|||||||
Reversible: reversible && globals.IsGPT4Model(form.Model),
|
Reversible: reversible && globals.IsGPT4Model(form.Model),
|
||||||
Token: form.MaxTokens,
|
Token: form.MaxTokens,
|
||||||
}, func(data string) error {
|
}, func(data string) error {
|
||||||
channel <- getStreamTranshipmentForm(id, created, form, data, buffer, false)
|
channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false)
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
|
channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
|
||||||
CollectQuota(c, user, buffer.GetQuota(), reversible)
|
CollectQuota(c, user, buffer.GetQuota(), reversible)
|
||||||
|
close(channel)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
|
channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
|
||||||
CollectQuota(c, user, buffer.GetQuota(), reversible)
|
CollectQuota(c, user, buffer.GetQuota(), reversible)
|
||||||
|
close(channel)
|
||||||
return
|
return
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -201,6 +203,8 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
|
|||||||
c.SSEvent("message", resp)
|
c.SSEvent("message", resp)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
w.Write([]byte("[DATA: DONE]"))
|
||||||
return false
|
return false
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -80,14 +80,14 @@ func (b *Buffer) ReadHistory() []globals.Message {
|
|||||||
return b.History
|
return b.History
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) CountInputToken() float32 {
|
func (b *Buffer) CountInputToken() int {
|
||||||
return CountInputToken(b.Model, b.ReadHistory())
|
return GetWeightByModel(b.Model) * NumTokensFromMessages(b.History, b.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) CountOutputToken() float32 {
|
func (b *Buffer) CountOutputToken() int {
|
||||||
return CountOutputToken(b.Model, b.ReadTimes())
|
return b.ReadTimes() * GetWeightByModel(b.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) CountToken() float32 {
|
func (b *Buffer) CountToken() int {
|
||||||
return b.CountInputToken() + b.CountOutputToken()
|
return b.CountInputToken() + b.CountOutputToken()
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user