implement online chatgpt feature

This commit is contained in:
Zhang Minghan 2023-08-09 17:54:38 +08:00
parent 47dba52cc3
commit ed8e78f823
11 changed files with 199 additions and 14 deletions

View File

@ -21,14 +21,14 @@ func GetAnonymousResponse(message string) (string, error) {
"Content-Type": "application/json",
"Authorization": "Bearer " + viper.GetString("openai.anonymous"),
}, types.ChatGPTRequest{
Model: "gpt-3.5-turbo-16k",
Messages: []types.ChatGPTMessage{
Model: "gpt-3.5-turbo",
Messages: ChatWithWeb([]types.ChatGPTMessage{
{
Role: "user",
Content: message,
},
},
MaxToken: 250,
}, message),
MaxToken: 1000,
})
if err != nil {
return "", err

View File

@ -63,8 +63,8 @@ func ChatAPI(c *gin.Context) {
if err != nil {
return
}
if _, err := instance.AddMessageFromUserForm(message); err == nil {
StreamRequest("gpt-3.5-turbo", instance.GetMessageSegment(5), 500, func(resp string) {
if msg, err := instance.AddMessageFromUserForm(message); err == nil {
StreamRequest("gpt-3.5-turbo", ChatWithWeb(instance.GetMessageSegment(5), msg), 1500, func(resp string) {
data, _ := json.Marshal(map[string]interface{}{
"message": resp,
"end": false,

View File

@ -5,7 +5,6 @@ import (
"chat/utils"
"crypto/tls"
"encoding/json"
"fmt"
"github.com/spf13/viper"
"io"
"log"
@ -23,13 +22,20 @@ func processLine(buf []byte) []string {
array := strings.Split(data, "\n")
resp := make([]string, 0)
for _, item := range array {
item = fmt.Sprintf("{%s}", strings.TrimSpace(item))
item = strings.TrimSpace(item)
if !strings.HasPrefix(item, "{") {
item = "{" + item
}
if !strings.HasSuffix(item, "}}") {
item = item + "}"
}
if item == "{data: [DONE]}" {
break
}
var form types.ChatGPTStreamResponse
if err := json.Unmarshal([]byte(item), &form); err != nil {
log.Fatal(err)
log.Println(item, err)
}
choices := form.Data.Choices
if len(choices) > 0 {
@ -50,7 +56,7 @@ func StreamRequest(model string, messages []types.ChatGPTMessage, token int, cal
Stream: true,
}))
if err != nil {
log.Fatal(err)
return
}
req.Header.Set("Content-Type", "application/json")

16
api/utils.go Normal file
View File

@ -0,0 +1,16 @@
package api
import (
"chat/types"
"chat/utils"
"chat/web"
"fmt"
"time"
)
func ChatWithWeb(message []types.ChatGPTMessage, current string) []types.ChatGPTMessage {
return utils.Insert(message, 0, types.ChatGPTMessage{
Role: "system",
Content: fmt.Sprintf("你将扮演AI问答助手你的知识库不是截止于2021年9月而是可以实时联网的你可以提供实时联网的知识库。当前时间: %s, 你的知识库:%s", time.Now().Format("2006-01-02 15:04:05"), web.SearchBing(current)),
})
}

View File

@ -74,8 +74,8 @@ function toggle(n: boolean) {
box-shadow: 0 0 16px var(--card-shadow);
width: calc(100% - 32px);
height: 100%;
max-width: 1000px;
max-height: 600px;
max-width: 1200px;
max-height: 650px;
}
aside {

View File

@ -85,7 +85,6 @@ export class Conversation {
}
public notReady(): boolean {
console.log(Boolean(auth.value && !this.connection?.state))
return Boolean(auth.value && !this.connection?.state);
}

View File

@ -3,6 +3,7 @@ package conversation
import (
"chat/types"
"chat/utils"
"errors"
)
type Conversation struct {
@ -75,11 +76,25 @@ func (c *Conversation) AddMessageFromSystem(message string) {
})
}
func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
func GetMessage(data []byte) (string, error) {
form, err := utils.Unmarshal[FormMessage](data)
if err != nil {
return "", err
}
if len(form.Message) == 0 {
return "", errors.New("message is empty")
}
return form.Message, nil
}
func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
form, err := utils.Unmarshal[FormMessage](data)
if err != nil {
return "", err
} else if len(form.Message) == 0 {
return "", errors.New("message is empty")
}
c.Message = append(c.Message, types.ChatGPTMessage{
Role: "user",
Content: form.Message,

View File

@ -1,5 +1,7 @@
package utils
import "fmt"
func Contains[T comparable](value T, slice []T) bool {
for _, item := range slice {
if item == value {
@ -8,3 +10,22 @@ func Contains[T comparable](value T, slice []T) bool {
}
return false
}
func TryGet[T any](arr []T, index int) T {
if index >= len(arr) {
return arr[0]
}
return arr[index]
}
func Debug[T any](v T) T {
fmt.Println(v)
return v
}
func Insert[T any](arr []T, index int, value T) []T {
arr = append(arr, value)
copy(arr[index+1:], arr[index:])
arr[index] = value
return arr
}

View File

@ -31,11 +31,42 @@ func Http(uri string, method string, ptr interface{}, headers map[string]string,
return nil
}
func HttpRaw(uri string, method string, headers map[string]string, body io.Reader) (data []byte, err error) {
req, err := http.NewRequest(method, uri, body)
if err != nil {
return nil, err
}
for key, value := range headers {
req.Header.Set(key, value)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if data, err = io.ReadAll(resp.Body); err != nil {
return nil, err
}
return data, nil
}
func Get(uri string, headers map[string]string) (data interface{}, err error) {
err = Http(uri, http.MethodGet, &data, headers, nil)
return data, err
}
func GetRaw(uri string, headers map[string]string) (data string, err error) {
buffer, err := HttpRaw(uri, http.MethodGet, headers, nil)
if err != nil {
return "", err
}
return string(buffer), nil
}
func Post(uri string, headers map[string]string, body interface{}) (data interface{}, err error) {
err = Http(uri, http.MethodPost, &data, headers, ConvertBody(body))
return data, err

68
web/parser.go Normal file
View File

@ -0,0 +1,68 @@
package web
import (
"chat/utils"
"golang.org/x/net/html"
"regexp"
"strings"
)
var unexpected = []string{
"<cite>",
"<span class=\"sb_count\">",
"<div class=\"ntf_label toggle_label nt_tit\" id=\"ntf_newtabfil_label\">",
"<img role=\"presentation\"",
}
func ParseBing(source string) string {
body := SplitPagination(GetMainBody(source))
res := strings.Join(GetContent(body), " ")
return html.UnescapeString(res)
}
func GetMainBody(html string) string {
suf := utils.TryGet(strings.Split(html, "<main aria-label=\"搜索结果\">"), 1)
return strings.Split(suf, "</main>")[0]
}
func SplitPagination(html string) string {
pre := strings.Split(html, "<li class=\"b_msg b_canvas\">")[0]
return utils.TryGet(strings.Split(pre, "<div class=\"ntf_label toggle_label nt_tit\" id=\"ntf_newtabfil_label\">在新选项卡中打开链接</div>"), 1)
}
func GetContent(html string) []string {
re := regexp.MustCompile(`>([^<]+)<`)
matches := re.FindAllString(html, -1)
return FilterContent(matches)
}
func IsExpected(data string) bool {
if IsLink(data) {
return false
}
for _, str := range unexpected {
if strings.HasPrefix(data, str) {
return false
}
}
return true
}
func IsLink(input string) bool {
re := regexp.MustCompile(`^(https?|ftp):\/\/[^\s/$.?#].\S*$`)
return re.MatchString(input)
}
func FilterContent(matches []string) []string {
res := make([]string, 0)
for _, match := range matches {
source := strings.TrimSpace(match[1 : len(match)-1])
if len(source) > 0 && IsExpected(source) {
res = append(res, source)
}
}
return res
}

29
web/search.go Normal file
View File

@ -0,0 +1,29 @@
package web
import (
"chat/utils"
"net/url"
)
func GetBingUrl(q string) string {
return "https://cn.bing.com/search?q=" + url.QueryEscape(q)
}
func RequestWithUA(url string) string {
data, err := utils.GetRaw(url, map[string]string{
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/116.0",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
})
if err != nil {
return ""
}
return data
}
func SearchBing(q string) string {
uri := GetBingUrl(q)
data := RequestWithUA(uri)
return ParseBing(data)
}