From 9db6bd205d46f0e99096ea8057dd63c939f5924e Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Tue, 12 Mar 2024 17:25:51 +0800 Subject: [PATCH] feat: support grop cloud (#108) --- README.md | 1 + README_zh-CN.md | 1 + adapter/adapter.go | 1 + app/src/admin/channel.ts | 32 ++++- app/src/admin/colors.ts | 4 + app/src/admin/datasets/charge.ts | 5 + .../admin/assemblies/ChannelTable.tsx | 4 +- globals/constant.go | 1 + utils/sse.go | 117 ------------------ 9 files changed, 46 insertions(+), 120 deletions(-) diff --git a/README.md b/README.md index e3ad16a..1d4a0e3 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ English | [简体中文](https://github.com/Deeptrain-Community/chatnio/blob/mas - [x] Tencent Hunyuan - [x] Baichuan AI - [x] Moonshot AI +- [x] Groq Cloud AI - [x] ByteDance Skylark (support *function_calling*) - [x] 360 GPT - [x] LocalAI (Stable Diffusion, RWKV, LLaMa 2, Baichuan 7b, Mixtral, ...) _*requires local deployment_ diff --git a/README_zh-CN.md b/README_zh-CN.md index 8594275..d45de58 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -108,6 +108,7 @@ _🚀 **Next Generation AI One-Stop Solution**_ - [x] Tencent Hunyuan - [x] Baichuan AI - [x] Moonshot AI +- [x] Groq Cloud AI - [x] ByteDance Skylark (support *function_calling*) - [x] 360 GPT - [x] LocalAI (Stable Diffusion, RWKV, LLaMa 2, Baichuan 7b, Mixtral, ...) _*requires local deployment_ diff --git a/adapter/adapter.go b/adapter/adapter.go index 4c76a05..9dd10ee 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -37,6 +37,7 @@ var channelFactories = map[string]adaptercommon.FactoryCreator{ globals.MidjourneyChannelType: midjourney.NewChatInstanceFromConfig, globals.MoonshotChannelType: openai.NewChatInstanceFromConfig, // openai format + globals.GroqChannelType: openai.NewChatInstanceFromConfig, // openai format } func createChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps, hook globals.Hook) error { diff --git a/app/src/admin/channel.ts b/app/src/admin/channel.ts index 2826ce3..3c52230 100644 --- a/app/src/admin/channel.ts +++ b/app/src/admin/channel.ts @@ -43,10 +43,30 @@ export const ChannelTypes: Record = { zhinao: "360智脑 360GLM", baichuan: "百川大模型 BaichuanAI", skylark: "云雀大模型 SkylarkLLM", + groq: "Groq Cloud", bing: "New Bing", slack: "Slack Claude", }; +export const ShortChannelTypes: Record = { + openai: "OpenAI", + azure: "Azure", + claude: "Claude", + palm: "Gemini", + midjourney: "Midjourney", + sparkdesk: "讯飞星火", + chatglm: "ChatGLM", + moonshot: "Moonshot", + qwen: "通义千问", + hunyuan: "腾讯混元", + zhinao: "360 智脑", + baichuan: "百川 AI", + skylark: "火山方舟", + groq: "Groq", + bing: "Bing", + slack: "Slack", +} + export const ChannelInfos: Record = { openai: { endpoint: "https://api.openai.com", @@ -212,7 +232,12 @@ export const ChannelInfos: Record = { endpoint: "https://api.moonshot.cn", format: "", models: ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"], - } + }, + groq: { + endpoint: "https://api.groq.com/openai", + format: "", + models: ["llama2-70b-4096", "mixtral-8x7b-32768", "gemma-7b-it"] + } }; export const defaultChannelModels: string[] = getUniqueList( @@ -236,3 +261,8 @@ export function getChannelType(type?: string): string { if (type && type in ChannelTypes) return ChannelTypes[type]; return ChannelTypes.openai; } + +export function getShortChannelType(type?: string): string { + if (type && type in ShortChannelTypes) return ShortChannelTypes[type]; + return ShortChannelTypes.openai; +} diff --git a/app/src/admin/colors.ts b/app/src/admin/colors.ts index 901a98a..a3acac2 100644 --- a/app/src/admin/colors.ts +++ b/app/src/admin/colors.ts @@ -50,6 +50,10 @@ export const modelColorMapper: Record = { "moonshot-v1-32k": "black-500", "moonshot-v1-128k": "black-500", + "llama2-70b-4096": "red-500", + "mixtral-8x7b-32768": "red-500", + "gemma-7b-it": "red-500", + "chat-bison-001": "red-500", "gemini-pro": "red-500", "gemini-pro-vision": "red-500", diff --git a/app/src/admin/datasets/charge.ts b/app/src/admin/datasets/charge.ts index 34c6138..a24739e 100644 --- a/app/src/admin/datasets/charge.ts +++ b/app/src/admin/datasets/charge.ts @@ -216,6 +216,11 @@ export const pricing: PricingDataset = [ output: 0.011, currency: Currency.CNY, }, + { + models: ["llama2-70b-4096", "mixtral-8x7b-32768", "gemma-7b-it"], + output: 0.001, // free marked as $0.001 + currency: Currency.USD, + } ]; const countPricing = ( diff --git a/app/src/components/admin/assemblies/ChannelTable.tsx b/app/src/components/admin/assemblies/ChannelTable.tsx index 00d0b68..50eb457 100644 --- a/app/src/components/admin/assemblies/ChannelTable.tsx +++ b/app/src/components/admin/assemblies/ChannelTable.tsx @@ -18,7 +18,7 @@ import { import { Button } from "@/components/ui/button.tsx"; import OperationAction from "@/components/OperationAction.tsx"; import { Dispatch, useEffect, useMemo, useState } from "react"; -import { Channel, getChannelType } from "@/admin/channel.ts"; +import { Channel, getShortChannelType } from "@/admin/channel.ts"; import { toastState } from "@/api/common.ts"; import { useTranslation } from "react-i18next"; import { useEffectAsync } from "@/utils/hook.ts"; @@ -55,7 +55,7 @@ type TypeBadgeProps = { }; function TypeBadge({ type }: TypeBadgeProps) { - const content = useMemo(() => getChannelType(type), [type]); + const content = useMemo(() => getShortChannelType(type), [type]); return ( diff --git a/globals/constant.go b/globals/constant.go index a056913..d9b0e0c 100644 --- a/globals/constant.go +++ b/globals/constant.go @@ -24,6 +24,7 @@ const ( PalmChannelType = "palm" MidjourneyChannelType = "midjourney" MoonshotChannelType = "moonshot" + GroqChannelType = "groq" ) const ( diff --git a/utils/sse.go b/utils/sse.go index 56c2715..69e1f6f 100644 --- a/utils/sse.go +++ b/utils/sse.go @@ -1,10 +1,6 @@ package utils import ( - "bufio" - "bytes" - "chat/globals" - "crypto/tls" "fmt" "io" "net/http" @@ -82,116 +78,3 @@ func NewEndEvent() StreamEvent { Data: "data: [DONE]", } } - -func SSEClient(method string, uri string, headers map[string]string, body interface{}, callback func(string) error) error { - http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - - client := newClient() - req, err := http.NewRequest(method, uri, ConvertBody(body)) - if err != nil { - return nil - } - for key, value := range headers { - req.Header.Set(key, value) - } - - res, err := client.Do(req) - if err != nil { - return err - } - - defer res.Body.Close() - - if res.StatusCode >= 400 { - return fmt.Errorf("request failed with status: %s", res.Status) - } - - events, err := CreateSSEInstance(res) - if err != nil { - return err - } - - select { - case ev := <-events: - if err := callback(ev.Data); err != nil { - return err - } - } - - return nil -} - -// Event represents a Server-Sent Event -type Event struct { - Name string - ID string - Data string -} - -func CreateSSEInstance(resp *http.Response) (chan Event, error) { - events := make(chan Event) - reader := bufio.NewReader(resp.Body) - - go loop(reader, events) - - return events, nil -} - -func loop(reader *bufio.Reader, events chan Event) { - ev := Event{} - - var buf bytes.Buffer - - for { - line, err := reader.ReadBytes('\n') - if err != nil { - globals.Info(fmt.Sprintf("[sse] error during read response body: %s", err)) - close(events) - } - - switch { - case ioPrefix(line, ":"): - // Comment, do nothing - case ioPrefix(line, "retry:"): - // Retry, do nothing for now - - // id of event - case ioPrefix(line, "id: "): - ev.ID = string(line[4:]) - case ioPrefix(line, "id:"): - ev.ID = string(line[3:]) - - // name of event - case ioPrefix(line, "event: "): - ev.Name = string(line[7 : len(line)-1]) - case ioPrefix(line, "event:"): - ev.Name = string(line[6 : len(line)-1]) - - // event data - case ioPrefix(line, "data: "): - buf.Write(line[6:]) - case ioPrefix(line, "data:"): - buf.Write(line[5:]) - - // end of event - case bytes.Equal(line, []byte("\n")): - b := buf.Bytes() - - if ioPrefix(b, "{") { - if err == nil { - ev.Data = string(b) - buf.Reset() - events <- ev - ev = Event{} - } - } - - default: - globals.Info(fmt.Sprintf("[sse] unknown line: %s", line)) - } - } -} - -func ioPrefix(s []byte, prefix string) bool { - return bytes.HasPrefix(s, []byte(prefix)) -}