coai/utils/scanner.go
2025-04-29 19:12:52 +08:00

224 lines
4.9 KiB
Go

package utils
import (
"bufio"
"bytes"
"chat/globals"
"fmt"
"io"
"net/http"
"runtime/debug"
"strings"
)
type EventScannerProps struct {
Method string
Uri string
Headers map[string]string
Body interface{}
Callback func(string) error
FullSSE bool
}
type EventScannerError struct {
Error error
Body string
}
func getErrorBody(resp *http.Response) string {
if resp == nil {
return ""
}
if content, err := io.ReadAll(resp.Body); err == nil {
return string(content)
}
return ""
}
func EventScanner(props *EventScannerProps, config ...globals.ProxyConfig) *EventScannerError {
// panic recovery
defer func() {
if r := recover(); r != nil {
stack := debug.Stack()
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", r, props.Uri, props.Method, stack))
}
}()
if globals.DebugMode {
globals.Debug(fmt.Sprintf("[sse] event source: %s %s\nheaders: %v\nbody: %v", props.Method, props.Uri, Marshal(props.Headers), Marshal(props.Body)))
}
client := newClient(config)
req, err := http.NewRequest(props.Method, props.Uri, ConvertBody(props.Body))
if err != nil {
if globals.DebugMode {
globals.Debug(fmt.Sprintf("[sse] failed to create request: %s", err))
}
return &EventScannerError{Error: err}
}
fillHeaders(req, props.Headers)
resp, err := client.Do(req)
if err != nil {
if globals.DebugMode {
globals.Debug(fmt.Sprintf("[sse] failed to send request: %s", err))
}
return &EventScannerError{Error: err}
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
// for error response
body := getErrorBody(resp)
if globals.DebugMode {
globals.Debug(fmt.Sprintf("[sse] request failed with status: %s\nresponse: %s", resp.Status, body))
}
return &EventScannerError{
Error: fmt.Errorf("request failed with status code: %d", resp.StatusCode),
Body: body,
}
}
if props.FullSSE {
return processFullSSE(resp.Body, props.Callback)
}
return processLegacySSE(resp.Body, props.Callback)
}
func processFullSSE(body io.ReadCloser, callback func(string) error) *EventScannerError {
scanner := bufio.NewScanner(body)
var eventType, eventData string
var buffer strings.Builder
for scanner.Scan() {
line := scanner.Text()
if len(strings.TrimSpace(line)) == 0 {
if eventData != "" {
if eventType != "" {
buffer.WriteString("event: ")
buffer.WriteString(eventType)
buffer.WriteString("\n")
}
buffer.WriteString("data: ")
buffer.WriteString(eventData)
eventStr := buffer.String()
if globals.DebugMode {
globals.Debug(fmt.Sprintf("[sse-full] event: %s", eventStr))
}
if err := callback(eventStr); err != nil {
err := body.Close()
if err != nil {
globals.Debug(fmt.Sprintf("[sse] event source close error: %s", err.Error()))
}
return &EventScannerError{Error: err}
}
eventType = ""
eventData = ""
buffer.Reset()
}
continue
}
if strings.HasPrefix(line, "event:") {
eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
continue
}
if strings.HasPrefix(line, "data:") {
eventData = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if eventData == "[DONE]" || strings.HasPrefix(eventData, "[DONE]") {
continue
}
}
}
if eventData != "" {
if eventType != "" {
buffer.WriteString("event: ")
buffer.WriteString(eventType)
buffer.WriteString("\n")
}
buffer.WriteString("data: ")
buffer.WriteString(eventData)
eventStr := buffer.String()
if globals.DebugMode {
globals.Debug(fmt.Sprintf("[sse-full] last event: %s", eventStr))
}
if err := callback(eventStr); err != nil {
return &EventScannerError{Error: err}
}
}
return nil
}
func processLegacySSE(body io.ReadCloser, callback func(string) error) *EventScannerError {
scanner := bufio.NewScanner(body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
// when EOF and empty data
return 0, nil, nil
}
if idx := bytes.Index(data, []byte("\n")); idx >= 0 {
// when found new line
return idx + 1, data[:idx], nil
}
if atEOF {
// when EOF and no new line
return len(data), data, nil
}
// when need more data
return 0, nil, nil
})
for scanner.Scan() {
raw := scanner.Text()
if len(raw) <= 5 || !strings.HasPrefix(raw, "data:") {
// for only `data:` partial raw or unexpected chunk
continue
}
if globals.DebugMode {
globals.Debug(fmt.Sprintf("[sse] chunk: %s", raw))
}
chunk := strings.TrimSpace(strings.TrimPrefix(raw, "data:"))
if chunk == "[DONE]" || strings.HasPrefix(chunk, "[DONE]") {
// for done signal
continue
}
// callback chunk
if err := callback(chunk); err != nil {
// break connection on callback error
err := body.Close()
if err != nil {
globals.Debug(fmt.Sprintf("[sse] event source close error: %s", err.Error()))
}
return &EventScannerError{Error: err}
}
}
return nil
}