mirror of
https://github.com/coaidev/coai.git
synced 2025-05-20 13:30:13 +09:00
Implemented feature: chatgpt api stream real time reception
This commit is contained in:
parent
18722a6567
commit
e6673e1c36
60
api/chat.go
Normal file
60
api/chat.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChatAPI(c *gin.Context) {
|
||||||
|
// websocket connection
|
||||||
|
upgrader := websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": false,
|
||||||
|
"message": "",
|
||||||
|
"reason": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func(conn *websocket.Conn) {
|
||||||
|
err := conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}(conn)
|
||||||
|
for {
|
||||||
|
_, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var form map[string]interface{}
|
||||||
|
if err := json.Unmarshal(message, &form); err == nil {
|
||||||
|
message := form["message"].(string)
|
||||||
|
StreamRequest("gpt-3.5-turbo-16k", []ChatGPTMessage{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: message,
|
||||||
|
},
|
||||||
|
}, 250, func(resp string) {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"message": resp,
|
||||||
|
"end": false,
|
||||||
|
})
|
||||||
|
_ = conn.WriteMessage(websocket.TextMessage, data)
|
||||||
|
})
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"message": "",
|
||||||
|
"end": true,
|
||||||
|
})
|
||||||
|
_ = conn.WriteMessage(websocket.TextMessage, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -42,7 +42,7 @@ func StreamRequest(model string, messages []ChatGPTMessage, token int, callback
|
|||||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
req, err := http.NewRequest("POST", viper.GetString("openai.user_endpoint")+"/chat/completions", utils.ConvertBody(ChatGPTRequest{
|
req, err := http.NewRequest("POST", viper.GetString("openai.anonymous_endpoint")+"/chat/completions", utils.ConvertBody(ChatGPTRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
MaxToken: token,
|
MaxToken: token,
|
||||||
@ -53,7 +53,7 @@ func StreamRequest(model string, messages []ChatGPTMessage, token int, callback
|
|||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+viper.GetString("openai.user"))
|
req.Header.Set("Authorization", "Bearer "+viper.GetString("openai.anonymous"))
|
||||||
|
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -67,7 +67,7 @@ func StreamRequest(model string, messages []ChatGPTMessage, token int, callback
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, 1024)
|
buf := make([]byte, 20480)
|
||||||
n, err := res.Body.Read(buf)
|
n, err := res.Body.Read(buf)
|
||||||
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import Login from "./components/icons/login.vue";
|
import Login from "./components/icons/login.vue";
|
||||||
import {auth} from "./assets/script/auth";
|
import {auth, username} from "./assets/script/auth";
|
||||||
|
|
||||||
function goto() {
|
function goto() {
|
||||||
window.location.href = "https://deeptrain.net/login?app=chatnio";
|
window.location.href = "https://deeptrain.net/login?app=chatnio";
|
||||||
@ -16,7 +16,8 @@ function goto() {
|
|||||||
</div>
|
</div>
|
||||||
<div class="grow" />
|
<div class="grow" />
|
||||||
<div class="user" v-if="auth">
|
<div class="user" v-if="auth">
|
||||||
|
<img class="avatar" src="https://zmh-program.site/avatar/zmh-program.webp" alt="">
|
||||||
|
<span class="username">{{ username }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="login" v-else>
|
<div class="login" v-else>
|
||||||
<button @click="goto">
|
<button @click="goto">
|
||||||
@ -59,6 +60,31 @@ aside {
|
|||||||
width: max-content;
|
width: max-content;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.user {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
margin: 28px auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.avatar {
|
||||||
|
width: 36px;
|
||||||
|
height: 36px;
|
||||||
|
border-radius: 8px;
|
||||||
|
background: var(--card-input);
|
||||||
|
border: 1px solid var(--card-input-border);
|
||||||
|
transition: .5s;
|
||||||
|
flex-shrink: 0;
|
||||||
|
user-select: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.username {
|
||||||
|
user-select: none;
|
||||||
|
font-size: 18px;
|
||||||
|
padding: 4px;
|
||||||
|
margin: 0 4px;
|
||||||
|
color: var(--card-text);
|
||||||
|
}
|
||||||
|
|
||||||
.grow {
|
.grow {
|
||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
}
|
}
|
||||||
@ -130,6 +156,16 @@ aside {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@media screen and (max-width: 340px) {
|
||||||
|
.username {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.avatar {
|
||||||
|
margin-right: 16px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@media screen and (max-width: 600px) {
|
@media screen and (max-width: 600px) {
|
||||||
.card {
|
.card {
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
@ -137,6 +173,10 @@ aside {
|
|||||||
max-height: calc(100% - 24px);
|
max-height: calc(100% - 24px);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.username {
|
||||||
|
margin-right: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
.logo span {
|
.logo span {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ import axios from "axios";
|
|||||||
|
|
||||||
export const auth = ref<boolean | undefined>(undefined);
|
export const auth = ref<boolean | undefined>(undefined);
|
||||||
export const token = ref(localStorage.getItem("token") || "");
|
export const token = ref(localStorage.getItem("token") || "");
|
||||||
|
export const username = ref("");
|
||||||
|
|
||||||
watch(token, () => {
|
watch(token, () => {
|
||||||
localStorage.setItem("token", token.value);
|
localStorage.setItem("token", token.value);
|
||||||
@ -15,6 +16,7 @@ export async function awaitUtilSetup(): Promise<any> {
|
|||||||
if (!token.value) return (auth.value = false);
|
if (!token.value) return (auth.value = false);
|
||||||
try {
|
try {
|
||||||
const resp = await axios.post("/state");
|
const resp = await axios.post("/state");
|
||||||
|
username.value = resp.data.user;
|
||||||
auth.value = resp.data.status;
|
auth.value = resp.data.status;
|
||||||
} catch {
|
} catch {
|
||||||
auth.value = false;
|
auth.value = false;
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import {nextTick, reactive, ref} from "vue";
|
import {nextTick, reactive, ref} from "vue";
|
||||||
import type { Ref } from "vue";
|
import type { Ref } from "vue";
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
|
import {auth} from "./auth";
|
||||||
|
import {ws_api} from "./conf";
|
||||||
|
|
||||||
type Message = {
|
type Message = {
|
||||||
content: string;
|
content: string;
|
||||||
@ -9,12 +11,64 @@ type Message = {
|
|||||||
stamp: number;
|
stamp: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamMessage = {
|
||||||
|
message: string;
|
||||||
|
end: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class Connection {
|
||||||
|
protected connection: WebSocket | undefined;
|
||||||
|
protected callback?: (message: StreamMessage) => void;
|
||||||
|
public state: boolean;
|
||||||
|
|
||||||
|
public constructor() {
|
||||||
|
this.state = false;
|
||||||
|
this.init();
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(): void {
|
||||||
|
this.connection = new WebSocket(ws_api + "/chat");
|
||||||
|
this.state = false;
|
||||||
|
this.connection.onopen = () => {
|
||||||
|
this.state = true;
|
||||||
|
}
|
||||||
|
this.connection.onclose = () => {
|
||||||
|
this.state = false;
|
||||||
|
setTimeout(() => {
|
||||||
|
this.init();
|
||||||
|
}, 3000);
|
||||||
|
}
|
||||||
|
this.connection.onmessage = (event) => {
|
||||||
|
const message = JSON.parse(event.data);
|
||||||
|
this.callback && this.callback(message as StreamMessage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public send(content: Record<string, any>): boolean {
|
||||||
|
if (!this.state || !this.connection) {
|
||||||
|
console.debug("Connection not ready");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
this.connection.send(JSON.stringify(content));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public close(): void {
|
||||||
|
if (!this.connection) return;
|
||||||
|
this.connection.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public setCallback(callback: (message: StreamMessage) => void): void {
|
||||||
|
this.callback = callback;
|
||||||
|
}
|
||||||
|
}
|
||||||
export class Conversation {
|
export class Conversation {
|
||||||
id: number;
|
id: number;
|
||||||
messages: Message[];
|
messages: Message[];
|
||||||
len: Ref<number>;
|
len: Ref<number>;
|
||||||
state: Ref<boolean>;
|
state: Ref<boolean>;
|
||||||
refresh: () => void;
|
refresh: () => void;
|
||||||
|
connection: Connection | undefined;
|
||||||
|
|
||||||
public constructor(id: number, refresh: () => void) {
|
public constructor(id: number, refresh: () => void) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
@ -22,9 +76,32 @@ export class Conversation {
|
|||||||
this.state = ref(false);
|
this.state = ref(false);
|
||||||
this.len = ref(0);
|
this.len = ref(0);
|
||||||
this.refresh = refresh;
|
this.refresh = refresh;
|
||||||
|
if (auth.value) this.connection = new Connection();
|
||||||
}
|
}
|
||||||
|
|
||||||
public async send(content: string): Promise<void> {
|
public async send(content: string): Promise<void> {
|
||||||
|
return await (auth.value ? this.sendAuthenticated(content) : this.sendAnonymous(content));
|
||||||
|
}
|
||||||
|
|
||||||
|
public async sendAuthenticated(content: string): Promise<void> {
|
||||||
|
this.state.value = true;
|
||||||
|
this.addMessageFromUser(content);
|
||||||
|
let message = ref(""), end = ref(false);
|
||||||
|
this.connection?.setCallback((res: StreamMessage) => {
|
||||||
|
message.value += res.message;
|
||||||
|
end.value = res.end;
|
||||||
|
})
|
||||||
|
this.addDynamicMessageFromAI(message, end);
|
||||||
|
const status = this.connection?.send({
|
||||||
|
message: content,
|
||||||
|
});
|
||||||
|
if (!status) {
|
||||||
|
this.addMessageFromAI("网络错误,请稍后再试");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async sendAnonymous(content: string): Promise<void> {
|
||||||
this.state.value = true;
|
this.state.value = true;
|
||||||
this.addMessageFromUser(content);
|
this.addMessageFromUser(content);
|
||||||
try {
|
try {
|
||||||
@ -68,6 +145,16 @@ export class Conversation {
|
|||||||
this.typingEffect(this.len.value - 1, content);
|
this.typingEffect(this.len.value - 1, content);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public addDynamicMessageFromAI(content: Ref<string>, end: Ref<boolean>): void {
|
||||||
|
this.addMessage({
|
||||||
|
content: "",
|
||||||
|
role: "bot",
|
||||||
|
time: new Date().toLocaleTimeString(),
|
||||||
|
stamp: new Date().getTime(),
|
||||||
|
})
|
||||||
|
this.dynamicTypingEffect(this.len.value - 1, content, end);
|
||||||
|
}
|
||||||
|
|
||||||
public typingEffect(index: number, content: string): void {
|
public typingEffect(index: number, content: string): void {
|
||||||
let cursor = 0;
|
let cursor = 0;
|
||||||
const interval = setInterval(() => {
|
const interval = setInterval(() => {
|
||||||
@ -81,6 +168,22 @@ export class Conversation {
|
|||||||
}, 35);
|
}, 35);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public dynamicTypingEffect(index: number, content: Ref<string>, end: Ref<boolean>): void {
|
||||||
|
let cursor = 0;
|
||||||
|
const interval = setInterval(() => {
|
||||||
|
if (end.value && cursor >= content.value.length) {
|
||||||
|
this.messages[index].content = content.value;
|
||||||
|
this.state.value = false;
|
||||||
|
clearInterval(interval);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (cursor >= content.value.length) return;
|
||||||
|
cursor++;
|
||||||
|
this.messages[index].content = content.value.substring(0, cursor);
|
||||||
|
this.refresh();
|
||||||
|
}, 35);
|
||||||
|
}
|
||||||
|
|
||||||
public getMessages(): Message[] {
|
public getMessages(): Message[] {
|
||||||
return this.messages;
|
return this.messages;
|
||||||
}
|
}
|
||||||
|
1
go.mod
1
go.mod
@ -22,6 +22,7 @@ require (
|
|||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/gorilla/websocket v1.5.0 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -150,6 +150,8 @@ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
|||||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
||||||
|
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||||
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||||
|
2
main.go
2
main.go
@ -23,8 +23,10 @@ func main() {
|
|||||||
app.Use(auth.Middleware())
|
app.Use(auth.Middleware())
|
||||||
|
|
||||||
app.POST("/anonymous", api.AnonymousAPI)
|
app.POST("/anonymous", api.AnonymousAPI)
|
||||||
|
app.GET("/chat", api.ChatAPI)
|
||||||
app.POST("/login", auth.LoginAPI)
|
app.POST("/login", auth.LoginAPI)
|
||||||
app.POST("/state", auth.StateAPI)
|
app.POST("/state", auth.StateAPI)
|
||||||
|
|
||||||
}
|
}
|
||||||
if viper.GetBool("debug") {
|
if viper.GetBool("debug") {
|
||||||
gin.SetMode(gin.DebugMode)
|
gin.SetMode(gin.DebugMode)
|
||||||
|
Loading…
Reference in New Issue
Block a user