feat: support claude 3 opus, sonnet and haiku (support vision) (#93)

This commit is contained in:
Zhang Minghan 2024-03-14 23:52:31 +08:00
parent e037276387
commit 14dc10025f
9 changed files with 226 additions and 72 deletions

View File

@ -75,7 +75,7 @@ English | [简体中文](https://github.com/Deeptrain-Community/chatnio/blob/mas
- [x] Chat Completions (support *vision*, *tools_calling* and *function_calling*)
- [x] Image Generation
- [x] Azure OpenAI
- [x] Anthropic Claude (claude-2, claude-2.1, claude-instant)
- [x] Anthropic Claude (support *vision*)
- [x] Slack Claude (deprecated)
- [x] Sparkdesk (support *function_calling*)
- [x] Google Gemini (PaLM2)

View File

@ -4,8 +4,8 @@ import (
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"errors"
"fmt"
"strings"
)
const defaultTokens = 2500
@ -53,9 +53,48 @@ func (c *ChatInstance) GetTokens(props *adaptercommon.ChatProps) int {
return *props.MaxTokens
}
func (c *ChatInstance) GetMessages(props *adaptercommon.ChatProps) []Message {
return utils.Each(props.Message, func(message globals.Message) Message {
if !globals.IsVisionModel(props.Model) || message.Role != globals.User {
return Message{
Role: message.Role,
Content: message.Content,
}
}
content, urls := utils.ExtractImages(message.Content, true)
images := utils.EachNotNil(urls, func(url string) *MessageContent {
obj, err := utils.NewImage(url)
props.Buffer.AddImage(obj)
if err != nil {
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), utils.Extract(url, 24, "...")))
}
i := utils.NewImageContent(url)
return &MessageContent{
Type: "image",
Source: &MessageImage{
Type: "base64",
MediaType: i.GetType(),
Data: i.ToRawBase64(),
},
}
})
return Message{
Role: message.Role,
Content: utils.Prepend(images, MessageContent{
Type: "text",
Text: &content,
}),
}
})
}
func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool) *ChatBody {
messages := c.GetMessages(props)
return &ChatBody{
Messages: props.Message,
Messages: messages,
MaxTokens: c.GetTokens(props),
Model: props.Model,
Stream: stream,
@ -65,69 +104,63 @@ func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool)
}
}
// CreateChatRequest is the request for anthropic claude
func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) {
data, err := utils.Post(c.GetChatEndpoint(), c.GetChatHeaders(), c.GetChatBody(props, false), props.Proxy)
if err != nil {
return "", fmt.Errorf("claude error: %s", err.Error())
func (c *ChatInstance) ProcessLine(data string) (*globals.Chunk, error) {
if form := processChatResponse(data); form != nil {
return &globals.Chunk{
Content: form.Delta.Text,
}, nil
}
if form := utils.MapToStruct[ChatResponse](data); form != nil {
return form.Completion, nil
if form := processChatErrorResponse(data); form != nil {
return &globals.Chunk{Content: ""}, fmt.Errorf("anthropic error: %s (type: %s)", form.Error.Message, form.Error.Type)
}
return "", fmt.Errorf("claude error: invalid response")
return &globals.Chunk{Content: ""}, nil
}
func (c *ChatInstance) ProcessLine(buf, data string) (string, error) {
// response example:
//
// event:completion
// data:{"completion":"!","stop_reason":null,"model":"claude-2.0","stop":null,"log_id":"f5f659a5807419c94cfac4a9f2f79a66e95733975714ce7f00e30689dd136b02"}
if !strings.HasPrefix(data, "data:") && strings.HasPrefix(data, "event:") {
return "", nil
} else {
data = strings.TrimSpace(strings.TrimPrefix(data, "data:"))
func processChatErrorResponse(data string) *ChatErrorResponse {
if form := utils.UnmarshalForm[ChatErrorResponse](data); form != nil {
return form
}
return nil
}
if len(data) == 0 {
return "", nil
func processChatResponse(data string) *ChatStreamResponse {
if form := utils.UnmarshalForm[ChatStreamResponse](data); form != nil {
return form
}
if form := utils.UnmarshalForm[ChatResponse](data); form != nil {
return form.Completion, nil
}
data = buf + data
if form := utils.UnmarshalForm[ChatResponse](data); form != nil {
return form.Completion, nil
}
globals.Warn(fmt.Sprintf("anthropic error: cannot parse response: %s", data))
return "", fmt.Errorf("claude error: invalid response")
return nil
}
// CreateStreamChatRequest is the stream request for anthropic claude
func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, hook globals.Hook) error {
buf := ""
return utils.EventSource(
"POST",
c.GetChatEndpoint(),
c.GetChatHeaders(),
c.GetChatBody(props, true),
func(data string) error {
if resp, err := c.ProcessLine(buf, data); err == nil && len(resp) > 0 {
buf = ""
if err := hook(&globals.Chunk{Content: resp}); err != nil {
return err
}
} else {
buf = buf + data
err := utils.EventScanner(&utils.EventScannerProps{
Method: "POST",
Uri: c.GetChatEndpoint(),
Headers: c.GetChatHeaders(),
Body: c.GetChatBody(props, true),
Callback: func(data string) error {
partial, err := c.ProcessLine(data)
if err != nil {
return err
}
return nil
return hook(partial)
},
},
props.Proxy,
)
if err != nil {
if form := processChatErrorResponse(err.Body); form != nil {
if form.Error.Type == "" && form.Error.Message == "" {
return errors.New(utils.ToMarkdownCode("json", err.Body))
}
return errors.New(fmt.Sprintf("%s (type: %s)", form.Error.Message, form.Error.Type))
}
return fmt.Errorf("%s\n%s", err.Error, errors.New(utils.ToMarkdownCode("json", err.Body)))
}
return nil
}

View File

@ -1,20 +1,46 @@
package claude
import "chat/globals"
// ChatBody is the request body for anthropic claude
type ChatBody struct {
Messages []globals.Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Model string `json:"model"`
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
type Message struct {
Role string `json:"role"`
Content interface{} `json:"content"`
}
// ChatResponse is the native http request and stream response for anthropic claude
type ChatResponse struct {
Completion string `json:"completion"`
LogId string `json:"log_id"`
type MessageImage struct {
Type string `json:"type"`
MediaType interface{} `json:"media_type"`
Data interface{} `json:"data"`
}
type MessageContent struct {
Type string `json:"type"`
Text *string `json:"text,omitempty"`
Source *MessageImage `json:"source,omitempty"`
}
type ChatBody struct {
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Model string `json:"model"`
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
}
type ChatStreamResponse struct {
Type string `json:"type"`
Index int `json:"index"`
Delta struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"delta"`
}
type ChatErrorResponse struct {
Error struct {
Type string `json:"type" binding:"required"`
Message string `json:"message"`
} `json:"error"`
}

View File

@ -100,6 +100,24 @@ export const pricing: PricingDataset = [
input: 0.008,
output: 0.024,
},
// claude 3 haiku $0.25/1m tokens input & $1.25/1m tokens output
{
models: ["claude-3-haiku-20240307"],
input: 0.00025,
output: 0.00125,
},
// claude 3 sonnet $3/1m tokens input & $15/1m tokens output
{
models: ["claude-3-sonnet-20240229"],
input: 0.003,
output: 0.015,
},
// claude 3 sonnet $15/1m tokens input & $75/1m tokens output
{
models: ["claude-3-opus-20240229"],
input: 0.015,
output: 0.075,
},
{
models: ["midjourney"],
output: 0.1,

View File

@ -10,7 +10,7 @@ const initialState: Channel = {
models: [],
priority: 0,
weight: 1,
retry: 3,
retry: 1,
secret: "",
endpoint: getChannelInfo().endpoint,
mapper: "",

View File

@ -94,7 +94,7 @@ function SyncDialog({ dispatch, open, setOpen }: SyncDialogProps) {
models: resp.data,
priority: 0,
weight: 1,
retry: 3,
retry: 1,
secret,
endpoint,
mapper: "",

View File

@ -2,7 +2,7 @@ import { parseFile } from "@/components/plugins/file.tsx";
import { parseProgressbar } from "@/components/plugins/progress.tsx";
import { cn } from "@/components/ui/lib/utils.ts";
import { copyClipboard } from "@/utils/dom.ts";
import { Copy } from "lucide-react";
import { Check, Copy } from "lucide-react";
import { LightAsync as SyntaxHighlighter } from "react-syntax-highlighter";
import { atomOneDark as style } from "react-syntax-highlighter/dist/esm/styles/hljs";
import React from "react";
@ -30,6 +30,7 @@ export default function ({
codeStyle,
...props
}: CodeProps) {
const [copied, setCopied] = React.useState(false);
const match = /language-(\w+)/.exec(className || "");
const language = match ? match[1].toLowerCase() : "unknown";
if (language === "file") return parseFile(children);
@ -49,9 +50,14 @@ export default function ({
onClick={async () => {
const text = children?.toString() || "";
await copyClipboard(text);
setCopied(true);
}}
>
<Copy className={`h-3 w-3`} />
{copied ? (
<Check className={`h-3 w-3`} />
) : (
<Copy className={`h-3 w-3`} />
)}
<p>{language}</p>
</div>
<SyntaxHighlighter

View File

@ -16,7 +16,8 @@ import (
)
type Image struct {
Object image.Image
Object image.Image
Content string
}
type Images []Image
@ -37,7 +38,7 @@ func NewImage(url string) (*Image, error) {
return nil, err
}
return &Image{Object: img}, nil
return &Image{Object: img, Content: url}, nil
}
res, err := http.Get(url)
@ -70,7 +71,11 @@ func NewImage(url string) (*Image, error) {
img = ticks.Image[0]
}
return &Image{Object: img}, nil
return &Image{Object: img, Content: url}, nil
}
func NewImageContent(content string) *Image {
return &Image{Content: content}
}
func ConvertToBase64(url string) (string, error) {
@ -130,6 +135,66 @@ func (i *Image) CountTokens(model string) int {
return 0
}
func (i *Image) IsBase64() bool {
return strings.HasPrefix(i.Content, "data:image/")
}
func (i *Image) GetType() string {
// example: image/jpeg, image/png, image/gif
if i.IsBase64() {
t := SafeSplit(i.Content, ";", 2)[0]
return strings.ReplaceAll(t, "data:", "")
}
// example: .jpg, .png, .gif to image/jpeg, image/png, image/gif
switch strings.ToLower(path.Ext(i.Content)) {
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".gif":
return "image/gif"
case ".webp":
return "image/webp"
case ".bmp":
return "image/bmp"
default:
return ""
}
}
func (i *Image) ToBase64() string {
if i.IsBase64() {
return i.Content
}
// get url content and convert to base64
data, err := ConvertToBase64(i.Content)
if err != nil {
globals.Warn(fmt.Sprintf("cannot convert image to base64: %s", err.Error()))
return ""
}
return fmt.Sprintf("data:%s;base64,%s", i.GetType(), data)
}
func (i *Image) ToRawBase64() string {
// example: return /9j/...
if i.IsBase64() {
return SafeSplit(i.Content, ",", 2)[1]
}
// get url content and convert to base64
data, err := ConvertToBase64(i.Content)
if err != nil {
globals.Warn(fmt.Sprintf("cannot convert image to base64: %s", err.Error()))
return ""
}
return data
}
func DownloadImage(url string, path string) error {
res, err := http.Get(url)
if err != nil {

View File

@ -104,6 +104,12 @@ func EventScanner(props *EventScannerProps, config ...globals.ProxyConfig) *Even
// callback chunk
if err := props.Callback(chunk); err != nil {
// break connection on callback error
err := resp.Body.Close()
if err != nil {
globals.Debug(fmt.Sprintf("[sse] event source close error: %s", err.Error()))
}
return &EventScannerError{Error: err}
}
}