mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 04:50:14 +09:00
224 lines
4.9 KiB
Go
224 lines
4.9 KiB
Go
package utils
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"chat/globals"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"runtime/debug"
|
|
"strings"
|
|
)
|
|
|
|
type EventScannerProps struct {
|
|
Method string
|
|
Uri string
|
|
Headers map[string]string
|
|
Body interface{}
|
|
Callback func(string) error
|
|
FullSSE bool
|
|
}
|
|
|
|
type EventScannerError struct {
|
|
Error error
|
|
Body string
|
|
}
|
|
|
|
func getErrorBody(resp *http.Response) string {
|
|
if resp == nil {
|
|
return ""
|
|
}
|
|
|
|
if content, err := io.ReadAll(resp.Body); err == nil {
|
|
return string(content)
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func EventScanner(props *EventScannerProps, config ...globals.ProxyConfig) *EventScannerError {
|
|
// panic recovery
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
stack := debug.Stack()
|
|
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", r, props.Uri, props.Method, stack))
|
|
}
|
|
}()
|
|
|
|
if globals.DebugMode {
|
|
globals.Debug(fmt.Sprintf("[sse] event source: %s %s\nheaders: %v\nbody: %v", props.Method, props.Uri, Marshal(props.Headers), Marshal(props.Body)))
|
|
}
|
|
|
|
client := newClient(config)
|
|
req, err := http.NewRequest(props.Method, props.Uri, ConvertBody(props.Body))
|
|
if err != nil {
|
|
if globals.DebugMode {
|
|
globals.Debug(fmt.Sprintf("[sse] failed to create request: %s", err))
|
|
}
|
|
|
|
return &EventScannerError{Error: err}
|
|
}
|
|
|
|
fillHeaders(req, props.Headers)
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
if globals.DebugMode {
|
|
globals.Debug(fmt.Sprintf("[sse] failed to send request: %s", err))
|
|
}
|
|
|
|
return &EventScannerError{Error: err}
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
// for error response
|
|
body := getErrorBody(resp)
|
|
if globals.DebugMode {
|
|
globals.Debug(fmt.Sprintf("[sse] request failed with status: %s\nresponse: %s", resp.Status, body))
|
|
}
|
|
|
|
return &EventScannerError{
|
|
Error: fmt.Errorf("request failed with status code: %d", resp.StatusCode),
|
|
Body: body,
|
|
}
|
|
}
|
|
|
|
if props.FullSSE {
|
|
return processFullSSE(resp.Body, props.Callback)
|
|
}
|
|
|
|
return processLegacySSE(resp.Body, props.Callback)
|
|
}
|
|
|
|
func processFullSSE(body io.ReadCloser, callback func(string) error) *EventScannerError {
|
|
scanner := bufio.NewScanner(body)
|
|
var eventType, eventData string
|
|
var buffer strings.Builder
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
if len(strings.TrimSpace(line)) == 0 {
|
|
if eventData != "" {
|
|
if eventType != "" {
|
|
buffer.WriteString("event: ")
|
|
buffer.WriteString(eventType)
|
|
buffer.WriteString("\n")
|
|
}
|
|
buffer.WriteString("data: ")
|
|
buffer.WriteString(eventData)
|
|
|
|
eventStr := buffer.String()
|
|
if globals.DebugMode {
|
|
globals.Debug(fmt.Sprintf("[sse-full] event: %s", eventStr))
|
|
}
|
|
|
|
if err := callback(eventStr); err != nil {
|
|
err := body.Close()
|
|
if err != nil {
|
|
globals.Debug(fmt.Sprintf("[sse] event source close error: %s", err.Error()))
|
|
}
|
|
return &EventScannerError{Error: err}
|
|
}
|
|
|
|
eventType = ""
|
|
eventData = ""
|
|
buffer.Reset()
|
|
}
|
|
continue
|
|
}
|
|
|
|
if strings.HasPrefix(line, "event:") {
|
|
eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
|
continue
|
|
}
|
|
|
|
if strings.HasPrefix(line, "data:") {
|
|
eventData = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
|
|
if eventData == "[DONE]" || strings.HasPrefix(eventData, "[DONE]") {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
if eventData != "" {
|
|
if eventType != "" {
|
|
buffer.WriteString("event: ")
|
|
buffer.WriteString(eventType)
|
|
buffer.WriteString("\n")
|
|
}
|
|
buffer.WriteString("data: ")
|
|
buffer.WriteString(eventData)
|
|
|
|
eventStr := buffer.String()
|
|
if globals.DebugMode {
|
|
globals.Debug(fmt.Sprintf("[sse-full] last event: %s", eventStr))
|
|
}
|
|
|
|
if err := callback(eventStr); err != nil {
|
|
return &EventScannerError{Error: err}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func processLegacySSE(body io.ReadCloser, callback func(string) error) *EventScannerError {
|
|
scanner := bufio.NewScanner(body)
|
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
if atEOF && len(data) == 0 {
|
|
// when EOF and empty data
|
|
return 0, nil, nil
|
|
}
|
|
|
|
if idx := bytes.Index(data, []byte("\n")); idx >= 0 {
|
|
// when found new line
|
|
return idx + 1, data[:idx], nil
|
|
}
|
|
|
|
if atEOF {
|
|
// when EOF and no new line
|
|
return len(data), data, nil
|
|
}
|
|
|
|
// when need more data
|
|
return 0, nil, nil
|
|
})
|
|
|
|
for scanner.Scan() {
|
|
raw := scanner.Text()
|
|
|
|
if len(raw) <= 5 || !strings.HasPrefix(raw, "data:") {
|
|
// for only `data:` partial raw or unexpected chunk
|
|
continue
|
|
}
|
|
|
|
if globals.DebugMode {
|
|
globals.Debug(fmt.Sprintf("[sse] chunk: %s", raw))
|
|
}
|
|
|
|
chunk := strings.TrimSpace(strings.TrimPrefix(raw, "data:"))
|
|
if chunk == "[DONE]" || strings.HasPrefix(chunk, "[DONE]") {
|
|
// for done signal
|
|
continue
|
|
}
|
|
|
|
// callback chunk
|
|
if err := callback(chunk); err != nil {
|
|
// break connection on callback error
|
|
err := body.Close()
|
|
if err != nil {
|
|
globals.Debug(fmt.Sprintf("[sse] event source close error: %s", err.Error()))
|
|
}
|
|
|
|
return &EventScannerError{Error: err}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|