feat: support grop cloud (#108)

This commit is contained in:
Zhang Minghan 2024-03-12 17:25:51 +08:00
parent 0ecd49e312
commit 9db6bd205d
9 changed files with 46 additions and 120 deletions

View File

@ -88,6 +88,7 @@ English | [简体中文](https://github.com/Deeptrain-Community/chatnio/blob/mas
- [x] Tencent Hunyuan
- [x] Baichuan AI
- [x] Moonshot AI
- [x] Groq Cloud AI
- [x] ByteDance Skylark (support *function_calling*)
- [x] 360 GPT
- [x] LocalAI (Stable Diffusion, RWKV, LLaMa 2, Baichuan 7b, Mixtral, ...) _*requires local deployment_

View File

@ -108,6 +108,7 @@ _🚀 **Next Generation AI One-Stop Solution**_
- [x] Tencent Hunyuan
- [x] Baichuan AI
- [x] Moonshot AI
- [x] Groq Cloud AI
- [x] ByteDance Skylark (support *function_calling*)
- [x] 360 GPT
- [x] LocalAI (Stable Diffusion, RWKV, LLaMa 2, Baichuan 7b, Mixtral, ...) _*requires local deployment_

View File

@ -37,6 +37,7 @@ var channelFactories = map[string]adaptercommon.FactoryCreator{
globals.MidjourneyChannelType: midjourney.NewChatInstanceFromConfig,
globals.MoonshotChannelType: openai.NewChatInstanceFromConfig, // openai format
globals.GroqChannelType: openai.NewChatInstanceFromConfig, // openai format
}
func createChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps, hook globals.Hook) error {

View File

@ -43,10 +43,30 @@ export const ChannelTypes: Record<string, string> = {
zhinao: "360智脑 360GLM",
baichuan: "百川大模型 BaichuanAI",
skylark: "云雀大模型 SkylarkLLM",
groq: "Groq Cloud",
bing: "New Bing",
slack: "Slack Claude",
};
export const ShortChannelTypes: Record<string, string> = {
openai: "OpenAI",
azure: "Azure",
claude: "Claude",
palm: "Gemini",
midjourney: "Midjourney",
sparkdesk: "讯飞星火",
chatglm: "ChatGLM",
moonshot: "Moonshot",
qwen: "通义千问",
hunyuan: "腾讯混元",
zhinao: "360 智脑",
baichuan: "百川 AI",
skylark: "火山方舟",
groq: "Groq",
bing: "Bing",
slack: "Slack",
}
export const ChannelInfos: Record<string, ChannelInfo> = {
openai: {
endpoint: "https://api.openai.com",
@ -212,7 +232,12 @@ export const ChannelInfos: Record<string, ChannelInfo> = {
endpoint: "https://api.moonshot.cn",
format: "<api-key>",
models: ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"],
}
},
groq: {
endpoint: "https://api.groq.com/openai",
format: "<api-key>",
models: ["llama2-70b-4096", "mixtral-8x7b-32768", "gemma-7b-it"]
}
};
export const defaultChannelModels: string[] = getUniqueList(
@ -236,3 +261,8 @@ export function getChannelType(type?: string): string {
if (type && type in ChannelTypes) return ChannelTypes[type];
return ChannelTypes.openai;
}
export function getShortChannelType(type?: string): string {
if (type && type in ShortChannelTypes) return ShortChannelTypes[type];
return ShortChannelTypes.openai;
}

View File

@ -50,6 +50,10 @@ export const modelColorMapper: Record<string, string> = {
"moonshot-v1-32k": "black-500",
"moonshot-v1-128k": "black-500",
"llama2-70b-4096": "red-500",
"mixtral-8x7b-32768": "red-500",
"gemma-7b-it": "red-500",
"chat-bison-001": "red-500",
"gemini-pro": "red-500",
"gemini-pro-vision": "red-500",

View File

@ -216,6 +216,11 @@ export const pricing: PricingDataset = [
output: 0.011,
currency: Currency.CNY,
},
{
models: ["llama2-70b-4096", "mixtral-8x7b-32768", "gemma-7b-it"],
output: 0.001, // free marked as $0.001
currency: Currency.USD,
}
];
const countPricing = (

View File

@ -18,7 +18,7 @@ import {
import { Button } from "@/components/ui/button.tsx";
import OperationAction from "@/components/OperationAction.tsx";
import { Dispatch, useEffect, useMemo, useState } from "react";
import { Channel, getChannelType } from "@/admin/channel.ts";
import { Channel, getShortChannelType } from "@/admin/channel.ts";
import { toastState } from "@/api/common.ts";
import { useTranslation } from "react-i18next";
import { useEffectAsync } from "@/utils/hook.ts";
@ -55,7 +55,7 @@ type TypeBadgeProps = {
};
function TypeBadge({ type }: TypeBadgeProps) {
const content = useMemo(() => getChannelType(type), [type]);
const content = useMemo(() => getShortChannelType(type), [type]);
return (
<Badge className={`select-none w-max cursor-pointer`}>

View File

@ -24,6 +24,7 @@ const (
PalmChannelType = "palm"
MidjourneyChannelType = "midjourney"
MoonshotChannelType = "moonshot"
GroqChannelType = "groq"
)
const (

View File

@ -1,10 +1,6 @@
package utils
import (
"bufio"
"bytes"
"chat/globals"
"crypto/tls"
"fmt"
"io"
"net/http"
@ -82,116 +78,3 @@ func NewEndEvent() StreamEvent {
Data: "data: [DONE]",
}
}
func SSEClient(method string, uri string, headers map[string]string, body interface{}, callback func(string) error) error {
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
client := newClient()
req, err := http.NewRequest(method, uri, ConvertBody(body))
if err != nil {
return nil
}
for key, value := range headers {
req.Header.Set(key, value)
}
res, err := client.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode >= 400 {
return fmt.Errorf("request failed with status: %s", res.Status)
}
events, err := CreateSSEInstance(res)
if err != nil {
return err
}
select {
case ev := <-events:
if err := callback(ev.Data); err != nil {
return err
}
}
return nil
}
// Event represents a Server-Sent Event
type Event struct {
Name string
ID string
Data string
}
func CreateSSEInstance(resp *http.Response) (chan Event, error) {
events := make(chan Event)
reader := bufio.NewReader(resp.Body)
go loop(reader, events)
return events, nil
}
func loop(reader *bufio.Reader, events chan Event) {
ev := Event{}
var buf bytes.Buffer
for {
line, err := reader.ReadBytes('\n')
if err != nil {
globals.Info(fmt.Sprintf("[sse] error during read response body: %s", err))
close(events)
}
switch {
case ioPrefix(line, ":"):
// Comment, do nothing
case ioPrefix(line, "retry:"):
// Retry, do nothing for now
// id of event
case ioPrefix(line, "id: "):
ev.ID = string(line[4:])
case ioPrefix(line, "id:"):
ev.ID = string(line[3:])
// name of event
case ioPrefix(line, "event: "):
ev.Name = string(line[7 : len(line)-1])
case ioPrefix(line, "event:"):
ev.Name = string(line[6 : len(line)-1])
// event data
case ioPrefix(line, "data: "):
buf.Write(line[6:])
case ioPrefix(line, "data:"):
buf.Write(line[5:])
// end of event
case bytes.Equal(line, []byte("\n")):
b := buf.Bytes()
if ioPrefix(b, "{") {
if err == nil {
ev.Data = string(b)
buf.Reset()
events <- ev
ev = Event{}
}
}
default:
globals.Info(fmt.Sprintf("[sse] unknown line: %s", line))
}
}
}
func ioPrefix(s []byte, prefix string) bool {
return bytes.HasPrefix(s, []byte(prefix))
}