forked from GithubProxy/ChatGPT-Next-Web
Merge pull request #5996 from ChatGPTNextWeb/feature/cogview
Feature/cogview
This commit is contained in:
commit
9df24e568b
@ -25,12 +25,103 @@ import { getMessageTextContent } from "@/app/utils";
|
|||||||
import { RequestPayload } from "./openai";
|
import { RequestPayload } from "./openai";
|
||||||
import { fetch } from "@/app/utils/stream";
|
import { fetch } from "@/app/utils/stream";
|
||||||
|
|
||||||
|
interface BasePayload {
|
||||||
|
model: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ChatPayload extends BasePayload {
|
||||||
|
messages: ChatOptions["messages"];
|
||||||
|
stream?: boolean;
|
||||||
|
temperature?: number;
|
||||||
|
presence_penalty?: number;
|
||||||
|
frequency_penalty?: number;
|
||||||
|
top_p?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ImageGenerationPayload extends BasePayload {
|
||||||
|
prompt: string;
|
||||||
|
size?: string;
|
||||||
|
user_id?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VideoGenerationPayload extends BasePayload {
|
||||||
|
prompt: string;
|
||||||
|
duration?: number;
|
||||||
|
resolution?: string;
|
||||||
|
user_id?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelType = "chat" | "image" | "video";
|
||||||
|
|
||||||
export class ChatGLMApi implements LLMApi {
|
export class ChatGLMApi implements LLMApi {
|
||||||
private disableListModels = true;
|
private disableListModels = true;
|
||||||
|
|
||||||
|
private getModelType(model: string): ModelType {
|
||||||
|
if (model.startsWith("cogview-")) return "image";
|
||||||
|
if (model.startsWith("cogvideo-")) return "video";
|
||||||
|
return "chat";
|
||||||
|
}
|
||||||
|
|
||||||
|
private getModelPath(type: ModelType): string {
|
||||||
|
switch (type) {
|
||||||
|
case "image":
|
||||||
|
return ChatGLM.ImagePath;
|
||||||
|
case "video":
|
||||||
|
return ChatGLM.VideoPath;
|
||||||
|
default:
|
||||||
|
return ChatGLM.ChatPath;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private createPayload(
|
||||||
|
messages: ChatOptions["messages"],
|
||||||
|
modelConfig: any,
|
||||||
|
options: ChatOptions,
|
||||||
|
): BasePayload {
|
||||||
|
const modelType = this.getModelType(modelConfig.model);
|
||||||
|
const lastMessage = messages[messages.length - 1];
|
||||||
|
const prompt =
|
||||||
|
typeof lastMessage.content === "string"
|
||||||
|
? lastMessage.content
|
||||||
|
: lastMessage.content.map((c) => c.text).join("\n");
|
||||||
|
|
||||||
|
switch (modelType) {
|
||||||
|
case "image":
|
||||||
|
return {
|
||||||
|
model: modelConfig.model,
|
||||||
|
prompt,
|
||||||
|
size: options.config.size,
|
||||||
|
} as ImageGenerationPayload;
|
||||||
|
default:
|
||||||
|
return {
|
||||||
|
messages,
|
||||||
|
stream: options.config.stream,
|
||||||
|
model: modelConfig.model,
|
||||||
|
temperature: modelConfig.temperature,
|
||||||
|
presence_penalty: modelConfig.presence_penalty,
|
||||||
|
frequency_penalty: modelConfig.frequency_penalty,
|
||||||
|
top_p: modelConfig.top_p,
|
||||||
|
} as ChatPayload;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private parseResponse(modelType: ModelType, json: any): string {
|
||||||
|
switch (modelType) {
|
||||||
|
case "image": {
|
||||||
|
const imageUrl = json.data?.[0]?.url;
|
||||||
|
return imageUrl ? `` : "";
|
||||||
|
}
|
||||||
|
case "video": {
|
||||||
|
const videoUrl = json.data?.[0]?.url;
|
||||||
|
return videoUrl ? `<video controls src="${videoUrl}"></video>` : "";
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return this.extractMessage(json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
path(path: string): string {
|
path(path: string): string {
|
||||||
const accessStore = useAccessStore.getState();
|
const accessStore = useAccessStore.getState();
|
||||||
|
|
||||||
let baseUrl = "";
|
let baseUrl = "";
|
||||||
|
|
||||||
if (accessStore.useCustomConfig) {
|
if (accessStore.useCustomConfig) {
|
||||||
@ -51,7 +142,6 @@ export class ChatGLMApi implements LLMApi {
|
|||||||
}
|
}
|
||||||
|
|
||||||
console.log("[Proxy Endpoint] ", baseUrl, path);
|
console.log("[Proxy Endpoint] ", baseUrl, path);
|
||||||
|
|
||||||
return [baseUrl, path].join("/");
|
return [baseUrl, path].join("/");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,24 +169,16 @@ export class ChatGLMApi implements LLMApi {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const requestPayload: RequestPayload = {
|
const modelType = this.getModelType(modelConfig.model);
|
||||||
messages,
|
const requestPayload = this.createPayload(messages, modelConfig, options);
|
||||||
stream: options.config.stream,
|
const path = this.path(this.getModelPath(modelType));
|
||||||
model: modelConfig.model,
|
|
||||||
temperature: modelConfig.temperature,
|
|
||||||
presence_penalty: modelConfig.presence_penalty,
|
|
||||||
frequency_penalty: modelConfig.frequency_penalty,
|
|
||||||
top_p: modelConfig.top_p,
|
|
||||||
};
|
|
||||||
|
|
||||||
console.log("[Request] glm payload: ", requestPayload);
|
console.log(`[Request] glm ${modelType} payload: `, requestPayload);
|
||||||
|
|
||||||
const shouldStream = !!options.config.stream;
|
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
options.onController?.(controller);
|
options.onController?.(controller);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const chatPath = this.path(ChatGLM.ChatPath);
|
|
||||||
const chatPayload = {
|
const chatPayload = {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
body: JSON.stringify(requestPayload),
|
body: JSON.stringify(requestPayload),
|
||||||
@ -104,12 +186,23 @@ export class ChatGLMApi implements LLMApi {
|
|||||||
headers: getHeaders(),
|
headers: getHeaders(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// make a fetch request
|
|
||||||
const requestTimeoutId = setTimeout(
|
const requestTimeoutId = setTimeout(
|
||||||
() => controller.abort(),
|
() => controller.abort(),
|
||||||
REQUEST_TIMEOUT_MS,
|
REQUEST_TIMEOUT_MS,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (modelType === "image" || modelType === "video") {
|
||||||
|
const res = await fetch(path, chatPayload);
|
||||||
|
clearTimeout(requestTimeoutId);
|
||||||
|
|
||||||
|
const resJson = await res.json();
|
||||||
|
console.log(`[Response] glm ${modelType}:`, resJson);
|
||||||
|
const message = this.parseResponse(modelType, resJson);
|
||||||
|
options.onFinish(message, res);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const shouldStream = !!options.config.stream;
|
||||||
if (shouldStream) {
|
if (shouldStream) {
|
||||||
const [tools, funcs] = usePluginStore
|
const [tools, funcs] = usePluginStore
|
||||||
.getState()
|
.getState()
|
||||||
@ -117,7 +210,7 @@ export class ChatGLMApi implements LLMApi {
|
|||||||
useChatStore.getState().currentSession().mask?.plugin || [],
|
useChatStore.getState().currentSession().mask?.plugin || [],
|
||||||
);
|
);
|
||||||
return stream(
|
return stream(
|
||||||
chatPath,
|
path,
|
||||||
requestPayload,
|
requestPayload,
|
||||||
getHeaders(),
|
getHeaders(),
|
||||||
tools as any,
|
tools as any,
|
||||||
@ -125,7 +218,6 @@ export class ChatGLMApi implements LLMApi {
|
|||||||
controller,
|
controller,
|
||||||
// parseSSE
|
// parseSSE
|
||||||
(text: string, runTools: ChatMessageTool[]) => {
|
(text: string, runTools: ChatMessageTool[]) => {
|
||||||
// console.log("parseSSE", text, runTools);
|
|
||||||
const json = JSON.parse(text);
|
const json = JSON.parse(text);
|
||||||
const choices = json.choices as Array<{
|
const choices = json.choices as Array<{
|
||||||
delta: {
|
delta: {
|
||||||
@ -154,7 +246,7 @@ export class ChatGLMApi implements LLMApi {
|
|||||||
}
|
}
|
||||||
return choices[0]?.delta?.content;
|
return choices[0]?.delta?.content;
|
||||||
},
|
},
|
||||||
// processToolMessage, include tool_calls message and tool call results
|
// processToolMessage
|
||||||
(
|
(
|
||||||
requestPayload: RequestPayload,
|
requestPayload: RequestPayload,
|
||||||
toolCallMessage: any,
|
toolCallMessage: any,
|
||||||
@ -172,7 +264,7 @@ export class ChatGLMApi implements LLMApi {
|
|||||||
options,
|
options,
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
const res = await fetch(chatPath, chatPayload);
|
const res = await fetch(path, chatPayload);
|
||||||
clearTimeout(requestTimeoutId);
|
clearTimeout(requestTimeoutId);
|
||||||
|
|
||||||
const resJson = await res.json();
|
const resJson = await res.json();
|
||||||
@ -184,6 +276,7 @@ export class ChatGLMApi implements LLMApi {
|
|||||||
options.onError?.(e as Error);
|
options.onError?.(e as Error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async usage() {
|
async usage() {
|
||||||
return {
|
return {
|
||||||
used: 0,
|
used: 0,
|
||||||
|
@ -24,7 +24,7 @@ import {
|
|||||||
stream,
|
stream,
|
||||||
} from "@/app/utils/chat";
|
} from "@/app/utils/chat";
|
||||||
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
|
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
|
||||||
import { DalleSize, DalleQuality, DalleStyle } from "@/app/typing";
|
import { ModelSize, DalleQuality, DalleStyle } from "@/app/typing";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
ChatOptions,
|
ChatOptions,
|
||||||
@ -73,7 +73,7 @@ export interface DalleRequestPayload {
|
|||||||
prompt: string;
|
prompt: string;
|
||||||
response_format: "url" | "b64_json";
|
response_format: "url" | "b64_json";
|
||||||
n: number;
|
n: number;
|
||||||
size: DalleSize;
|
size: ModelSize;
|
||||||
quality: DalleQuality;
|
quality: DalleQuality;
|
||||||
style: DalleStyle;
|
style: DalleStyle;
|
||||||
}
|
}
|
||||||
|
@ -72,6 +72,8 @@ import {
|
|||||||
isDalle3,
|
isDalle3,
|
||||||
showPlugins,
|
showPlugins,
|
||||||
safeLocalStorage,
|
safeLocalStorage,
|
||||||
|
getModelSizes,
|
||||||
|
supportsCustomSize,
|
||||||
} from "../utils";
|
} from "../utils";
|
||||||
|
|
||||||
import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
|
import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
|
||||||
@ -79,7 +81,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
|
|||||||
import dynamic from "next/dynamic";
|
import dynamic from "next/dynamic";
|
||||||
|
|
||||||
import { ChatControllerPool } from "../client/controller";
|
import { ChatControllerPool } from "../client/controller";
|
||||||
import { DalleSize, DalleQuality, DalleStyle } from "../typing";
|
import { DalleQuality, DalleStyle, ModelSize } from "../typing";
|
||||||
import { Prompt, usePromptStore } from "../store/prompt";
|
import { Prompt, usePromptStore } from "../store/prompt";
|
||||||
import Locale from "../locales";
|
import Locale from "../locales";
|
||||||
|
|
||||||
@ -519,10 +521,11 @@ export function ChatActions(props: {
|
|||||||
const [showSizeSelector, setShowSizeSelector] = useState(false);
|
const [showSizeSelector, setShowSizeSelector] = useState(false);
|
||||||
const [showQualitySelector, setShowQualitySelector] = useState(false);
|
const [showQualitySelector, setShowQualitySelector] = useState(false);
|
||||||
const [showStyleSelector, setShowStyleSelector] = useState(false);
|
const [showStyleSelector, setShowStyleSelector] = useState(false);
|
||||||
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
|
const modelSizes = getModelSizes(currentModel);
|
||||||
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
|
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
|
||||||
const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
|
const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
|
||||||
const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
|
const currentSize =
|
||||||
|
session.mask.modelConfig?.size ?? ("1024x1024" as ModelSize);
|
||||||
const currentQuality = session.mask.modelConfig?.quality ?? "standard";
|
const currentQuality = session.mask.modelConfig?.quality ?? "standard";
|
||||||
const currentStyle = session.mask.modelConfig?.style ?? "vivid";
|
const currentStyle = session.mask.modelConfig?.style ?? "vivid";
|
||||||
|
|
||||||
@ -673,7 +676,7 @@ export function ChatActions(props: {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{isDalle3(currentModel) && (
|
{supportsCustomSize(currentModel) && (
|
||||||
<ChatAction
|
<ChatAction
|
||||||
onClick={() => setShowSizeSelector(true)}
|
onClick={() => setShowSizeSelector(true)}
|
||||||
text={currentSize}
|
text={currentSize}
|
||||||
@ -684,7 +687,7 @@ export function ChatActions(props: {
|
|||||||
{showSizeSelector && (
|
{showSizeSelector && (
|
||||||
<Selector
|
<Selector
|
||||||
defaultSelectedValue={currentSize}
|
defaultSelectedValue={currentSize}
|
||||||
items={dalle3Sizes.map((m) => ({
|
items={modelSizes.map((m) => ({
|
||||||
title: m,
|
title: m,
|
||||||
value: m,
|
value: m,
|
||||||
}))}
|
}))}
|
||||||
|
@ -233,6 +233,8 @@ export const XAI = {
|
|||||||
export const ChatGLM = {
|
export const ChatGLM = {
|
||||||
ExampleEndpoint: CHATGLM_BASE_URL,
|
ExampleEndpoint: CHATGLM_BASE_URL,
|
||||||
ChatPath: "api/paas/v4/chat/completions",
|
ChatPath: "api/paas/v4/chat/completions",
|
||||||
|
ImagePath: "api/paas/v4/images/generations",
|
||||||
|
VideoPath: "api/paas/v4/videos/generations",
|
||||||
};
|
};
|
||||||
|
|
||||||
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
|
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
|
||||||
@ -431,6 +433,15 @@ const chatglmModels = [
|
|||||||
"glm-4-long",
|
"glm-4-long",
|
||||||
"glm-4-flashx",
|
"glm-4-flashx",
|
||||||
"glm-4-flash",
|
"glm-4-flash",
|
||||||
|
"glm-4v-plus",
|
||||||
|
"glm-4v",
|
||||||
|
"glm-4v-flash", // free
|
||||||
|
"cogview-3-plus",
|
||||||
|
"cogview-3",
|
||||||
|
"cogview-3-flash", // free
|
||||||
|
// 目前无法适配轮询任务
|
||||||
|
// "cogvideox",
|
||||||
|
// "cogvideox-flash", // free
|
||||||
];
|
];
|
||||||
|
|
||||||
let seq = 1000; // 内置的模型序号生成器从1000开始
|
let seq = 1000; // 内置的模型序号生成器从1000开始
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { LLMModel } from "../client/api";
|
import { LLMModel } from "../client/api";
|
||||||
import { DalleSize, DalleQuality, DalleStyle } from "../typing";
|
import { DalleQuality, DalleStyle, ModelSize } from "../typing";
|
||||||
import { getClientConfig } from "../config/client";
|
import { getClientConfig } from "../config/client";
|
||||||
import {
|
import {
|
||||||
DEFAULT_INPUT_TEMPLATE,
|
DEFAULT_INPUT_TEMPLATE,
|
||||||
@ -78,7 +78,7 @@ export const DEFAULT_CONFIG = {
|
|||||||
compressProviderName: "",
|
compressProviderName: "",
|
||||||
enableInjectSystemPrompts: true,
|
enableInjectSystemPrompts: true,
|
||||||
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
|
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
|
||||||
size: "1024x1024" as DalleSize,
|
size: "1024x1024" as ModelSize,
|
||||||
quality: "standard" as DalleQuality,
|
quality: "standard" as DalleQuality,
|
||||||
style: "vivid" as DalleStyle,
|
style: "vivid" as DalleStyle,
|
||||||
},
|
},
|
||||||
|
@ -11,3 +11,14 @@ export interface RequestMessage {
|
|||||||
export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";
|
export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";
|
||||||
export type DalleQuality = "standard" | "hd";
|
export type DalleQuality = "standard" | "hd";
|
||||||
export type DalleStyle = "vivid" | "natural";
|
export type DalleStyle = "vivid" | "natural";
|
||||||
|
|
||||||
|
export type ModelSize =
|
||||||
|
| "1024x1024"
|
||||||
|
| "1792x1024"
|
||||||
|
| "1024x1792"
|
||||||
|
| "768x1344"
|
||||||
|
| "864x1152"
|
||||||
|
| "1344x768"
|
||||||
|
| "1152x864"
|
||||||
|
| "1440x720"
|
||||||
|
| "720x1440";
|
||||||
|
23
app/utils.ts
23
app/utils.ts
@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant";
|
|||||||
import { fetch as tauriStreamFetch } from "./utils/stream";
|
import { fetch as tauriStreamFetch } from "./utils/stream";
|
||||||
import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant";
|
import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant";
|
||||||
import { getClientConfig } from "./config/client";
|
import { getClientConfig } from "./config/client";
|
||||||
|
import { ModelSize } from "./typing";
|
||||||
|
|
||||||
export function trimTopic(topic: string) {
|
export function trimTopic(topic: string) {
|
||||||
// Fix an issue where double quotes still show in the Indonesian language
|
// Fix an issue where double quotes still show in the Indonesian language
|
||||||
@ -271,6 +272,28 @@ export function isDalle3(model: string) {
|
|||||||
return "dall-e-3" === model;
|
return "dall-e-3" === model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getModelSizes(model: string): ModelSize[] {
|
||||||
|
if (isDalle3(model)) {
|
||||||
|
return ["1024x1024", "1792x1024", "1024x1792"];
|
||||||
|
}
|
||||||
|
if (model.toLowerCase().includes("cogview")) {
|
||||||
|
return [
|
||||||
|
"1024x1024",
|
||||||
|
"768x1344",
|
||||||
|
"864x1152",
|
||||||
|
"1344x768",
|
||||||
|
"1152x864",
|
||||||
|
"1440x720",
|
||||||
|
"720x1440",
|
||||||
|
];
|
||||||
|
}
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function supportsCustomSize(model: string): boolean {
|
||||||
|
return getModelSizes(model).length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
export function showPlugins(provider: ServiceProvider, model: string) {
|
export function showPlugins(provider: ServiceProvider, model: string) {
|
||||||
if (
|
if (
|
||||||
provider == ServiceProvider.OpenAI ||
|
provider == ServiceProvider.OpenAI ||
|
||||||
|
Loading…
Reference in New Issue
Block a user