diff --git a/admin/analysis.go b/admin/analysis.go index 2712a78..65a4df8 100644 --- a/admin/analysis.go +++ b/admin/analysis.go @@ -30,7 +30,7 @@ func getFormat(t time.Time) string { func GetSubscriptionUsers(db *sql.DB) int64 { var count int64 - err := db.QueryRow(` + err := globals.QueryRowDb(db, ` SELECT COUNT(*) FROM subscription WHERE expired_at > NOW() `).Scan(&count) if err != nil { @@ -117,7 +117,7 @@ func GetUserTypeData(db *sql.DB) (UserTypeForm, error) { var form UserTypeForm // get total users - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT COUNT(*) FROM auth `).Scan(&form.Total); err != nil { return form, err @@ -125,7 +125,7 @@ func GetUserTypeData(db *sql.DB) (UserTypeForm, error) { // get subscription users count (current subscription) // level 1: basic plan, level 2: standard plan, level 3: pro plan - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT (SELECT COUNT(*) FROM subscription WHERE level = 1 AND expired_at > NOW()), (SELECT COUNT(*) FROM subscription WHERE level = 2 AND expired_at > NOW()), @@ -136,7 +136,7 @@ func GetUserTypeData(db *sql.DB) (UserTypeForm, error) { // get normal users count (no subscription in `subscription` table and `quota` + `used` < initial quota in `quota` table) initialQuota := channel.SystemInstance.GetInitialQuota() - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT COUNT(*) FROM auth WHERE id NOT IN (SELECT user_id FROM subscription WHERE total_month > 0) AND id IN (SELECT user_id FROM quota WHERE quota + used <= ?) diff --git a/admin/invitation.go b/admin/invitation.go index b4f9c84..532958d 100644 --- a/admin/invitation.go +++ b/admin/invitation.go @@ -1,6 +1,7 @@ package admin import ( + "chat/globals" "chat/utils" "database/sql" "errors" @@ -12,7 +13,7 @@ import ( func GetInvitationPagination(db *sql.DB, page int64) PaginationForm { var invitations []interface{} var total int64 - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT COUNT(*) FROM invitation `).Scan(&total); err != nil { return PaginationForm{ @@ -21,7 +22,7 @@ func GetInvitationPagination(db *sql.DB, page int64) PaginationForm { } } - rows, err := db.Query(` + rows, err := globals.QueryDb(db, ` SELECT code, quota, type, used, updated_at FROM invitation ORDER BY id DESC LIMIT ? OFFSET ? `, pagination, page*pagination) @@ -53,7 +54,7 @@ func GetInvitationPagination(db *sql.DB, page int64) PaginationForm { } func NewInvitationCode(db *sql.DB, code string, quota float32, t string) error { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO invitation (code, quota, type) VALUES (?, ?, ?) `, code, quota, t) diff --git a/admin/redeem.go b/admin/redeem.go index d2b3336..95809e1 100644 --- a/admin/redeem.go +++ b/admin/redeem.go @@ -1,6 +1,7 @@ package admin import ( + "chat/globals" "chat/utils" "database/sql" "fmt" @@ -10,7 +11,7 @@ import ( func GetRedeemData(db *sql.DB) []RedeemData { var data []RedeemData - rows, err := db.Query(` + rows, err := globals.QueryDb(db, ` SELECT quota, COUNT(*) AS total, SUM(IF(used = 0, 0, 1)) AS used FROM redeem GROUP BY quota @@ -54,7 +55,7 @@ func GenerateRedeemCodes(db *sql.DB, num int, quota float32) RedeemGenerateRespo func CreateRedeemCode(db *sql.DB, quota float32) (string, error) { code := fmt.Sprintf("nio-%s", utils.GenerateChar(32)) - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO redeem (code, quota) VALUES (?, ?) `, code, quota) diff --git a/admin/user.go b/admin/user.go index 9462015..16d19e6 100644 --- a/admin/user.go +++ b/admin/user.go @@ -2,6 +2,7 @@ package admin import ( "chat/channel" + "chat/globals" "chat/utils" "context" "database/sql" @@ -31,7 +32,7 @@ func getUsersForm(db *sql.DB, page int64, search string) PaginationForm { var users []interface{} var total int64 - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT COUNT(*) FROM auth WHERE username LIKE ? `, "%"+search+"%").Scan(&total); err != nil { @@ -41,7 +42,7 @@ func getUsersForm(db *sql.DB, page int64, search string) PaginationForm { } } - rows, err := db.Query(` + rows, err := globals.QueryDb(db, ` SELECT auth.id, auth.username, auth.email, auth.is_admin, quota.quota, quota.used, @@ -116,7 +117,7 @@ func passwordMigration(db *sql.DB, cache *redis.Client, id int64, password strin return fmt.Errorf("password length must be between 6 and 36") } - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` UPDATE auth SET password = ? WHERE id = ? `, utils.Sha2Encrypt(password), id) @@ -126,7 +127,7 @@ func passwordMigration(db *sql.DB, cache *redis.Client, id int64, password strin } func emailMigration(db *sql.DB, id int64, email string) error { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` UPDATE auth SET email = ? WHERE id = ? `, email, id) @@ -134,7 +135,7 @@ func emailMigration(db *sql.DB, id int64, email string) error { } func setAdmin(db *sql.DB, id int64, isAdmin bool) error { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` UPDATE auth SET is_admin = ? WHERE id = ? `, isAdmin, id) @@ -142,7 +143,7 @@ func setAdmin(db *sql.DB, id int64, isAdmin bool) error { } func banUser(db *sql.DB, id int64, isBanned bool) error { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` UPDATE auth SET is_banned = ? WHERE id = ? `, isBanned, id) @@ -154,7 +155,7 @@ func quotaMigration(db *sql.DB, id int64, quota float32, override bool) error { // if quota is positive, then increase quota if override { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = ? `, id, quota, 0., quota) @@ -162,7 +163,7 @@ func quotaMigration(db *sql.DB, id int64, quota float32, override bool) error { return err } - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = quota + ? `, id, quota, 0., quota) @@ -176,7 +177,7 @@ func subscriptionMigration(db *sql.DB, id int64, month int64) error { expireAt := time.Now().AddDate(0, int(month), 0) - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO subscription (user_id, total_month, expired_at) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE total_month = total_month + ?, expired_at = DATE_ADD(expired_at, INTERVAL ? MONTH) `, id, month, expireAt, month, month) @@ -189,7 +190,7 @@ func subscriptionLevelMigration(db *sql.DB, id int64, level int64) error { return fmt.Errorf("invalid subscription level") } - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO subscription (user_id, level) VALUES (?, ?) ON DUPLICATE KEY UPDATE level = ? `, id, level, level) @@ -199,7 +200,7 @@ func subscriptionLevelMigration(db *sql.DB, id int64, level int64) error { func releaseUsage(db *sql.DB, cache *redis.Client, id int64) error { var level sql.NullInt64 - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT level FROM subscription WHERE user_id = ? `, id).Scan(&level); err != nil { return err @@ -225,7 +226,7 @@ func UpdateRootPassword(db *sql.DB, cache *redis.Client, password string) error return fmt.Errorf("password length must be between 6 and 36") } - if _, err := db.Exec(` + if _, err := globals.ExecDb(db, ` UPDATE auth SET password = ? WHERE username = 'root' `, utils.Sha2Encrypt(password)); err != nil { return err diff --git a/app/src/api/connection.ts b/app/src/api/connection.ts index e7dd99d..dbf9bd8 100644 --- a/app/src/api/connection.ts +++ b/app/src/api/connection.ts @@ -1,6 +1,7 @@ import { tokenField, websocketEndpoint } from "@/conf/bootstrap.ts"; import { getMemory } from "@/utils/memory.ts"; import { getErrorMessage } from "@/utils/base.ts"; +import { Mask } from "@/masks/types.ts"; export const endpoint = `${websocketEndpoint}/chat`; export const maxRetry = 60; // 30s max websocket retry @@ -31,11 +32,12 @@ export type ChatProps = { repetition_penalty?: number; }; -type StreamCallback = (message: StreamMessage) => void; +type StreamCallback = (id: number, message: StreamMessage) => void; export class Connection { protected connection?: WebSocket; protected callback?: StreamCallback; + protected stack?: string; public id: number; public state: boolean; @@ -66,6 +68,10 @@ export class Connection { const message = JSON.parse(event.data); this.triggerCallback(message as StreamMessage); }; + + this.connection.onclose = (event) => { + this.stack = `websocket connection failed (code: ${event.code}, reason: ${event.reason}, endpoint: ${endpoint})`; + }; } public reconnect(): void { @@ -99,23 +105,56 @@ export class Connection { ); } - const trace = { - message: data.message, - endpoint: endpoint, - }; + const trace = + this.stack || + JSON.stringify( + { + message: data.message, + endpoint: endpoint, + }, + null, + 2, + ); t && this.triggerCallback({ - message: ` -${t("request-failed")} -\`\`\`json -${JSON.stringify(trace, null, 2)} -\`\`\` - `, + message: `${t("request-failed")}\n\`\`\`json\n${trace}\n\`\`\`\n`, end: true, }); } + public sendEvent(t: any, event: string, data?: string) { + this.sendWithRetry(t, { + type: event, + message: data || "", + model: "event", + }); + } + + public sendStopEvent(t: any) { + this.sendEvent(t, "stop"); + } + + public sendRestartEvent(t: any) { + this.sendEvent(t, "restart"); + } + + public sendMaskEvent(t: any, mask: Mask) { + this.sendEvent(t, "mask", JSON.stringify(mask.context)); + } + + public sendEditEvent(t: any, id: number, message: string) { + this.sendEvent(t, "edit", `${id}:${message}`); + } + + public sendRemoveEvent(t: any, id: number) { + this.sendEvent(t, "remove", id.toString()); + } + + public sendShareEvent(t: any, refer: string) { + this.sendEvent(t, "share", refer); + } + public close(): void { if (!this.connection) return; this.connection.close(); @@ -126,13 +165,91 @@ ${JSON.stringify(trace, null, 2)} } protected triggerCallback(message: StreamMessage): void { - if (this.id === -1 && message.conversation) { - this.setId(message.conversation); - } - this.callback && this.callback(message); + this.callback && this.callback(this.id, message); } public setId(id: number): void { this.id = id; } } + +export class ConnectionStack { + protected connections: Connection[]; + protected callback?: StreamCallback; + + public constructor(callback?: StreamCallback) { + this.connections = []; + this.callback = callback; + } + + public getConnection(id: number): Connection | undefined { + return this.connections.find((conn) => conn.id === id); + } + + public addConnection(id: number): Connection { + const conn = new Connection(id, this.triggerCallback.bind(this)); + this.connections.push(conn); + return conn; + } + + public setCallback(callback?: StreamCallback): void { + this.callback = callback; + } + + public sendEvent(id: number, t: any, event: string, data?: string) { + const conn = this.getConnection(id); + conn && conn.sendEvent(t, event, data); + } + + public sendStopEvent(id: number, t: any) { + const conn = this.getConnection(id); + conn && conn.sendStopEvent(t); + } + + public sendRestartEvent(id: number, t: any) { + const conn = this.getConnection(id); + conn && conn.sendRestartEvent(t); + } + + public sendMaskEvent(id: number, t: any, mask: Mask) { + const conn = this.getConnection(id); + conn && conn.sendMaskEvent(t, mask); + } + + public sendEditEvent(id: number, t: any, message: string) { + const conn = this.getConnection(id); + conn && conn.sendEditEvent(t, id, message); + } + + public sendRemoveEvent(id: number, t: any, messageId: number) { + const conn = this.getConnection(id); + conn && conn.sendRemoveEvent(t, messageId); + } + + public sendShareEvent(id: number, t: any, refer: string) { + const conn = this.getConnection(id); + conn && conn.sendShareEvent(t, refer); + } + + public close(id: number): void { + const conn = this.getConnection(id); + conn && conn.close(); + } + + public closeAll(): void { + this.connections.forEach((conn) => conn.close()); + } + + public reconnect(id: number): void { + const conn = this.getConnection(id); + conn && conn.reconnect(); + } + + public reconnectAll(): void { + this.connections.forEach((conn) => conn.reconnect()); + } + + public triggerCallback(id: number, message: StreamMessage): void { + this.callback && this.callback(id, message); + } +} diff --git a/app/src/api/conversation.ts b/app/src/api/conversation.ts index bf3dd47..501a7f5 100644 --- a/app/src/api/conversation.ts +++ b/app/src/api/conversation.ts @@ -9,6 +9,13 @@ import { Mask } from "@/masks/types.ts"; type ConversationCallback = (idx: number, message: Message[]) => boolean; +export type ConversationSerialized = { + model: string; + end: boolean; + mask: Mask | null; + messages: Message[]; +}; + export class Conversation { protected connection?: Connection; protected callback?: ConversationCallback; @@ -100,14 +107,14 @@ export class Conversation { }); } - public sendStopEvent() { - this.sendEvent("stop"); - } - public isValidIndex(idx: number): boolean { return idx >= 0 && idx < this.data.length; } + public sendStopEvent() { + this.sendEvent("stop"); + } + public sendRestartEvent() { this.sendEvent("restart"); } @@ -253,10 +260,6 @@ export class Conversation { }; } - public getSegmentData(length: number): Message[] { - return this.data.slice(this.data.length - length - 1, this.data.length - 1); - } - public send(t: any, props: ChatProps) { if (!this.connection) { this.connection = new Connection(this.id); diff --git a/app/src/assets/main.less b/app/src/assets/main.less index 41f1538..ac34a35 100644 --- a/app/src/assets/main.less +++ b/app/src/assets/main.less @@ -91,7 +91,7 @@ strong { .flex-dialog { border-radius: var(--radius) !important; - max-height: calc(100vh - 2rem) !important; + max-height: calc(95vh - 2rem) !important; overflow-x: hidden; overflow-y: auto; scrollbar-width: none; @@ -128,7 +128,7 @@ strong { .fixed-dialog { border-radius: var(--radius) !important; - max-height: calc(100vh - 2rem) !important; + max-height: calc(95vh - 2rem) !important; min-height: 60vh; overflow-x: hidden; overflow-y: auto; @@ -197,4 +197,4 @@ strong { .chat-logo { border-radius: var(--radius); user-select: none; -} \ No newline at end of file +} diff --git a/app/src/assets/pages/home.less b/app/src/assets/pages/home.less index 62643ef..7061899 100644 --- a/app/src/assets/pages/home.less +++ b/app/src/assets/pages/home.less @@ -497,7 +497,7 @@ font-size: 14px; user-select: none; - &:before { + &:not(.loading):before { content: "#"; font-size: 12px; margin-right: 1px; diff --git a/app/src/components/admin/assemblies/ModelUsageChart.tsx b/app/src/components/admin/assemblies/ModelUsageChart.tsx index 0886999..8ecd6aa 100644 --- a/app/src/components/admin/assemblies/ModelUsageChart.tsx +++ b/app/src/components/admin/assemblies/ModelUsageChart.tsx @@ -65,7 +65,7 @@ function ModelUsageChart({ labels, datasets }: ModelChartProps) {
diff --git a/app/src/components/home/ConversationSegment.tsx b/app/src/components/home/ConversationSegment.tsx index 459d3dc..08d3394 100644 --- a/app/src/components/home/ConversationSegment.tsx +++ b/app/src/components/home/ConversationSegment.tsx @@ -2,7 +2,13 @@ import { toggleConversation } from "@/api/history.ts"; import { mobile } from "@/utils/device.ts"; import { filterMessage } from "@/utils/processor.ts"; import { setMenu } from "@/store/menu.ts"; -import { MessageSquare, MoreHorizontal, Share2, Trash2 } from "lucide-react"; +import { + Loader2, + MessageSquare, + MoreHorizontal, + Share2, + Trash2, +} from "lucide-react"; import { DropdownMenu, DropdownMenuContent, @@ -34,6 +40,8 @@ function ConversationSegment({ const [open, setOpen] = useState(false); const [offset, setOffset] = useState(0); + const loading = conversation.id <= 0; + return (
{filterMessage(conversation.name)}
-
{conversation.id}
+
+ {loading ? ( + + ) : ( + conversation.id + )} +
{ diff --git a/app/src/store/chat.ts b/app/src/store/chat.ts index 13c6bbe..be08559 100644 --- a/app/src/store/chat.ts +++ b/app/src/store/chat.ts @@ -17,10 +17,15 @@ import { } from "@/conf/storage.ts"; import { CustomMask } from "@/masks/types.ts"; import { listMasks } from "@/api/mask.ts"; +import { ConversationSerialized } from "@/api/conversation.ts"; +import { useSelector } from "react-redux"; +import { useMemo } from "react"; +import { ConnectionStack } from "@/api/connection.ts"; type initialStateType = { history: ConversationInstance[]; messages: Message[]; + conversations: Record; model: string; web: boolean; current: number; @@ -60,12 +65,14 @@ export function getModelList( return target; } +export const stack = new ConnectionStack(); const offline = loadPreferenceModels(getOfflineModels()); const chatSlice = createSlice({ name: "chat", initialState: { history: [], messages: [], + conversations: {}, web: getBooleanMemory("web", false), current: -1, model: getModel(offline, getMemory("model")), @@ -214,6 +221,9 @@ export const selectHistory = (state: RootState): ConversationInstance[] => state.chat.history; export const selectMessages = (state: RootState): Message[] => state.chat.messages; +export const selectConversations = ( + state: RootState, +): Record => state.chat.conversations; export const selectModel = (state: RootState): string => state.chat.model; export const selectWeb = (state: RootState): boolean => state.chat.web; export const selectCurrent = (state: RootState): number => state.chat.current; @@ -226,6 +236,23 @@ export const selectCustomMasks = (state: RootState): CustomMask[] => export const selectSupportModels = (state: RootState): Model[] => state.chat.support_models; +export function useConversation(): ConversationSerialized | undefined { + const conversations = useSelector(selectConversations); + const current = useSelector(selectCurrent); + + return useMemo(() => conversations[current], [conversations, current]); +} + +export function useMessages(): Message[] { + const conversations = useSelector(selectConversations); + const current = useSelector(selectCurrent); + + return useMemo( + () => conversations[current]?.messages || [], + [conversations, current], + ); +} + export const updateMasks = async (dispatch: AppDispatch) => { const resp = await listMasks(); resp.data.length > 0 && dispatch(setCustomMasks(resp.data)); diff --git a/auth/apikey.go b/auth/apikey.go index 1f7bd00..4ab9ef5 100644 --- a/auth/apikey.go +++ b/auth/apikey.go @@ -1,6 +1,7 @@ package auth import ( + "chat/globals" "chat/utils" "database/sql" "errors" @@ -10,7 +11,7 @@ import ( func (u *User) CreateApiKey(db *sql.DB) string { salt := utils.Sha2Encrypt(fmt.Sprintf("%s-%s", u.Username, utils.GenerateChar(utils.GetRandomInt(720, 1024)))) key := fmt.Sprintf("sk-%s", salt[:64]) // 64 bytes - if _, err := db.Exec("INSERT INTO apikey (user_id, api_key) VALUES (?, ?)", u.GetID(db), key); err != nil { + if _, err := globals.ExecDb(db, "INSERT INTO apikey (user_id, api_key) VALUES (?, ?)", u.GetID(db), key); err != nil { return "" } return key @@ -18,14 +19,14 @@ func (u *User) CreateApiKey(db *sql.DB) string { func (u *User) GetApiKey(db *sql.DB) string { var key string - if err := db.QueryRow("SELECT api_key FROM apikey WHERE user_id = ?", u.GetID(db)).Scan(&key); err != nil { + if err := globals.QueryRowDb(db, "SELECT api_key FROM apikey WHERE user_id = ?", u.GetID(db)).Scan(&key); err != nil { return u.CreateApiKey(db) } return key } func (u *User) ResetApiKey(db *sql.DB) (string, error) { - if _, err := db.Exec("DELETE FROM apikey WHERE user_id = ?", u.GetID(db)); err != nil && !errors.Is(err, sql.ErrNoRows) { + if _, err := globals.ExecDb(db, "DELETE FROM apikey WHERE user_id = ?", u.GetID(db)); err != nil && !errors.Is(err, sql.ErrNoRows) { return "", err } return u.CreateApiKey(db), nil diff --git a/auth/auth.go b/auth/auth.go index 344a33a..f1c9924 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -47,7 +47,7 @@ func ParseApiKey(c *gin.Context, key string) *User { } var user User - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT auth.id, auth.username, auth.password FROM auth INNER JOIN apikey ON auth.id = apikey.user_id WHERE apikey.api_key = ? @@ -143,7 +143,7 @@ func SignUp(c *gin.Context, form RegisterForm) (string, error) { Token: utils.Sha2Encrypt(email + username), } - if _, err := db.Exec(` + if _, err := globals.ExecDb(db, ` INSERT INTO auth (username, password, email, bind_id, token) VALUES (?, ?, ?, ?, ?) `, user.Username, user.Password, user.Email, user.BindID, user.Token); err != nil { @@ -170,7 +170,7 @@ func Login(c *gin.Context, form LoginForm) (string, error) { // get user from db by username (or email) and password var user User - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT auth.id, auth.username, auth.password FROM auth WHERE (auth.username = ? OR auth.email = ?) AND auth.password = ? `, username, username, hash).Scan(&user.ID, &user.Username, &user.Password); err != nil { @@ -202,7 +202,7 @@ func DeepLogin(c *gin.Context, token string) (string, error) { // register password := utils.GenerateChar(64) - _ = db.QueryRow("INSERT INTO auth (bind_id, username, token, password) VALUES (?, ?, ?, ?)", + _ = globals.QueryRowDb(db, "INSERT INTO auth (bind_id, username, token, password) VALUES (?, ?, ?, ?)", user.ID, user.Username, token, password) u := &User{ Username: user.Username, @@ -214,9 +214,9 @@ func DeepLogin(c *gin.Context, token string) (string, error) { } // login - _ = db.QueryRow("UPDATE auth SET token = ? WHERE username = ?", token, user.Username) + _ = globals.QueryRowDb(db, "UPDATE auth SET token = ? WHERE username = ?", token, user.Username) var password string - err := db.QueryRow("SELECT password FROM auth WHERE username = ?", user.Username).Scan(&password) + err := globals.QueryRowDb(db, "SELECT password FROM auth WHERE username = ?", user.Username).Scan(&password) if err != nil { return "", err } @@ -273,7 +273,7 @@ func Reset(c *gin.Context, form ResetForm) error { func (u *User) UpdatePassword(db *sql.DB, cache *redis.Client, password string) error { hash := utils.Sha2Encrypt(password) - if _, err := db.Exec(` + if _, err := globals.ExecDb(db, ` UPDATE auth SET password = ? WHERE id = ? `, hash, u.ID); err != nil { return err @@ -296,7 +296,7 @@ func (u *User) Validate(c *gin.Context) bool { db := utils.GetDBFromContext(c) var count int - if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ? AND password = ?", u.Username, u.Password).Scan(&count); err != nil || count == 0 { + if err := globals.QueryRowDb(db, "SELECT COUNT(*) FROM auth WHERE username = ? AND password = ?", u.Username, u.Password).Scan(&count); err != nil || count == 0 { if err != nil { globals.Warn(fmt.Sprintf("validate user error: %s", err.Error())) } @@ -328,13 +328,13 @@ func (u *User) GenerateToken() (string, error) { func (u *User) GenerateTokenSafe(db *sql.DB) (string, error) { if len(u.Username) == 0 { - if err := db.QueryRow("SELECT username FROM auth WHERE id = ?", u.ID).Scan(&u.Username); err != nil { + if err := globals.QueryRowDb(db, "SELECT username FROM auth WHERE id = ?", u.ID).Scan(&u.Username); err != nil { return "", err } } if len(u.Password) == 0 { - if err := db.QueryRow("SELECT password FROM auth WHERE id = ?", u.ID).Scan(&u.Password); err != nil { + if err := globals.QueryRowDb(db, "SELECT password FROM auth WHERE id = ?", u.ID).Scan(&u.Password); err != nil { return "", err } } diff --git a/auth/invitation.go b/auth/invitation.go index e81db5e..78df2d8 100644 --- a/auth/invitation.go +++ b/auth/invitation.go @@ -1,6 +1,7 @@ package auth import ( + "chat/globals" "chat/utils" "database/sql" "errors" @@ -36,7 +37,7 @@ func GenerateInvitations(db *sql.DB, num int, quota float32, t string) ([]string } func CreateInvitationCode(db *sql.DB, code string, quota float32, t string) error { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO invitation (code, quota, type) VALUES (?, ?, ?) `, code, quota, t) @@ -44,7 +45,7 @@ func CreateInvitationCode(db *sql.DB, code string, quota float32, t string) erro } func GetInvitation(db *sql.DB, code string) (*Invitation, error) { - row := db.QueryRow(` + row := globals.QueryRowDb(db, ` SELECT id, code, quota, type, used, used_id FROM invitation WHERE code = ? @@ -69,7 +70,7 @@ func (i *Invitation) IsUsed() bool { } func (i *Invitation) Use(db *sql.DB, userId int64) error { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` UPDATE invitation SET used = TRUE, used_id = ? WHERE id = ? `, userId, i.Id) return err diff --git a/auth/package.go b/auth/package.go index 5fdcc93..eb97b4f 100644 --- a/auth/package.go +++ b/auth/package.go @@ -1,6 +1,9 @@ package auth -import "database/sql" +import ( + "chat/globals" + "database/sql" +) type GiftResponse struct { Cert bool `json:"cert"` @@ -9,7 +12,7 @@ type GiftResponse struct { func (u *User) HasPackage(db *sql.DB, _t string) bool { var count int - if err := db.QueryRow(`SELECT COUNT(*) FROM package where user_id = ? AND type = ?`, u.ID, _t).Scan(&count); err != nil { + if err := globals.QueryRowDb(db, `SELECT COUNT(*) FROM package where user_id = ? AND type = ?`, u.ID, _t).Scan(&count); err != nil { return false } @@ -28,7 +31,7 @@ func NewPackage(db *sql.DB, user *User, _t string) bool { id := user.GetID(db) var count int - if err := db.QueryRow(`SELECT COUNT(*) FROM package where user_id = ? AND type = ?`, id, _t).Scan(&count); err != nil { + if err := globals.QueryRowDb(db, `SELECT COUNT(*) FROM package where user_id = ? AND type = ?`, id, _t).Scan(&count); err != nil { return false } @@ -36,7 +39,7 @@ func NewPackage(db *sql.DB, user *User, _t string) bool { return false } - _ = db.QueryRow(`INSERT INTO package (user_id, type) VALUES (?, ?)`, id, _t) + _ = globals.QueryRowDb(db, `INSERT INTO package (user_id, type) VALUES (?, ?)`, id, _t) return true } diff --git a/auth/quota.go b/auth/quota.go index b4123fd..0170e69 100644 --- a/auth/quota.go +++ b/auth/quota.go @@ -2,11 +2,12 @@ package auth import ( "chat/channel" + "chat/globals" "database/sql" ) func (u *User) CreateInitialQuota(db *sql.DB) bool { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) `, u.GetID(db), channel.SystemInstance.GetInitialQuota(), 0.) return err == nil @@ -14,7 +15,7 @@ func (u *User) CreateInitialQuota(db *sql.DB) bool { func (u *User) GetQuota(db *sql.DB) float32 { var quota float32 - if err := db.QueryRow("SELECT quota FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil { + if err := globals.QueryRowDb(db, "SELECT quota FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil { return 0. } return quota @@ -22,44 +23,50 @@ func (u *User) GetQuota(db *sql.DB) float32 { func (u *User) GetUsedQuota(db *sql.DB) float32 { var quota float32 - if err := db.QueryRow("SELECT used FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil { + if err := globals.QueryRowDb(db, "SELECT used FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil { return 0. } return quota } func (u *User) SetQuota(db *sql.DB, quota float32) bool { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = ? `, u.GetID(db), quota, 0., quota) return err == nil } func (u *User) SetUsedQuota(db *sql.DB, used float32) bool { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE used = ? `, u.GetID(db), 0., used, used) return err == nil } func (u *User) IncreaseQuota(db *sql.DB, quota float32) bool { - _, err := db.Exec(` + _, err := globals.ExecDb(db, globals.MultiSql(` INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = quota + ? - `, u.GetID(db), quota, 0., quota) + `, ` + INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET quota = quota + ? + `), u.GetID(db), quota, 0., quota) return err == nil } func (u *User) IncreaseUsedQuota(db *sql.DB, used float32) bool { - _, err := db.Exec(` + _, err := globals.ExecDb(db, globals.MultiSql(` INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE used = used + ? - `, u.GetID(db), 0., used, used) + `, ` + INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET used = used + ? + `), u.GetID(db), 0., used, used) return err == nil } func (u *User) DecreaseQuota(db *sql.DB, quota float32) bool { - _, err := db.Exec(` + _, err := globals.ExecDb(db, globals.MultiSql(` INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = quota - ? - `, u.GetID(db), quota, 0., quota) + `, ` + INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET quota = quota - ? + `), u.GetID(db), quota, 0., quota) return err == nil } diff --git a/auth/redeem.go b/auth/redeem.go index 8be8492..b2b2292 100644 --- a/auth/redeem.go +++ b/auth/redeem.go @@ -1,6 +1,7 @@ package auth import ( + "chat/globals" "chat/utils" "database/sql" "errors" @@ -34,14 +35,14 @@ func GenerateRedeemCodes(db *sql.DB, num int, quota float32) ([]string, error) { } func CreateRedeemCode(db *sql.DB, code string, quota float32) error { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO redeem (code, quota) VALUES (?, ?) `, code, quota) return err } func GetRedeemCode(db *sql.DB, code string) (*Redeem, error) { - row := db.QueryRow(` + row := globals.QueryRowDb(db, ` SELECT id, code, quota, used FROM redeem WHERE code = ? @@ -62,7 +63,7 @@ func (r *Redeem) IsUsed() bool { } func (r *Redeem) Use(db *sql.DB) error { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` UPDATE redeem SET used = TRUE WHERE id = ? AND used = FALSE `, r.Id) return err diff --git a/auth/struct.go b/auth/struct.go index 5a84483..d194c7e 100644 --- a/auth/struct.go +++ b/auth/struct.go @@ -22,7 +22,7 @@ type User struct { func GetUserById(db *sql.DB, id int64) *User { var user User - if err := db.QueryRow("SELECT id, username FROM auth WHERE id = ?", id).Scan(&user.ID, &user.Username); err != nil { + if err := globals.QueryRowDb(db, "SELECT id, username FROM auth WHERE id = ?", id).Scan(&user.ID, &user.Username); err != nil { return nil } return &user @@ -30,7 +30,7 @@ func GetUserById(db *sql.DB, id int64) *User { func GetUserByName(db *sql.DB, username string) *User { var user User - if err := db.QueryRow("SELECT id, username FROM auth WHERE username = ?", username).Scan(&user.ID, &user.Username); err != nil { + if err := globals.QueryRowDb(db, "SELECT id, username FROM auth WHERE username = ?", username).Scan(&user.ID, &user.Username); err != nil { return nil } return &user @@ -38,7 +38,7 @@ func GetUserByName(db *sql.DB, username string) *User { func GetUserByEmail(db *sql.DB, email string) *User { var user User - if err := db.QueryRow("SELECT id, username FROM auth WHERE email = ?", email).Scan(&user.ID, &user.Username); err != nil { + if err := globals.QueryRowDb(db, "SELECT id, username FROM auth WHERE email = ?", email).Scan(&user.ID, &user.Username); err != nil { return nil } return &user @@ -57,7 +57,7 @@ func (u *User) IsBanned(db *sql.DB) bool { } var banned sql.NullBool - if err := db.QueryRow("SELECT is_banned FROM auth WHERE username = ?", u.Username).Scan(&banned); err != nil { + if err := globals.QueryRowDb(db, "SELECT is_banned FROM auth WHERE username = ?", u.Username).Scan(&banned); err != nil { return false } u.Banned = banned.Valid && banned.Bool @@ -71,7 +71,7 @@ func (u *User) IsAdmin(db *sql.DB) bool { } var admin sql.NullBool - if err := db.QueryRow("SELECT is_admin FROM auth WHERE username = ?", u.Username).Scan(&admin); err != nil { + if err := globals.QueryRowDb(db, "SELECT is_admin FROM auth WHERE username = ?", u.Username).Scan(&admin); err != nil { return false } @@ -83,7 +83,7 @@ func (u *User) GetID(db *sql.DB) int64 { if u.ID > 0 { return u.ID } - if err := db.QueryRow("SELECT id FROM auth WHERE username = ?", u.Username).Scan(&u.ID); err != nil { + if err := globals.QueryRowDb(db, "SELECT id FROM auth WHERE username = ?", u.Username).Scan(&u.ID); err != nil { return 0 } return u.ID @@ -99,7 +99,7 @@ func (u *User) GetEmail(db *sql.DB) string { } var email sql.NullString - if err := db.QueryRow("SELECT email FROM auth WHERE username = ?", u.Username).Scan(&email); err != nil { + if err := globals.QueryRowDb(db, "SELECT email FROM auth WHERE username = ?", u.Username).Scan(&email); err != nil { return "" } @@ -109,7 +109,7 @@ func (u *User) GetEmail(db *sql.DB) string { func IsUserExist(db *sql.DB, username string) bool { var count int - if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil { + if err := globals.QueryRowDb(db, "SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil { return false } return count > 0 @@ -117,7 +117,7 @@ func IsUserExist(db *sql.DB, username string) bool { func IsEmailExist(db *sql.DB, email string) bool { var count int - if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE email = ?", email).Scan(&count); err != nil { + if err := globals.QueryRowDb(db, "SELECT COUNT(*) FROM auth WHERE email = ?", email).Scan(&count); err != nil { return false } return count > 0 @@ -125,7 +125,7 @@ func IsEmailExist(db *sql.DB, email string) bool { func getMaxBindId(db *sql.DB) int64 { var max int64 - if err := db.QueryRow("SELECT MAX(bind_id) FROM auth").Scan(&max); err != nil { + if err := globals.QueryRowDb(db, "SELECT MAX(bind_id) FROM auth").Scan(&max); err != nil { return 0 } return max diff --git a/auth/subscription.go b/auth/subscription.go index e78a92f..5409c9b 100644 --- a/auth/subscription.go +++ b/auth/subscription.go @@ -2,6 +2,7 @@ package auth import ( "chat/channel" + "chat/globals" "chat/utils" "database/sql" "errors" @@ -21,7 +22,7 @@ func (u *User) GetSubscription(db *sql.DB) (time.Time, int) { } var expiredAt []uint8 - if err := db.QueryRow("SELECT expired_at, level FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&expiredAt, &u.Level); err != nil { + if err := globals.QueryRowDb(db, "SELECT expired_at, level FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&expiredAt, &u.Level); err != nil { return time.Unix(0, 0), 0 } @@ -62,7 +63,7 @@ func (u *User) IsEnterprise(db *sql.DB) bool { } var enterprise sql.NullBool - if err := db.QueryRow("SELECT enterprise FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&enterprise); err != nil { + if err := globals.QueryRowDb(db, "SELECT enterprise FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&enterprise); err != nil { return false } @@ -81,7 +82,7 @@ func (u *User) AddSubscription(db *sql.DB, month int, level int) bool { } expiredAt := current.AddDate(0, month, 0) date := utils.ConvertSqlTime(expiredAt) - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO subscription (user_id, expired_at, total_month, level) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE expired_at = ?, total_month = total_month + ?, level = ? `, u.GetID(db), date, month, level, date, month, level) @@ -101,7 +102,7 @@ func (u *User) DowngradePlan(db *sql.DB, target int) error { // ceil expired time expiredAt := now.Add(time.Duration(stamp)*time.Second).AddDate(0, 0, -1) date := utils.ConvertSqlTime(expiredAt) - _, err := db.Exec("UPDATE subscription SET level = ?, expired_at = ? WHERE user_id = ?", target, date, u.GetID(db)) + _, err := globals.ExecDb(db, "UPDATE subscription SET level = ?, expired_at = ? WHERE user_id = ?", target, date, u.GetID(db)) return err } @@ -118,7 +119,7 @@ func (u *User) CountUpgradePrice(db *sql.DB, target int) float32 { } func (u *User) SetSubscriptionLevel(db *sql.DB, level int) bool { - _, err := db.Exec("UPDATE subscription SET level = ? WHERE user_id = ?", level, u.GetID(db)) + _, err := globals.ExecDb(db, "UPDATE subscription SET level = ? WHERE user_id = ?", level, u.GetID(db)) return err == nil } diff --git a/chatnio.db b/chatnio.db new file mode 100644 index 0000000..e761b6a Binary files /dev/null and b/chatnio.db differ diff --git a/cli/admin.go b/cli/admin.go index f19859b..b201366 100644 --- a/cli/admin.go +++ b/cli/admin.go @@ -7,7 +7,7 @@ import ( ) func UpdateRootCommand(args []string) { - db := connection.ConnectMySQL() + db := connection.ConnectDatabase() cache := connection.ConnectRedis() if len(args) == 0 { diff --git a/cli/invite.go b/cli/invite.go index 9c6c42d..cfb0213 100644 --- a/cli/invite.go +++ b/cli/invite.go @@ -8,7 +8,7 @@ import ( ) func CreateInvitationCommand(args []string) { - db := connection.ConnectMySQL() + db := connection.ConnectDatabase() var ( t = GetArgString(args, 0) diff --git a/cli/token.go b/cli/token.go index ea0ed1b..80d5709 100644 --- a/cli/token.go +++ b/cli/token.go @@ -8,7 +8,7 @@ import ( ) func CreateTokenCommand(args []string) { - db := connection.ConnectMySQL() + db := connection.ConnectDatabase() id, _ := strconv.Atoi(args[0]) user := auth.GetUserById(db, int64(id)) diff --git a/connection/database.go b/connection/database.go index 31e4612..4ae7c1f 100644 --- a/connection/database.go +++ b/connection/database.go @@ -6,20 +6,32 @@ import ( "database/sql" "fmt" _ "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" "github.com/spf13/viper" ) var DB *sql.DB func InitMySQLSafe() *sql.DB { - ConnectMySQL() + ConnectDatabase() // using DB as a global variable to point to the latest db connection MysqlWorker(DB) return DB } -func ConnectMySQL() *sql.DB { +func getConn() *sql.DB { + if viper.GetString("mysql.host") == "" { + globals.SqliteEngine = true + globals.Warn("[connection] mysql host is not set, using sqlite (chatnio.db)") + db, err := sql.Open("sqlite3", "chatnio.db") + if err != nil { + panic(err) + } + + return db + } + // connect to MySQL db, err := sql.Open("mysql", fmt.Sprintf( "%s:%s@tcp(%s:%d)/%s", @@ -29,6 +41,7 @@ func ConnectMySQL() *sql.DB { viper.GetInt("mysql.port"), viper.GetString("mysql.db"), )) + if pingErr := db.Ping(); err != nil || pingErr != nil { errMsg := utils.Multi[string](err != nil, utils.GetError(err), utils.GetError(pingErr)) // err.Error() may contain nil pointer globals.Warn( @@ -40,11 +53,16 @@ func ConnectMySQL() *sql.DB { utils.Sleep(5000) db.Close() - return ConnectMySQL() - } else { - globals.Debug(fmt.Sprintf("[connection] connected to mysql server (host: %s)", viper.GetString("mysql.host"))) + return getConn() } + globals.Debug(fmt.Sprintf("[connection] connected to mysql server (host: %s)", viper.GetString("mysql.host"))) + return db +} + +func ConnectDatabase() *sql.DB { + db := getConn() + db.SetMaxOpenConns(512) db.SetMaxIdleConns(64) @@ -72,7 +90,7 @@ func ConnectMySQL() *sql.DB { func InitRootUser(db *sql.DB) { // create root user if totally empty var count int - err := db.QueryRow("SELECT COUNT(*) FROM auth").Scan(&count) + err := globals.QueryRowDb(db, "SELECT COUNT(*) FROM auth").Scan(&count) if err != nil { globals.Warn(fmt.Sprintf("[service] failed to query user count: %s", err.Error())) return @@ -80,7 +98,7 @@ func InitRootUser(db *sql.DB) { if count == 0 { globals.Debug("[service] no user found, creating root user (username: root, password: chatnio123456, email: root@example.com)") - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` INSERT INTO auth (username, password, email, is_admin, bind_id, token) VALUES (?, ?, ?, ?, ?, ?) `, "root", utils.Sha2Encrypt("chatnio123456"), "root@example.com", true, 0, "root") @@ -93,7 +111,7 @@ func InitRootUser(db *sql.DB) { } func CreateUserTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS auth ( id INT PRIMARY KEY AUTO_INCREMENT, bind_id INT UNIQUE, @@ -113,7 +131,7 @@ func CreateUserTable(db *sql.DB) { } func CreatePackageTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS package ( id INT PRIMARY KEY AUTO_INCREMENT, user_id INT, @@ -129,7 +147,7 @@ func CreatePackageTable(db *sql.DB) { } func CreateQuotaTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS quota ( id INT PRIMARY KEY AUTO_INCREMENT, user_id INT UNIQUE, @@ -146,7 +164,7 @@ func CreateQuotaTable(db *sql.DB) { } func CreateConversationTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS conversation ( id INT PRIMARY KEY AUTO_INCREMENT, user_id INT, @@ -164,7 +182,7 @@ func CreateConversationTable(db *sql.DB) { } func CreateMaskTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS mask ( id INT PRIMARY KEY AUTO_INCREMENT, user_id INT, @@ -184,7 +202,7 @@ func CreateMaskTable(db *sql.DB) { func CreateSharingTable(db *sql.DB) { // refs is an array of message id, separated by comma (-1 means all messages) - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS sharing ( id INT PRIMARY KEY AUTO_INCREMENT, hash CHAR(32) UNIQUE, @@ -201,7 +219,7 @@ func CreateSharingTable(db *sql.DB) { } func CreateSubscriptionTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS subscription ( id INT PRIMARY KEY AUTO_INCREMENT, level INT DEFAULT 1, @@ -220,7 +238,7 @@ func CreateSubscriptionTable(db *sql.DB) { } func CreateApiKeyTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS apikey ( id INT PRIMARY KEY AUTO_INCREMENT, user_id INT UNIQUE, @@ -235,7 +253,7 @@ func CreateApiKeyTable(db *sql.DB) { } func CreateInvitationTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS invitation ( id INT PRIMARY KEY AUTO_INCREMENT, code VARCHAR(255) UNIQUE, @@ -255,7 +273,7 @@ func CreateInvitationTable(db *sql.DB) { } func CreateRedeemTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS redeem ( id INT PRIMARY KEY AUTO_INCREMENT, code VARCHAR(255) UNIQUE, @@ -271,7 +289,7 @@ func CreateRedeemTable(db *sql.DB) { } func CreateBroadcastTable(db *sql.DB) { - _, err := db.Exec(` + _, err := globals.ExecDb(db, ` CREATE TABLE IF NOT EXISTS broadcast ( id INT PRIMARY KEY AUTO_INCREMENT, poster_id INT, diff --git a/connection/db_migration.go b/connection/db_migration.go index 3f8cc5b..2b3a9c6 100644 --- a/connection/db_migration.go +++ b/connection/db_migration.go @@ -1,6 +1,7 @@ package connection import ( + "chat/globals" "database/sql" "strings" ) @@ -27,10 +28,14 @@ func checkSqlError(_ sql.Result, err error) error { } func execSql(db *sql.DB, sql string, args ...interface{}) error { - return checkSqlError(db.Exec(sql, args...)) + return checkSqlError(globals.ExecDb(db, sql, args...)) } func doMigration(db *sql.DB) error { + if globals.SqliteEngine { + return doSqliteMigration(db) + } + // v3.10 migration // update `quota`, `used` field in `quota` table @@ -54,3 +59,9 @@ func doMigration(db *sql.DB) error { return nil } + +func doSqliteMigration(db *sql.DB) error { + // v3.10 added sqlite support, no migration needed before this version + + return nil +} diff --git a/connection/worker.go b/connection/worker.go index a7e5835..94a248b 100644 --- a/connection/worker.go +++ b/connection/worker.go @@ -12,7 +12,7 @@ func MysqlWorker(db *sql.DB) { go func() { for { if db == nil || db.Ping() != nil { - db = ConnectMySQL() + db = ConnectDatabase() } time.Sleep(tick) diff --git a/globals/sql.go b/globals/sql.go new file mode 100644 index 0000000..75d3b73 --- /dev/null +++ b/globals/sql.go @@ -0,0 +1,77 @@ +package globals + +import ( + "database/sql" + "regexp" + "strings" +) + +var SqliteEngine = false + +type batch struct { + Old string + New string + Regex bool +} + +func batchReplace(sql string, batch []batch) string { + for _, item := range batch { + if item.Regex { + sql = regexp.MustCompile(item.Old).ReplaceAllString(sql, item.New) + continue + } + + sql = strings.ReplaceAll(sql, item.Old, item.New) + } + return sql +} + +func MultiSql(mysqlSql string, sqliteSql string) string { + if SqliteEngine { + return sqliteSql + } + return mysqlSql +} + +func PreflightSql(sql string) string { + if SqliteEngine { + sql = batchReplace(sql, []batch{ + // KEYWORD REPLACEMENT + {`INT `, `INTEGER `, false}, + {` AUTO_INCREMENT`, ` AUTOINCREMENT`, false}, + {`DATETIME`, `TEXT`, false}, + {`DECIMAL`, `REAL`, false}, + {`MEDIUMTEXT`, `TEXT`, false}, + {`VARCHAR`, `TEXT`, false}, + + // TEXT(65535) -> TEXT, REAL(10,2) -> REAL + {`TEXT\(\d+\)`, `TEXT`, true}, + {`REAL\(\d+,\d+\)`, `REAL`, true}, + + // UNIQUE KEY -> UNIQUE + {`UNIQUE KEY`, `UNIQUE`, false}, + }) + } + + return sql +} + +func ExecDb(db *sql.DB, sql string, args ...interface{}) (sql.Result, error) { + sql = PreflightSql(sql) + return db.Exec(sql, args...) +} + +func PrepareDb(db *sql.DB, sql string) (*sql.Stmt, error) { + sql = PreflightSql(sql) + return db.Prepare(sql) +} + +func QueryDb(db *sql.DB, sql string, args ...interface{}) (*sql.Rows, error) { + sql = PreflightSql(sql) + return db.Query(sql, args...) +} + +func QueryRowDb(db *sql.DB, sql string, args ...interface{}) *sql.Row { + sql = PreflightSql(sql) + return db.QueryRow(sql, args...) +} diff --git a/go.mod b/go.mod index 1c584fc..46da86a 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( github.com/leodido/go-urn v1.2.4 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect diff --git a/go.sum b/go.sum index d73051b..68aa742 100644 --- a/go.sum +++ b/go.sum @@ -403,6 +403,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= diff --git a/manager/broadcast/manage.go b/manager/broadcast/manage.go index 6436fa9..79737a1 100644 --- a/manager/broadcast/manage.go +++ b/manager/broadcast/manage.go @@ -2,6 +2,7 @@ package broadcast import ( "chat/auth" + "chat/globals" "chat/utils" "context" "github.com/gin-gonic/gin" @@ -11,7 +12,7 @@ func createBroadcast(c *gin.Context, user *auth.User, content string) error { db := utils.GetDBFromContext(c) cache := utils.GetCacheFromContext(c) - if _, err := db.Exec(`INSERT INTO broadcast (poster_id, content) VALUES (?, ?)`, user.GetID(db), content); err != nil { + if _, err := globals.ExecDb(db, `INSERT INTO broadcast (poster_id, content) VALUES (?, ?)`, user.GetID(db), content); err != nil { return err } @@ -24,7 +25,7 @@ func getBroadcastList(c *gin.Context) ([]Info, error) { db := utils.GetDBFromContext(c) var broadcastList []Info - rows, err := db.Query(` + rows, err := globals.QueryDb(db, ` SELECT broadcast.id, broadcast.content, auth.username, broadcast.created_at FROM broadcast INNER JOIN auth ON broadcast.poster_id = auth.id diff --git a/manager/broadcast/view.go b/manager/broadcast/view.go index d2d7668..672f8f8 100644 --- a/manager/broadcast/view.go +++ b/manager/broadcast/view.go @@ -1,6 +1,7 @@ package broadcast import ( + "chat/globals" "chat/utils" "context" "github.com/gin-gonic/gin" @@ -18,7 +19,7 @@ func getLatestBroadcast(c *gin.Context) *Broadcast { } var broadcast Broadcast - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT id, content FROM broadcast ORDER BY id DESC LIMIT 1; `).Scan(&broadcast.Index, &broadcast.Content); err != nil { return nil diff --git a/manager/conversation/mask.go b/manager/conversation/mask.go index 336af3f..97c6d24 100644 --- a/manager/conversation/mask.go +++ b/manager/conversation/mask.go @@ -26,14 +26,14 @@ func (m *Mask) Save(db *sql.DB, user *auth.User) error { userId := user.GetID(db) if m.Id == -1 { - _, err := db.Exec( + _, err := globals.ExecDb(db, "INSERT INTO mask (mask.user_id, avatar, name, description, context) VALUES (?, ?, ?, ?, ?)", userId, m.Avatar, m.Name, m.Description, utils.Marshal(m.Context), ) return err } - _, err := db.Exec( + _, err := globals.ExecDb(db, "UPDATE mask SET avatar = ?, name = ?, description = ?, context = ? WHERE id = ? AND user_id = ?", m.Avatar, m.Name, m.Description, utils.Marshal(m.Context), m.Id, userId, ) @@ -41,12 +41,12 @@ func (m *Mask) Save(db *sql.DB, user *auth.User) error { } func (m *Mask) Delete(db *sql.DB, user *auth.User) error { - _, err := db.Exec("DELETE FROM mask WHERE id = ? AND user_id = ?", m.Id, user.GetID(db)) + _, err := globals.ExecDb(db, "DELETE FROM mask WHERE id = ? AND user_id = ?", m.Id, user.GetID(db)) return err } func LoadMask(db *sql.DB, user *auth.User) ([]Mask, error) { - rows, err := db.Query(` + rows, err := globals.QueryDb(db, ` SELECT id, avatar, name, description, context FROM mask WHERE user_id = ? ORDER BY id DESC diff --git a/manager/conversation/shared.go b/manager/conversation/shared.go index 7cf365c..cf249dd 100644 --- a/manager/conversation/shared.go +++ b/manager/conversation/shared.go @@ -49,7 +49,7 @@ func ShareConversation(db *sql.DB, user *auth.User, id int64, refs []int) (strin Refs: refs, }) - if _, err := db.Exec(` + if _, err := globals.ExecDb(db, ` INSERT INTO sharing (hash, user_id, conversation_id, refs) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE refs = ? `, hash, user.GetID(db), id, ref, ref); err != nil { @@ -86,7 +86,7 @@ func ListSharedConversation(db *sql.DB, user *auth.User) []SharedPreviewForm { } id := user.GetID(db) - rows, err := db.Query(` + rows, err := globals.QueryDb(db, ` SELECT conversation.conversation_name, conversation.conversation_id, sharing.updated_at, sharing.hash FROM sharing INNER JOIN conversation @@ -120,7 +120,7 @@ func DeleteSharedConversation(db *sql.DB, user *auth.User, hash string) error { } id := user.GetID(db) - if _, err := db.Exec(` + if _, err := globals.ExecDb(db, ` DELETE FROM sharing WHERE user_id = ? AND hash = ? `, id, hash); err != nil { return err @@ -136,7 +136,7 @@ func GetSharedConversation(db *sql.DB, hash string) (*SharedForm, error) { ref string updated []uint8 ) - if err := db.QueryRow(` + if err := globals.QueryRowDb(db, ` SELECT auth.username, sharing.refs, sharing.updated_at, conversation.conversation_name, sharing.user_id, sharing.conversation_id FROM sharing diff --git a/manager/conversation/storage.go b/manager/conversation/storage.go index e245b7f..64c947e 100644 --- a/manager/conversation/storage.go +++ b/manager/conversation/storage.go @@ -20,7 +20,7 @@ func (c *Conversation) SaveConversation(db *sql.DB) bool { ON DUPLICATE KEY UPDATE conversation_name = VALUES(conversation_name), data = VALUES(data) ` - stmt, err := db.Prepare(query) + stmt, err := globals.PrepareDb(db, query) if err != nil { return false } @@ -40,7 +40,7 @@ func (c *Conversation) SaveConversation(db *sql.DB) bool { } func GetConversationLengthByUserID(db *sql.DB, userId int64) int64 { var length int64 - err := db.QueryRow("SELECT MAX(conversation_id) FROM conversation WHERE user_id = ?", userId).Scan(&length) + err := globals.QueryRowDb(db, "SELECT MAX(conversation_id) FROM conversation WHERE user_id = ?", userId).Scan(&length) if err != nil || length < 0 { return 0 } @@ -57,7 +57,7 @@ func LoadConversation(db *sql.DB, userId int64, conversationId int64) *Conversat data string model interface{} ) - err := db.QueryRow(` + err := globals.QueryRowDb(db, ` SELECT conversation_name, model, data FROM conversation WHERE user_id = ? AND conversation_id = ? `, userId, conversationId).Scan(&conversation.Name, &model, &data) @@ -81,7 +81,7 @@ func LoadConversation(db *sql.DB, userId int64, conversationId int64) *Conversat func LoadConversationList(db *sql.DB, userId int64) []Conversation { var conversationList []Conversation - rows, err := db.Query(` + rows, err := globals.QueryDb(db, ` SELECT conversation_id, conversation_name FROM conversation WHERE user_id = ? ORDER BY conversation_id DESC LIMIT 100 `, userId) @@ -108,7 +108,7 @@ func LoadConversationList(db *sql.DB, userId int64) []Conversation { } func (c *Conversation) DeleteConversation(db *sql.DB) bool { - _, err := db.Exec("DELETE FROM conversation WHERE user_id = ? AND conversation_id = ?", c.UserID, c.Id) + _, err := globals.ExecDb(db, "DELETE FROM conversation WHERE user_id = ? AND conversation_id = ?", c.UserID, c.Id) if err != nil { return false } @@ -116,6 +116,6 @@ func (c *Conversation) DeleteConversation(db *sql.DB) bool { } func DeleteAllConversations(db *sql.DB, user auth.User) error { - _, err := db.Exec("DELETE FROM conversation WHERE user_id = ?", user.GetID(db)) + _, err := globals.ExecDb(db, "DELETE FROM conversation WHERE user_id = ?", user.GetID(db)) return err }