mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 13:00:14 +09:00
feat alpha: support sqlite and connection stack restruct
This commit is contained in:
parent
d1573dfa2e
commit
67cb512eb4
@ -30,7 +30,7 @@ func getFormat(t time.Time) string {
|
|||||||
|
|
||||||
func GetSubscriptionUsers(db *sql.DB) int64 {
|
func GetSubscriptionUsers(db *sql.DB) int64 {
|
||||||
var count int64
|
var count int64
|
||||||
err := db.QueryRow(`
|
err := globals.QueryRowDb(db, `
|
||||||
SELECT COUNT(*) FROM subscription WHERE expired_at > NOW()
|
SELECT COUNT(*) FROM subscription WHERE expired_at > NOW()
|
||||||
`).Scan(&count)
|
`).Scan(&count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -117,7 +117,7 @@ func GetUserTypeData(db *sql.DB) (UserTypeForm, error) {
|
|||||||
var form UserTypeForm
|
var form UserTypeForm
|
||||||
|
|
||||||
// get total users
|
// get total users
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT COUNT(*) FROM auth
|
SELECT COUNT(*) FROM auth
|
||||||
`).Scan(&form.Total); err != nil {
|
`).Scan(&form.Total); err != nil {
|
||||||
return form, err
|
return form, err
|
||||||
@ -125,7 +125,7 @@ func GetUserTypeData(db *sql.DB) (UserTypeForm, error) {
|
|||||||
|
|
||||||
// get subscription users count (current subscription)
|
// get subscription users count (current subscription)
|
||||||
// level 1: basic plan, level 2: standard plan, level 3: pro plan
|
// level 1: basic plan, level 2: standard plan, level 3: pro plan
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT
|
SELECT
|
||||||
(SELECT COUNT(*) FROM subscription WHERE level = 1 AND expired_at > NOW()),
|
(SELECT COUNT(*) FROM subscription WHERE level = 1 AND expired_at > NOW()),
|
||||||
(SELECT COUNT(*) FROM subscription WHERE level = 2 AND expired_at > NOW()),
|
(SELECT COUNT(*) FROM subscription WHERE level = 2 AND expired_at > NOW()),
|
||||||
@ -136,7 +136,7 @@ func GetUserTypeData(db *sql.DB) (UserTypeForm, error) {
|
|||||||
|
|
||||||
// get normal users count (no subscription in `subscription` table and `quota` + `used` < initial quota in `quota` table)
|
// get normal users count (no subscription in `subscription` table and `quota` + `used` < initial quota in `quota` table)
|
||||||
initialQuota := channel.SystemInstance.GetInitialQuota()
|
initialQuota := channel.SystemInstance.GetInitialQuota()
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT COUNT(*) FROM auth
|
SELECT COUNT(*) FROM auth
|
||||||
WHERE id NOT IN (SELECT user_id FROM subscription WHERE total_month > 0)
|
WHERE id NOT IN (SELECT user_id FROM subscription WHERE total_month > 0)
|
||||||
AND id IN (SELECT user_id FROM quota WHERE quota + used <= ?)
|
AND id IN (SELECT user_id FROM quota WHERE quota + used <= ?)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
@ -12,7 +13,7 @@ import (
|
|||||||
func GetInvitationPagination(db *sql.DB, page int64) PaginationForm {
|
func GetInvitationPagination(db *sql.DB, page int64) PaginationForm {
|
||||||
var invitations []interface{}
|
var invitations []interface{}
|
||||||
var total int64
|
var total int64
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT COUNT(*) FROM invitation
|
SELECT COUNT(*) FROM invitation
|
||||||
`).Scan(&total); err != nil {
|
`).Scan(&total); err != nil {
|
||||||
return PaginationForm{
|
return PaginationForm{
|
||||||
@ -21,7 +22,7 @@ func GetInvitationPagination(db *sql.DB, page int64) PaginationForm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.Query(`
|
rows, err := globals.QueryDb(db, `
|
||||||
SELECT code, quota, type, used, updated_at FROM invitation
|
SELECT code, quota, type, used, updated_at FROM invitation
|
||||||
ORDER BY id DESC LIMIT ? OFFSET ?
|
ORDER BY id DESC LIMIT ? OFFSET ?
|
||||||
`, pagination, page*pagination)
|
`, pagination, page*pagination)
|
||||||
@ -53,7 +54,7 @@ func GetInvitationPagination(db *sql.DB, page int64) PaginationForm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewInvitationCode(db *sql.DB, code string, quota float32, t string) error {
|
func NewInvitationCode(db *sql.DB, code string, quota float32, t string) error {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO invitation (code, quota, type)
|
INSERT INTO invitation (code, quota, type)
|
||||||
VALUES (?, ?, ?)
|
VALUES (?, ?, ?)
|
||||||
`, code, quota, t)
|
`, code, quota, t)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -10,7 +11,7 @@ import (
|
|||||||
func GetRedeemData(db *sql.DB) []RedeemData {
|
func GetRedeemData(db *sql.DB) []RedeemData {
|
||||||
var data []RedeemData
|
var data []RedeemData
|
||||||
|
|
||||||
rows, err := db.Query(`
|
rows, err := globals.QueryDb(db, `
|
||||||
SELECT quota, COUNT(*) AS total, SUM(IF(used = 0, 0, 1)) AS used
|
SELECT quota, COUNT(*) AS total, SUM(IF(used = 0, 0, 1)) AS used
|
||||||
FROM redeem
|
FROM redeem
|
||||||
GROUP BY quota
|
GROUP BY quota
|
||||||
@ -54,7 +55,7 @@ func GenerateRedeemCodes(db *sql.DB, num int, quota float32) RedeemGenerateRespo
|
|||||||
|
|
||||||
func CreateRedeemCode(db *sql.DB, quota float32) (string, error) {
|
func CreateRedeemCode(db *sql.DB, quota float32) (string, error) {
|
||||||
code := fmt.Sprintf("nio-%s", utils.GenerateChar(32))
|
code := fmt.Sprintf("nio-%s", utils.GenerateChar(32))
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO redeem (code, quota) VALUES (?, ?)
|
INSERT INTO redeem (code, quota) VALUES (?, ?)
|
||||||
`, code, quota)
|
`, code, quota)
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chat/channel"
|
"chat/channel"
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
@ -31,7 +32,7 @@ func getUsersForm(db *sql.DB, page int64, search string) PaginationForm {
|
|||||||
var users []interface{}
|
var users []interface{}
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT COUNT(*) FROM auth
|
SELECT COUNT(*) FROM auth
|
||||||
WHERE username LIKE ?
|
WHERE username LIKE ?
|
||||||
`, "%"+search+"%").Scan(&total); err != nil {
|
`, "%"+search+"%").Scan(&total); err != nil {
|
||||||
@ -41,7 +42,7 @@ func getUsersForm(db *sql.DB, page int64, search string) PaginationForm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.Query(`
|
rows, err := globals.QueryDb(db, `
|
||||||
SELECT
|
SELECT
|
||||||
auth.id, auth.username, auth.email, auth.is_admin,
|
auth.id, auth.username, auth.email, auth.is_admin,
|
||||||
quota.quota, quota.used,
|
quota.quota, quota.used,
|
||||||
@ -116,7 +117,7 @@ func passwordMigration(db *sql.DB, cache *redis.Client, id int64, password strin
|
|||||||
return fmt.Errorf("password length must be between 6 and 36")
|
return fmt.Errorf("password length must be between 6 and 36")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
UPDATE auth SET password = ? WHERE id = ?
|
UPDATE auth SET password = ? WHERE id = ?
|
||||||
`, utils.Sha2Encrypt(password), id)
|
`, utils.Sha2Encrypt(password), id)
|
||||||
|
|
||||||
@ -126,7 +127,7 @@ func passwordMigration(db *sql.DB, cache *redis.Client, id int64, password strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func emailMigration(db *sql.DB, id int64, email string) error {
|
func emailMigration(db *sql.DB, id int64, email string) error {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
UPDATE auth SET email = ? WHERE id = ?
|
UPDATE auth SET email = ? WHERE id = ?
|
||||||
`, email, id)
|
`, email, id)
|
||||||
|
|
||||||
@ -134,7 +135,7 @@ func emailMigration(db *sql.DB, id int64, email string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setAdmin(db *sql.DB, id int64, isAdmin bool) error {
|
func setAdmin(db *sql.DB, id int64, isAdmin bool) error {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
UPDATE auth SET is_admin = ? WHERE id = ?
|
UPDATE auth SET is_admin = ? WHERE id = ?
|
||||||
`, isAdmin, id)
|
`, isAdmin, id)
|
||||||
|
|
||||||
@ -142,7 +143,7 @@ func setAdmin(db *sql.DB, id int64, isAdmin bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func banUser(db *sql.DB, id int64, isBanned bool) error {
|
func banUser(db *sql.DB, id int64, isBanned bool) error {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
UPDATE auth SET is_banned = ? WHERE id = ?
|
UPDATE auth SET is_banned = ? WHERE id = ?
|
||||||
`, isBanned, id)
|
`, isBanned, id)
|
||||||
|
|
||||||
@ -154,7 +155,7 @@ func quotaMigration(db *sql.DB, id int64, quota float32, override bool) error {
|
|||||||
// if quota is positive, then increase quota
|
// if quota is positive, then increase quota
|
||||||
|
|
||||||
if override {
|
if override {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?)
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?)
|
||||||
ON DUPLICATE KEY UPDATE quota = ?
|
ON DUPLICATE KEY UPDATE quota = ?
|
||||||
`, id, quota, 0., quota)
|
`, id, quota, 0., quota)
|
||||||
@ -162,7 +163,7 @@ func quotaMigration(db *sql.DB, id int64, quota float32, override bool) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?)
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?)
|
||||||
ON DUPLICATE KEY UPDATE quota = quota + ?
|
ON DUPLICATE KEY UPDATE quota = quota + ?
|
||||||
`, id, quota, 0., quota)
|
`, id, quota, 0., quota)
|
||||||
@ -176,7 +177,7 @@ func subscriptionMigration(db *sql.DB, id int64, month int64) error {
|
|||||||
|
|
||||||
expireAt := time.Now().AddDate(0, int(month), 0)
|
expireAt := time.Now().AddDate(0, int(month), 0)
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO subscription (user_id, total_month, expired_at) VALUES (?, ?, ?)
|
INSERT INTO subscription (user_id, total_month, expired_at) VALUES (?, ?, ?)
|
||||||
ON DUPLICATE KEY UPDATE total_month = total_month + ?, expired_at = DATE_ADD(expired_at, INTERVAL ? MONTH)
|
ON DUPLICATE KEY UPDATE total_month = total_month + ?, expired_at = DATE_ADD(expired_at, INTERVAL ? MONTH)
|
||||||
`, id, month, expireAt, month, month)
|
`, id, month, expireAt, month, month)
|
||||||
@ -189,7 +190,7 @@ func subscriptionLevelMigration(db *sql.DB, id int64, level int64) error {
|
|||||||
return fmt.Errorf("invalid subscription level")
|
return fmt.Errorf("invalid subscription level")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO subscription (user_id, level) VALUES (?, ?)
|
INSERT INTO subscription (user_id, level) VALUES (?, ?)
|
||||||
ON DUPLICATE KEY UPDATE level = ?
|
ON DUPLICATE KEY UPDATE level = ?
|
||||||
`, id, level, level)
|
`, id, level, level)
|
||||||
@ -199,7 +200,7 @@ func subscriptionLevelMigration(db *sql.DB, id int64, level int64) error {
|
|||||||
|
|
||||||
func releaseUsage(db *sql.DB, cache *redis.Client, id int64) error {
|
func releaseUsage(db *sql.DB, cache *redis.Client, id int64) error {
|
||||||
var level sql.NullInt64
|
var level sql.NullInt64
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT level FROM subscription WHERE user_id = ?
|
SELECT level FROM subscription WHERE user_id = ?
|
||||||
`, id).Scan(&level); err != nil {
|
`, id).Scan(&level); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -225,7 +226,7 @@ func UpdateRootPassword(db *sql.DB, cache *redis.Client, password string) error
|
|||||||
return fmt.Errorf("password length must be between 6 and 36")
|
return fmt.Errorf("password length must be between 6 and 36")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := db.Exec(`
|
if _, err := globals.ExecDb(db, `
|
||||||
UPDATE auth SET password = ? WHERE username = 'root'
|
UPDATE auth SET password = ? WHERE username = 'root'
|
||||||
`, utils.Sha2Encrypt(password)); err != nil {
|
`, utils.Sha2Encrypt(password)); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { tokenField, websocketEndpoint } from "@/conf/bootstrap.ts";
|
import { tokenField, websocketEndpoint } from "@/conf/bootstrap.ts";
|
||||||
import { getMemory } from "@/utils/memory.ts";
|
import { getMemory } from "@/utils/memory.ts";
|
||||||
import { getErrorMessage } from "@/utils/base.ts";
|
import { getErrorMessage } from "@/utils/base.ts";
|
||||||
|
import { Mask } from "@/masks/types.ts";
|
||||||
|
|
||||||
export const endpoint = `${websocketEndpoint}/chat`;
|
export const endpoint = `${websocketEndpoint}/chat`;
|
||||||
export const maxRetry = 60; // 30s max websocket retry
|
export const maxRetry = 60; // 30s max websocket retry
|
||||||
@ -31,11 +32,12 @@ export type ChatProps = {
|
|||||||
repetition_penalty?: number;
|
repetition_penalty?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
type StreamCallback = (message: StreamMessage) => void;
|
type StreamCallback = (id: number, message: StreamMessage) => void;
|
||||||
|
|
||||||
export class Connection {
|
export class Connection {
|
||||||
protected connection?: WebSocket;
|
protected connection?: WebSocket;
|
||||||
protected callback?: StreamCallback;
|
protected callback?: StreamCallback;
|
||||||
|
protected stack?: string;
|
||||||
public id: number;
|
public id: number;
|
||||||
public state: boolean;
|
public state: boolean;
|
||||||
|
|
||||||
@ -66,6 +68,10 @@ export class Connection {
|
|||||||
const message = JSON.parse(event.data);
|
const message = JSON.parse(event.data);
|
||||||
this.triggerCallback(message as StreamMessage);
|
this.triggerCallback(message as StreamMessage);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
this.connection.onclose = (event) => {
|
||||||
|
this.stack = `websocket connection failed (code: ${event.code}, reason: ${event.reason}, endpoint: ${endpoint})`;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
public reconnect(): void {
|
public reconnect(): void {
|
||||||
@ -99,23 +105,56 @@ export class Connection {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const trace = {
|
const trace =
|
||||||
message: data.message,
|
this.stack ||
|
||||||
endpoint: endpoint,
|
JSON.stringify(
|
||||||
};
|
{
|
||||||
|
message: data.message,
|
||||||
|
endpoint: endpoint,
|
||||||
|
},
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
);
|
||||||
|
|
||||||
t &&
|
t &&
|
||||||
this.triggerCallback({
|
this.triggerCallback({
|
||||||
message: `
|
message: `${t("request-failed")}\n\`\`\`json\n${trace}\n\`\`\`\n`,
|
||||||
${t("request-failed")}
|
|
||||||
\`\`\`json
|
|
||||||
${JSON.stringify(trace, null, 2)}
|
|
||||||
\`\`\`
|
|
||||||
`,
|
|
||||||
end: true,
|
end: true,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public sendEvent(t: any, event: string, data?: string) {
|
||||||
|
this.sendWithRetry(t, {
|
||||||
|
type: event,
|
||||||
|
message: data || "",
|
||||||
|
model: "event",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendStopEvent(t: any) {
|
||||||
|
this.sendEvent(t, "stop");
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendRestartEvent(t: any) {
|
||||||
|
this.sendEvent(t, "restart");
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendMaskEvent(t: any, mask: Mask) {
|
||||||
|
this.sendEvent(t, "mask", JSON.stringify(mask.context));
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendEditEvent(t: any, id: number, message: string) {
|
||||||
|
this.sendEvent(t, "edit", `${id}:${message}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendRemoveEvent(t: any, id: number) {
|
||||||
|
this.sendEvent(t, "remove", id.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendShareEvent(t: any, refer: string) {
|
||||||
|
this.sendEvent(t, "share", refer);
|
||||||
|
}
|
||||||
|
|
||||||
public close(): void {
|
public close(): void {
|
||||||
if (!this.connection) return;
|
if (!this.connection) return;
|
||||||
this.connection.close();
|
this.connection.close();
|
||||||
@ -126,13 +165,91 @@ ${JSON.stringify(trace, null, 2)}
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected triggerCallback(message: StreamMessage): void {
|
protected triggerCallback(message: StreamMessage): void {
|
||||||
if (this.id === -1 && message.conversation) {
|
this.callback && this.callback(this.id, message);
|
||||||
this.setId(message.conversation);
|
|
||||||
}
|
|
||||||
this.callback && this.callback(message);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public setId(id: number): void {
|
public setId(id: number): void {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export class ConnectionStack {
|
||||||
|
protected connections: Connection[];
|
||||||
|
protected callback?: StreamCallback;
|
||||||
|
|
||||||
|
public constructor(callback?: StreamCallback) {
|
||||||
|
this.connections = [];
|
||||||
|
this.callback = callback;
|
||||||
|
}
|
||||||
|
|
||||||
|
public getConnection(id: number): Connection | undefined {
|
||||||
|
return this.connections.find((conn) => conn.id === id);
|
||||||
|
}
|
||||||
|
|
||||||
|
public addConnection(id: number): Connection {
|
||||||
|
const conn = new Connection(id, this.triggerCallback.bind(this));
|
||||||
|
this.connections.push(conn);
|
||||||
|
return conn;
|
||||||
|
}
|
||||||
|
|
||||||
|
public setCallback(callback?: StreamCallback): void {
|
||||||
|
this.callback = callback;
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendEvent(id: number, t: any, event: string, data?: string) {
|
||||||
|
const conn = this.getConnection(id);
|
||||||
|
conn && conn.sendEvent(t, event, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendStopEvent(id: number, t: any) {
|
||||||
|
const conn = this.getConnection(id);
|
||||||
|
conn && conn.sendStopEvent(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendRestartEvent(id: number, t: any) {
|
||||||
|
const conn = this.getConnection(id);
|
||||||
|
conn && conn.sendRestartEvent(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendMaskEvent(id: number, t: any, mask: Mask) {
|
||||||
|
const conn = this.getConnection(id);
|
||||||
|
conn && conn.sendMaskEvent(t, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendEditEvent(id: number, t: any, message: string) {
|
||||||
|
const conn = this.getConnection(id);
|
||||||
|
conn && conn.sendEditEvent(t, id, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendRemoveEvent(id: number, t: any, messageId: number) {
|
||||||
|
const conn = this.getConnection(id);
|
||||||
|
conn && conn.sendRemoveEvent(t, messageId);
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendShareEvent(id: number, t: any, refer: string) {
|
||||||
|
const conn = this.getConnection(id);
|
||||||
|
conn && conn.sendShareEvent(t, refer);
|
||||||
|
}
|
||||||
|
|
||||||
|
public close(id: number): void {
|
||||||
|
const conn = this.getConnection(id);
|
||||||
|
conn && conn.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public closeAll(): void {
|
||||||
|
this.connections.forEach((conn) => conn.close());
|
||||||
|
}
|
||||||
|
|
||||||
|
public reconnect(id: number): void {
|
||||||
|
const conn = this.getConnection(id);
|
||||||
|
conn && conn.reconnect();
|
||||||
|
}
|
||||||
|
|
||||||
|
public reconnectAll(): void {
|
||||||
|
this.connections.forEach((conn) => conn.reconnect());
|
||||||
|
}
|
||||||
|
|
||||||
|
public triggerCallback(id: number, message: StreamMessage): void {
|
||||||
|
this.callback && this.callback(id, message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -9,6 +9,13 @@ import { Mask } from "@/masks/types.ts";
|
|||||||
|
|
||||||
type ConversationCallback = (idx: number, message: Message[]) => boolean;
|
type ConversationCallback = (idx: number, message: Message[]) => boolean;
|
||||||
|
|
||||||
|
export type ConversationSerialized = {
|
||||||
|
model: string;
|
||||||
|
end: boolean;
|
||||||
|
mask: Mask | null;
|
||||||
|
messages: Message[];
|
||||||
|
};
|
||||||
|
|
||||||
export class Conversation {
|
export class Conversation {
|
||||||
protected connection?: Connection;
|
protected connection?: Connection;
|
||||||
protected callback?: ConversationCallback;
|
protected callback?: ConversationCallback;
|
||||||
@ -100,14 +107,14 @@ export class Conversation {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
public sendStopEvent() {
|
|
||||||
this.sendEvent("stop");
|
|
||||||
}
|
|
||||||
|
|
||||||
public isValidIndex(idx: number): boolean {
|
public isValidIndex(idx: number): boolean {
|
||||||
return idx >= 0 && idx < this.data.length;
|
return idx >= 0 && idx < this.data.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public sendStopEvent() {
|
||||||
|
this.sendEvent("stop");
|
||||||
|
}
|
||||||
|
|
||||||
public sendRestartEvent() {
|
public sendRestartEvent() {
|
||||||
this.sendEvent("restart");
|
this.sendEvent("restart");
|
||||||
}
|
}
|
||||||
@ -253,10 +260,6 @@ export class Conversation {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
public getSegmentData(length: number): Message[] {
|
|
||||||
return this.data.slice(this.data.length - length - 1, this.data.length - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
public send(t: any, props: ChatProps) {
|
public send(t: any, props: ChatProps) {
|
||||||
if (!this.connection) {
|
if (!this.connection) {
|
||||||
this.connection = new Connection(this.id);
|
this.connection = new Connection(this.id);
|
||||||
|
@ -91,7 +91,7 @@ strong {
|
|||||||
|
|
||||||
.flex-dialog {
|
.flex-dialog {
|
||||||
border-radius: var(--radius) !important;
|
border-radius: var(--radius) !important;
|
||||||
max-height: calc(100vh - 2rem) !important;
|
max-height: calc(95vh - 2rem) !important;
|
||||||
overflow-x: hidden;
|
overflow-x: hidden;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
scrollbar-width: none;
|
scrollbar-width: none;
|
||||||
@ -128,7 +128,7 @@ strong {
|
|||||||
|
|
||||||
.fixed-dialog {
|
.fixed-dialog {
|
||||||
border-radius: var(--radius) !important;
|
border-radius: var(--radius) !important;
|
||||||
max-height: calc(100vh - 2rem) !important;
|
max-height: calc(95vh - 2rem) !important;
|
||||||
min-height: 60vh;
|
min-height: 60vh;
|
||||||
overflow-x: hidden;
|
overflow-x: hidden;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
@ -197,4 +197,4 @@ strong {
|
|||||||
.chat-logo {
|
.chat-logo {
|
||||||
border-radius: var(--radius);
|
border-radius: var(--radius);
|
||||||
user-select: none;
|
user-select: none;
|
||||||
}
|
}
|
||||||
|
@ -497,7 +497,7 @@
|
|||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
user-select: none;
|
user-select: none;
|
||||||
|
|
||||||
&:before {
|
&:not(.loading):before {
|
||||||
content: "#";
|
content: "#";
|
||||||
font-size: 12px;
|
font-size: 12px;
|
||||||
margin-right: 1px;
|
margin-right: 1px;
|
||||||
|
@ -65,7 +65,7 @@ function ModelUsageChart({ labels, datasets }: ModelChartProps) {
|
|||||||
<div className="chart-tooltip min-w-56 w-max z-10 rounded-tremor-default border border-tremor-border bg-tremor-background p-2 text-tremor-default shadow-tremor-dropdown">
|
<div className="chart-tooltip min-w-56 w-max z-10 rounded-tremor-default border border-tremor-border bg-tremor-background p-2 text-tremor-default shadow-tremor-dropdown">
|
||||||
<div className="flex flex-1 space-x-2.5">
|
<div className="flex flex-1 space-x-2.5">
|
||||||
<div
|
<div
|
||||||
className={`flex w-1.5 flex-col bg-${categoryPayload?.color}-500 rounded`}
|
className={`flex w-1.5 flex-col bg-${categoryPayload?.color} rounded`}
|
||||||
/>
|
/>
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
<div className="flex items-center justify-between space-x-8">
|
<div className="flex items-center justify-between space-x-8">
|
||||||
|
@ -2,7 +2,13 @@ import { toggleConversation } from "@/api/history.ts";
|
|||||||
import { mobile } from "@/utils/device.ts";
|
import { mobile } from "@/utils/device.ts";
|
||||||
import { filterMessage } from "@/utils/processor.ts";
|
import { filterMessage } from "@/utils/processor.ts";
|
||||||
import { setMenu } from "@/store/menu.ts";
|
import { setMenu } from "@/store/menu.ts";
|
||||||
import { MessageSquare, MoreHorizontal, Share2, Trash2 } from "lucide-react";
|
import {
|
||||||
|
Loader2,
|
||||||
|
MessageSquare,
|
||||||
|
MoreHorizontal,
|
||||||
|
Share2,
|
||||||
|
Trash2,
|
||||||
|
} from "lucide-react";
|
||||||
import {
|
import {
|
||||||
DropdownMenu,
|
DropdownMenu,
|
||||||
DropdownMenuContent,
|
DropdownMenuContent,
|
||||||
@ -34,6 +40,8 @@ function ConversationSegment({
|
|||||||
const [open, setOpen] = useState(false);
|
const [open, setOpen] = useState(false);
|
||||||
const [offset, setOffset] = useState(0);
|
const [offset, setOffset] = useState(0);
|
||||||
|
|
||||||
|
const loading = conversation.id <= 0;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn("conversation", current === conversation.id && "active")}
|
className={cn("conversation", current === conversation.id && "active")}
|
||||||
@ -51,7 +59,13 @@ function ConversationSegment({
|
|||||||
>
|
>
|
||||||
<MessageSquare className={`h-4 w-4 mr-1`} />
|
<MessageSquare className={`h-4 w-4 mr-1`} />
|
||||||
<div className={`title`}>{filterMessage(conversation.name)}</div>
|
<div className={`title`}>{filterMessage(conversation.name)}</div>
|
||||||
<div className={`id`}>{conversation.id}</div>
|
<div className={cn("id", loading && "loading")}>
|
||||||
|
{loading ? (
|
||||||
|
<Loader2 className={`mr-0.5 h-4 w-4 animate-spin`} />
|
||||||
|
) : (
|
||||||
|
conversation.id
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
<DropdownMenu
|
<DropdownMenu
|
||||||
open={open}
|
open={open}
|
||||||
onOpenChange={(state: boolean) => {
|
onOpenChange={(state: boolean) => {
|
||||||
|
@ -17,10 +17,15 @@ import {
|
|||||||
} from "@/conf/storage.ts";
|
} from "@/conf/storage.ts";
|
||||||
import { CustomMask } from "@/masks/types.ts";
|
import { CustomMask } from "@/masks/types.ts";
|
||||||
import { listMasks } from "@/api/mask.ts";
|
import { listMasks } from "@/api/mask.ts";
|
||||||
|
import { ConversationSerialized } from "@/api/conversation.ts";
|
||||||
|
import { useSelector } from "react-redux";
|
||||||
|
import { useMemo } from "react";
|
||||||
|
import { ConnectionStack } from "@/api/connection.ts";
|
||||||
|
|
||||||
type initialStateType = {
|
type initialStateType = {
|
||||||
history: ConversationInstance[];
|
history: ConversationInstance[];
|
||||||
messages: Message[];
|
messages: Message[];
|
||||||
|
conversations: Record<number, ConversationSerialized>;
|
||||||
model: string;
|
model: string;
|
||||||
web: boolean;
|
web: boolean;
|
||||||
current: number;
|
current: number;
|
||||||
@ -60,12 +65,14 @@ export function getModelList(
|
|||||||
return target;
|
return target;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const stack = new ConnectionStack();
|
||||||
const offline = loadPreferenceModels(getOfflineModels());
|
const offline = loadPreferenceModels(getOfflineModels());
|
||||||
const chatSlice = createSlice({
|
const chatSlice = createSlice({
|
||||||
name: "chat",
|
name: "chat",
|
||||||
initialState: {
|
initialState: {
|
||||||
history: [],
|
history: [],
|
||||||
messages: [],
|
messages: [],
|
||||||
|
conversations: {},
|
||||||
web: getBooleanMemory("web", false),
|
web: getBooleanMemory("web", false),
|
||||||
current: -1,
|
current: -1,
|
||||||
model: getModel(offline, getMemory("model")),
|
model: getModel(offline, getMemory("model")),
|
||||||
@ -214,6 +221,9 @@ export const selectHistory = (state: RootState): ConversationInstance[] =>
|
|||||||
state.chat.history;
|
state.chat.history;
|
||||||
export const selectMessages = (state: RootState): Message[] =>
|
export const selectMessages = (state: RootState): Message[] =>
|
||||||
state.chat.messages;
|
state.chat.messages;
|
||||||
|
export const selectConversations = (
|
||||||
|
state: RootState,
|
||||||
|
): Record<number, ConversationSerialized> => state.chat.conversations;
|
||||||
export const selectModel = (state: RootState): string => state.chat.model;
|
export const selectModel = (state: RootState): string => state.chat.model;
|
||||||
export const selectWeb = (state: RootState): boolean => state.chat.web;
|
export const selectWeb = (state: RootState): boolean => state.chat.web;
|
||||||
export const selectCurrent = (state: RootState): number => state.chat.current;
|
export const selectCurrent = (state: RootState): number => state.chat.current;
|
||||||
@ -226,6 +236,23 @@ export const selectCustomMasks = (state: RootState): CustomMask[] =>
|
|||||||
export const selectSupportModels = (state: RootState): Model[] =>
|
export const selectSupportModels = (state: RootState): Model[] =>
|
||||||
state.chat.support_models;
|
state.chat.support_models;
|
||||||
|
|
||||||
|
export function useConversation(): ConversationSerialized | undefined {
|
||||||
|
const conversations = useSelector(selectConversations);
|
||||||
|
const current = useSelector(selectCurrent);
|
||||||
|
|
||||||
|
return useMemo(() => conversations[current], [conversations, current]);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useMessages(): Message[] {
|
||||||
|
const conversations = useSelector(selectConversations);
|
||||||
|
const current = useSelector(selectCurrent);
|
||||||
|
|
||||||
|
return useMemo(
|
||||||
|
() => conversations[current]?.messages || [],
|
||||||
|
[conversations, current],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export const updateMasks = async (dispatch: AppDispatch) => {
|
export const updateMasks = async (dispatch: AppDispatch) => {
|
||||||
const resp = await listMasks();
|
const resp = await listMasks();
|
||||||
resp.data.length > 0 && dispatch(setCustomMasks(resp.data));
|
resp.data.length > 0 && dispatch(setCustomMasks(resp.data));
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
@ -10,7 +11,7 @@ import (
|
|||||||
func (u *User) CreateApiKey(db *sql.DB) string {
|
func (u *User) CreateApiKey(db *sql.DB) string {
|
||||||
salt := utils.Sha2Encrypt(fmt.Sprintf("%s-%s", u.Username, utils.GenerateChar(utils.GetRandomInt(720, 1024))))
|
salt := utils.Sha2Encrypt(fmt.Sprintf("%s-%s", u.Username, utils.GenerateChar(utils.GetRandomInt(720, 1024))))
|
||||||
key := fmt.Sprintf("sk-%s", salt[:64]) // 64 bytes
|
key := fmt.Sprintf("sk-%s", salt[:64]) // 64 bytes
|
||||||
if _, err := db.Exec("INSERT INTO apikey (user_id, api_key) VALUES (?, ?)", u.GetID(db), key); err != nil {
|
if _, err := globals.ExecDb(db, "INSERT INTO apikey (user_id, api_key) VALUES (?, ?)", u.GetID(db), key); err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return key
|
return key
|
||||||
@ -18,14 +19,14 @@ func (u *User) CreateApiKey(db *sql.DB) string {
|
|||||||
|
|
||||||
func (u *User) GetApiKey(db *sql.DB) string {
|
func (u *User) GetApiKey(db *sql.DB) string {
|
||||||
var key string
|
var key string
|
||||||
if err := db.QueryRow("SELECT api_key FROM apikey WHERE user_id = ?", u.GetID(db)).Scan(&key); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT api_key FROM apikey WHERE user_id = ?", u.GetID(db)).Scan(&key); err != nil {
|
||||||
return u.CreateApiKey(db)
|
return u.CreateApiKey(db)
|
||||||
}
|
}
|
||||||
return key
|
return key
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) ResetApiKey(db *sql.DB) (string, error) {
|
func (u *User) ResetApiKey(db *sql.DB) (string, error) {
|
||||||
if _, err := db.Exec("DELETE FROM apikey WHERE user_id = ?", u.GetID(db)); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if _, err := globals.ExecDb(db, "DELETE FROM apikey WHERE user_id = ?", u.GetID(db)); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return u.CreateApiKey(db), nil
|
return u.CreateApiKey(db), nil
|
||||||
|
20
auth/auth.go
20
auth/auth.go
@ -47,7 +47,7 @@ func ParseApiKey(c *gin.Context, key string) *User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user User
|
var user User
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT auth.id, auth.username, auth.password FROM auth
|
SELECT auth.id, auth.username, auth.password FROM auth
|
||||||
INNER JOIN apikey ON auth.id = apikey.user_id
|
INNER JOIN apikey ON auth.id = apikey.user_id
|
||||||
WHERE apikey.api_key = ?
|
WHERE apikey.api_key = ?
|
||||||
@ -143,7 +143,7 @@ func SignUp(c *gin.Context, form RegisterForm) (string, error) {
|
|||||||
Token: utils.Sha2Encrypt(email + username),
|
Token: utils.Sha2Encrypt(email + username),
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := db.Exec(`
|
if _, err := globals.ExecDb(db, `
|
||||||
INSERT INTO auth (username, password, email, bind_id, token)
|
INSERT INTO auth (username, password, email, bind_id, token)
|
||||||
VALUES (?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?)
|
||||||
`, user.Username, user.Password, user.Email, user.BindID, user.Token); err != nil {
|
`, user.Username, user.Password, user.Email, user.BindID, user.Token); err != nil {
|
||||||
@ -170,7 +170,7 @@ func Login(c *gin.Context, form LoginForm) (string, error) {
|
|||||||
|
|
||||||
// get user from db by username (or email) and password
|
// get user from db by username (or email) and password
|
||||||
var user User
|
var user User
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT auth.id, auth.username, auth.password FROM auth
|
SELECT auth.id, auth.username, auth.password FROM auth
|
||||||
WHERE (auth.username = ? OR auth.email = ?) AND auth.password = ?
|
WHERE (auth.username = ? OR auth.email = ?) AND auth.password = ?
|
||||||
`, username, username, hash).Scan(&user.ID, &user.Username, &user.Password); err != nil {
|
`, username, username, hash).Scan(&user.ID, &user.Username, &user.Password); err != nil {
|
||||||
@ -202,7 +202,7 @@ func DeepLogin(c *gin.Context, token string) (string, error) {
|
|||||||
|
|
||||||
// register
|
// register
|
||||||
password := utils.GenerateChar(64)
|
password := utils.GenerateChar(64)
|
||||||
_ = db.QueryRow("INSERT INTO auth (bind_id, username, token, password) VALUES (?, ?, ?, ?)",
|
_ = globals.QueryRowDb(db, "INSERT INTO auth (bind_id, username, token, password) VALUES (?, ?, ?, ?)",
|
||||||
user.ID, user.Username, token, password)
|
user.ID, user.Username, token, password)
|
||||||
u := &User{
|
u := &User{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
@ -214,9 +214,9 @@ func DeepLogin(c *gin.Context, token string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// login
|
// login
|
||||||
_ = db.QueryRow("UPDATE auth SET token = ? WHERE username = ?", token, user.Username)
|
_ = globals.QueryRowDb(db, "UPDATE auth SET token = ? WHERE username = ?", token, user.Username)
|
||||||
var password string
|
var password string
|
||||||
err := db.QueryRow("SELECT password FROM auth WHERE username = ?", user.Username).Scan(&password)
|
err := globals.QueryRowDb(db, "SELECT password FROM auth WHERE username = ?", user.Username).Scan(&password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -273,7 +273,7 @@ func Reset(c *gin.Context, form ResetForm) error {
|
|||||||
func (u *User) UpdatePassword(db *sql.DB, cache *redis.Client, password string) error {
|
func (u *User) UpdatePassword(db *sql.DB, cache *redis.Client, password string) error {
|
||||||
hash := utils.Sha2Encrypt(password)
|
hash := utils.Sha2Encrypt(password)
|
||||||
|
|
||||||
if _, err := db.Exec(`
|
if _, err := globals.ExecDb(db, `
|
||||||
UPDATE auth SET password = ? WHERE id = ?
|
UPDATE auth SET password = ? WHERE id = ?
|
||||||
`, hash, u.ID); err != nil {
|
`, hash, u.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -296,7 +296,7 @@ func (u *User) Validate(c *gin.Context) bool {
|
|||||||
|
|
||||||
db := utils.GetDBFromContext(c)
|
db := utils.GetDBFromContext(c)
|
||||||
var count int
|
var count int
|
||||||
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ? AND password = ?", u.Username, u.Password).Scan(&count); err != nil || count == 0 {
|
if err := globals.QueryRowDb(db, "SELECT COUNT(*) FROM auth WHERE username = ? AND password = ?", u.Username, u.Password).Scan(&count); err != nil || count == 0 {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
globals.Warn(fmt.Sprintf("validate user error: %s", err.Error()))
|
globals.Warn(fmt.Sprintf("validate user error: %s", err.Error()))
|
||||||
}
|
}
|
||||||
@ -328,13 +328,13 @@ func (u *User) GenerateToken() (string, error) {
|
|||||||
|
|
||||||
func (u *User) GenerateTokenSafe(db *sql.DB) (string, error) {
|
func (u *User) GenerateTokenSafe(db *sql.DB) (string, error) {
|
||||||
if len(u.Username) == 0 {
|
if len(u.Username) == 0 {
|
||||||
if err := db.QueryRow("SELECT username FROM auth WHERE id = ?", u.ID).Scan(&u.Username); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT username FROM auth WHERE id = ?", u.ID).Scan(&u.Username); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(u.Password) == 0 {
|
if len(u.Password) == 0 {
|
||||||
if err := db.QueryRow("SELECT password FROM auth WHERE id = ?", u.ID).Scan(&u.Password); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT password FROM auth WHERE id = ?", u.ID).Scan(&u.Password); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
@ -36,7 +37,7 @@ func GenerateInvitations(db *sql.DB, num int, quota float32, t string) ([]string
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateInvitationCode(db *sql.DB, code string, quota float32, t string) error {
|
func CreateInvitationCode(db *sql.DB, code string, quota float32, t string) error {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO invitation (code, quota, type)
|
INSERT INTO invitation (code, quota, type)
|
||||||
VALUES (?, ?, ?)
|
VALUES (?, ?, ?)
|
||||||
`, code, quota, t)
|
`, code, quota, t)
|
||||||
@ -44,7 +45,7 @@ func CreateInvitationCode(db *sql.DB, code string, quota float32, t string) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetInvitation(db *sql.DB, code string) (*Invitation, error) {
|
func GetInvitation(db *sql.DB, code string) (*Invitation, error) {
|
||||||
row := db.QueryRow(`
|
row := globals.QueryRowDb(db, `
|
||||||
SELECT id, code, quota, type, used, used_id
|
SELECT id, code, quota, type, used, used_id
|
||||||
FROM invitation
|
FROM invitation
|
||||||
WHERE code = ?
|
WHERE code = ?
|
||||||
@ -69,7 +70,7 @@ func (i *Invitation) IsUsed() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *Invitation) Use(db *sql.DB, userId int64) error {
|
func (i *Invitation) Use(db *sql.DB, userId int64) error {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
UPDATE invitation SET used = TRUE, used_id = ? WHERE id = ?
|
UPDATE invitation SET used = TRUE, used_id = ? WHERE id = ?
|
||||||
`, userId, i.Id)
|
`, userId, i.Id)
|
||||||
return err
|
return err
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import "database/sql"
|
import (
|
||||||
|
"chat/globals"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
type GiftResponse struct {
|
type GiftResponse struct {
|
||||||
Cert bool `json:"cert"`
|
Cert bool `json:"cert"`
|
||||||
@ -9,7 +12,7 @@ type GiftResponse struct {
|
|||||||
|
|
||||||
func (u *User) HasPackage(db *sql.DB, _t string) bool {
|
func (u *User) HasPackage(db *sql.DB, _t string) bool {
|
||||||
var count int
|
var count int
|
||||||
if err := db.QueryRow(`SELECT COUNT(*) FROM package where user_id = ? AND type = ?`, u.ID, _t).Scan(&count); err != nil {
|
if err := globals.QueryRowDb(db, `SELECT COUNT(*) FROM package where user_id = ? AND type = ?`, u.ID, _t).Scan(&count); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,7 +31,7 @@ func NewPackage(db *sql.DB, user *User, _t string) bool {
|
|||||||
id := user.GetID(db)
|
id := user.GetID(db)
|
||||||
|
|
||||||
var count int
|
var count int
|
||||||
if err := db.QueryRow(`SELECT COUNT(*) FROM package where user_id = ? AND type = ?`, id, _t).Scan(&count); err != nil {
|
if err := globals.QueryRowDb(db, `SELECT COUNT(*) FROM package where user_id = ? AND type = ?`, id, _t).Scan(&count); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,7 +39,7 @@ func NewPackage(db *sql.DB, user *User, _t string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = db.QueryRow(`INSERT INTO package (user_id, type) VALUES (?, ?)`, id, _t)
|
_ = globals.QueryRowDb(db, `INSERT INTO package (user_id, type) VALUES (?, ?)`, id, _t)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,11 +2,12 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chat/channel"
|
"chat/channel"
|
||||||
|
"chat/globals"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (u *User) CreateInitialQuota(db *sql.DB) bool {
|
func (u *User) CreateInitialQuota(db *sql.DB) bool {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?)
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?)
|
||||||
`, u.GetID(db), channel.SystemInstance.GetInitialQuota(), 0.)
|
`, u.GetID(db), channel.SystemInstance.GetInitialQuota(), 0.)
|
||||||
return err == nil
|
return err == nil
|
||||||
@ -14,7 +15,7 @@ func (u *User) CreateInitialQuota(db *sql.DB) bool {
|
|||||||
|
|
||||||
func (u *User) GetQuota(db *sql.DB) float32 {
|
func (u *User) GetQuota(db *sql.DB) float32 {
|
||||||
var quota float32
|
var quota float32
|
||||||
if err := db.QueryRow("SELECT quota FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT quota FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil {
|
||||||
return 0.
|
return 0.
|
||||||
}
|
}
|
||||||
return quota
|
return quota
|
||||||
@ -22,44 +23,50 @@ func (u *User) GetQuota(db *sql.DB) float32 {
|
|||||||
|
|
||||||
func (u *User) GetUsedQuota(db *sql.DB) float32 {
|
func (u *User) GetUsedQuota(db *sql.DB) float32 {
|
||||||
var quota float32
|
var quota float32
|
||||||
if err := db.QueryRow("SELECT used FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT used FROM quota WHERE user_id = ?", u.GetID(db)).Scan("a); err != nil {
|
||||||
return 0.
|
return 0.
|
||||||
}
|
}
|
||||||
return quota
|
return quota
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) SetQuota(db *sql.DB, quota float32) bool {
|
func (u *User) SetQuota(db *sql.DB, quota float32) bool {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = ?
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = ?
|
||||||
`, u.GetID(db), quota, 0., quota)
|
`, u.GetID(db), quota, 0., quota)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) SetUsedQuota(db *sql.DB, used float32) bool {
|
func (u *User) SetUsedQuota(db *sql.DB, used float32) bool {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE used = ?
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE used = ?
|
||||||
`, u.GetID(db), 0., used, used)
|
`, u.GetID(db), 0., used, used)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) IncreaseQuota(db *sql.DB, quota float32) bool {
|
func (u *User) IncreaseQuota(db *sql.DB, quota float32) bool {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, globals.MultiSql(`
|
||||||
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = quota + ?
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = quota + ?
|
||||||
`, u.GetID(db), quota, 0., quota)
|
`, `
|
||||||
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET quota = quota + ?
|
||||||
|
`), u.GetID(db), quota, 0., quota)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) IncreaseUsedQuota(db *sql.DB, used float32) bool {
|
func (u *User) IncreaseUsedQuota(db *sql.DB, used float32) bool {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, globals.MultiSql(`
|
||||||
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE used = used + ?
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE used = used + ?
|
||||||
`, u.GetID(db), 0., used, used)
|
`, `
|
||||||
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET used = used + ?
|
||||||
|
`), u.GetID(db), 0., used, used)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) DecreaseQuota(db *sql.DB, quota float32) bool {
|
func (u *User) DecreaseQuota(db *sql.DB, quota float32) bool {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, globals.MultiSql(`
|
||||||
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = quota - ?
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE quota = quota - ?
|
||||||
`, u.GetID(db), quota, 0., quota)
|
`, `
|
||||||
|
INSERT INTO quota (user_id, quota, used) VALUES (?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET quota = quota - ?
|
||||||
|
`), u.GetID(db), quota, 0., quota)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
@ -34,14 +35,14 @@ func GenerateRedeemCodes(db *sql.DB, num int, quota float32) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateRedeemCode(db *sql.DB, code string, quota float32) error {
|
func CreateRedeemCode(db *sql.DB, code string, quota float32) error {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO redeem (code, quota) VALUES (?, ?)
|
INSERT INTO redeem (code, quota) VALUES (?, ?)
|
||||||
`, code, quota)
|
`, code, quota)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRedeemCode(db *sql.DB, code string) (*Redeem, error) {
|
func GetRedeemCode(db *sql.DB, code string) (*Redeem, error) {
|
||||||
row := db.QueryRow(`
|
row := globals.QueryRowDb(db, `
|
||||||
SELECT id, code, quota, used
|
SELECT id, code, quota, used
|
||||||
FROM redeem
|
FROM redeem
|
||||||
WHERE code = ?
|
WHERE code = ?
|
||||||
@ -62,7 +63,7 @@ func (r *Redeem) IsUsed() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Redeem) Use(db *sql.DB) error {
|
func (r *Redeem) Use(db *sql.DB) error {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
UPDATE redeem SET used = TRUE WHERE id = ? AND used = FALSE
|
UPDATE redeem SET used = TRUE WHERE id = ? AND used = FALSE
|
||||||
`, r.Id)
|
`, r.Id)
|
||||||
return err
|
return err
|
||||||
|
@ -22,7 +22,7 @@ type User struct {
|
|||||||
|
|
||||||
func GetUserById(db *sql.DB, id int64) *User {
|
func GetUserById(db *sql.DB, id int64) *User {
|
||||||
var user User
|
var user User
|
||||||
if err := db.QueryRow("SELECT id, username FROM auth WHERE id = ?", id).Scan(&user.ID, &user.Username); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT id, username FROM auth WHERE id = ?", id).Scan(&user.ID, &user.Username); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &user
|
return &user
|
||||||
@ -30,7 +30,7 @@ func GetUserById(db *sql.DB, id int64) *User {
|
|||||||
|
|
||||||
func GetUserByName(db *sql.DB, username string) *User {
|
func GetUserByName(db *sql.DB, username string) *User {
|
||||||
var user User
|
var user User
|
||||||
if err := db.QueryRow("SELECT id, username FROM auth WHERE username = ?", username).Scan(&user.ID, &user.Username); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT id, username FROM auth WHERE username = ?", username).Scan(&user.ID, &user.Username); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &user
|
return &user
|
||||||
@ -38,7 +38,7 @@ func GetUserByName(db *sql.DB, username string) *User {
|
|||||||
|
|
||||||
func GetUserByEmail(db *sql.DB, email string) *User {
|
func GetUserByEmail(db *sql.DB, email string) *User {
|
||||||
var user User
|
var user User
|
||||||
if err := db.QueryRow("SELECT id, username FROM auth WHERE email = ?", email).Scan(&user.ID, &user.Username); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT id, username FROM auth WHERE email = ?", email).Scan(&user.ID, &user.Username); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &user
|
return &user
|
||||||
@ -57,7 +57,7 @@ func (u *User) IsBanned(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var banned sql.NullBool
|
var banned sql.NullBool
|
||||||
if err := db.QueryRow("SELECT is_banned FROM auth WHERE username = ?", u.Username).Scan(&banned); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT is_banned FROM auth WHERE username = ?", u.Username).Scan(&banned); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
u.Banned = banned.Valid && banned.Bool
|
u.Banned = banned.Valid && banned.Bool
|
||||||
@ -71,7 +71,7 @@ func (u *User) IsAdmin(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var admin sql.NullBool
|
var admin sql.NullBool
|
||||||
if err := db.QueryRow("SELECT is_admin FROM auth WHERE username = ?", u.Username).Scan(&admin); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT is_admin FROM auth WHERE username = ?", u.Username).Scan(&admin); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ func (u *User) GetID(db *sql.DB) int64 {
|
|||||||
if u.ID > 0 {
|
if u.ID > 0 {
|
||||||
return u.ID
|
return u.ID
|
||||||
}
|
}
|
||||||
if err := db.QueryRow("SELECT id FROM auth WHERE username = ?", u.Username).Scan(&u.ID); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT id FROM auth WHERE username = ?", u.Username).Scan(&u.ID); err != nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return u.ID
|
return u.ID
|
||||||
@ -99,7 +99,7 @@ func (u *User) GetEmail(db *sql.DB) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var email sql.NullString
|
var email sql.NullString
|
||||||
if err := db.QueryRow("SELECT email FROM auth WHERE username = ?", u.Username).Scan(&email); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT email FROM auth WHERE username = ?", u.Username).Scan(&email); err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ func (u *User) GetEmail(db *sql.DB) string {
|
|||||||
|
|
||||||
func IsUserExist(db *sql.DB, username string) bool {
|
func IsUserExist(db *sql.DB, username string) bool {
|
||||||
var count int
|
var count int
|
||||||
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return count > 0
|
return count > 0
|
||||||
@ -117,7 +117,7 @@ func IsUserExist(db *sql.DB, username string) bool {
|
|||||||
|
|
||||||
func IsEmailExist(db *sql.DB, email string) bool {
|
func IsEmailExist(db *sql.DB, email string) bool {
|
||||||
var count int
|
var count int
|
||||||
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE email = ?", email).Scan(&count); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT COUNT(*) FROM auth WHERE email = ?", email).Scan(&count); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return count > 0
|
return count > 0
|
||||||
@ -125,7 +125,7 @@ func IsEmailExist(db *sql.DB, email string) bool {
|
|||||||
|
|
||||||
func getMaxBindId(db *sql.DB) int64 {
|
func getMaxBindId(db *sql.DB) int64 {
|
||||||
var max int64
|
var max int64
|
||||||
if err := db.QueryRow("SELECT MAX(bind_id) FROM auth").Scan(&max); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT MAX(bind_id) FROM auth").Scan(&max); err != nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return max
|
return max
|
||||||
|
@ -2,6 +2,7 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chat/channel"
|
"chat/channel"
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
@ -21,7 +22,7 @@ func (u *User) GetSubscription(db *sql.DB) (time.Time, int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var expiredAt []uint8
|
var expiredAt []uint8
|
||||||
if err := db.QueryRow("SELECT expired_at, level FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&expiredAt, &u.Level); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT expired_at, level FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&expiredAt, &u.Level); err != nil {
|
||||||
return time.Unix(0, 0), 0
|
return time.Unix(0, 0), 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,7 +63,7 @@ func (u *User) IsEnterprise(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var enterprise sql.NullBool
|
var enterprise sql.NullBool
|
||||||
if err := db.QueryRow("SELECT enterprise FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&enterprise); err != nil {
|
if err := globals.QueryRowDb(db, "SELECT enterprise FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&enterprise); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +82,7 @@ func (u *User) AddSubscription(db *sql.DB, month int, level int) bool {
|
|||||||
}
|
}
|
||||||
expiredAt := current.AddDate(0, month, 0)
|
expiredAt := current.AddDate(0, month, 0)
|
||||||
date := utils.ConvertSqlTime(expiredAt)
|
date := utils.ConvertSqlTime(expiredAt)
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO subscription (user_id, expired_at, total_month, level) VALUES (?, ?, ?, ?)
|
INSERT INTO subscription (user_id, expired_at, total_month, level) VALUES (?, ?, ?, ?)
|
||||||
ON DUPLICATE KEY UPDATE expired_at = ?, total_month = total_month + ?, level = ?
|
ON DUPLICATE KEY UPDATE expired_at = ?, total_month = total_month + ?, level = ?
|
||||||
`, u.GetID(db), date, month, level, date, month, level)
|
`, u.GetID(db), date, month, level, date, month, level)
|
||||||
@ -101,7 +102,7 @@ func (u *User) DowngradePlan(db *sql.DB, target int) error {
|
|||||||
// ceil expired time
|
// ceil expired time
|
||||||
expiredAt := now.Add(time.Duration(stamp)*time.Second).AddDate(0, 0, -1)
|
expiredAt := now.Add(time.Duration(stamp)*time.Second).AddDate(0, 0, -1)
|
||||||
date := utils.ConvertSqlTime(expiredAt)
|
date := utils.ConvertSqlTime(expiredAt)
|
||||||
_, err := db.Exec("UPDATE subscription SET level = ?, expired_at = ? WHERE user_id = ?", target, date, u.GetID(db))
|
_, err := globals.ExecDb(db, "UPDATE subscription SET level = ?, expired_at = ? WHERE user_id = ?", target, date, u.GetID(db))
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -118,7 +119,7 @@ func (u *User) CountUpgradePrice(db *sql.DB, target int) float32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) SetSubscriptionLevel(db *sql.DB, level int) bool {
|
func (u *User) SetSubscriptionLevel(db *sql.DB, level int) bool {
|
||||||
_, err := db.Exec("UPDATE subscription SET level = ? WHERE user_id = ?", level, u.GetID(db))
|
_, err := globals.ExecDb(db, "UPDATE subscription SET level = ? WHERE user_id = ?", level, u.GetID(db))
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BIN
chatnio.db
Normal file
BIN
chatnio.db
Normal file
Binary file not shown.
@ -7,7 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func UpdateRootCommand(args []string) {
|
func UpdateRootCommand(args []string) {
|
||||||
db := connection.ConnectMySQL()
|
db := connection.ConnectDatabase()
|
||||||
cache := connection.ConnectRedis()
|
cache := connection.ConnectRedis()
|
||||||
|
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func CreateInvitationCommand(args []string) {
|
func CreateInvitationCommand(args []string) {
|
||||||
db := connection.ConnectMySQL()
|
db := connection.ConnectDatabase()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
t = GetArgString(args, 0)
|
t = GetArgString(args, 0)
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func CreateTokenCommand(args []string) {
|
func CreateTokenCommand(args []string) {
|
||||||
db := connection.ConnectMySQL()
|
db := connection.ConnectDatabase()
|
||||||
id, _ := strconv.Atoi(args[0])
|
id, _ := strconv.Atoi(args[0])
|
||||||
|
|
||||||
user := auth.GetUserById(db, int64(id))
|
user := auth.GetUserById(db, int64(id))
|
||||||
|
@ -6,20 +6,32 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
var DB *sql.DB
|
var DB *sql.DB
|
||||||
|
|
||||||
func InitMySQLSafe() *sql.DB {
|
func InitMySQLSafe() *sql.DB {
|
||||||
ConnectMySQL()
|
ConnectDatabase()
|
||||||
|
|
||||||
// using DB as a global variable to point to the latest db connection
|
// using DB as a global variable to point to the latest db connection
|
||||||
MysqlWorker(DB)
|
MysqlWorker(DB)
|
||||||
return DB
|
return DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConnectMySQL() *sql.DB {
|
func getConn() *sql.DB {
|
||||||
|
if viper.GetString("mysql.host") == "" {
|
||||||
|
globals.SqliteEngine = true
|
||||||
|
globals.Warn("[connection] mysql host is not set, using sqlite (chatnio.db)")
|
||||||
|
db, err := sql.Open("sqlite3", "chatnio.db")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
// connect to MySQL
|
// connect to MySQL
|
||||||
db, err := sql.Open("mysql", fmt.Sprintf(
|
db, err := sql.Open("mysql", fmt.Sprintf(
|
||||||
"%s:%s@tcp(%s:%d)/%s",
|
"%s:%s@tcp(%s:%d)/%s",
|
||||||
@ -29,6 +41,7 @@ func ConnectMySQL() *sql.DB {
|
|||||||
viper.GetInt("mysql.port"),
|
viper.GetInt("mysql.port"),
|
||||||
viper.GetString("mysql.db"),
|
viper.GetString("mysql.db"),
|
||||||
))
|
))
|
||||||
|
|
||||||
if pingErr := db.Ping(); err != nil || pingErr != nil {
|
if pingErr := db.Ping(); err != nil || pingErr != nil {
|
||||||
errMsg := utils.Multi[string](err != nil, utils.GetError(err), utils.GetError(pingErr)) // err.Error() may contain nil pointer
|
errMsg := utils.Multi[string](err != nil, utils.GetError(err), utils.GetError(pingErr)) // err.Error() may contain nil pointer
|
||||||
globals.Warn(
|
globals.Warn(
|
||||||
@ -40,11 +53,16 @@ func ConnectMySQL() *sql.DB {
|
|||||||
utils.Sleep(5000)
|
utils.Sleep(5000)
|
||||||
db.Close()
|
db.Close()
|
||||||
|
|
||||||
return ConnectMySQL()
|
return getConn()
|
||||||
} else {
|
|
||||||
globals.Debug(fmt.Sprintf("[connection] connected to mysql server (host: %s)", viper.GetString("mysql.host")))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
globals.Debug(fmt.Sprintf("[connection] connected to mysql server (host: %s)", viper.GetString("mysql.host")))
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnectDatabase() *sql.DB {
|
||||||
|
db := getConn()
|
||||||
|
|
||||||
db.SetMaxOpenConns(512)
|
db.SetMaxOpenConns(512)
|
||||||
db.SetMaxIdleConns(64)
|
db.SetMaxIdleConns(64)
|
||||||
|
|
||||||
@ -72,7 +90,7 @@ func ConnectMySQL() *sql.DB {
|
|||||||
func InitRootUser(db *sql.DB) {
|
func InitRootUser(db *sql.DB) {
|
||||||
// create root user if totally empty
|
// create root user if totally empty
|
||||||
var count int
|
var count int
|
||||||
err := db.QueryRow("SELECT COUNT(*) FROM auth").Scan(&count)
|
err := globals.QueryRowDb(db, "SELECT COUNT(*) FROM auth").Scan(&count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
globals.Warn(fmt.Sprintf("[service] failed to query user count: %s", err.Error()))
|
globals.Warn(fmt.Sprintf("[service] failed to query user count: %s", err.Error()))
|
||||||
return
|
return
|
||||||
@ -80,7 +98,7 @@ func InitRootUser(db *sql.DB) {
|
|||||||
|
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
globals.Debug("[service] no user found, creating root user (username: root, password: chatnio123456, email: root@example.com)")
|
globals.Debug("[service] no user found, creating root user (username: root, password: chatnio123456, email: root@example.com)")
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
INSERT INTO auth (username, password, email, is_admin, bind_id, token)
|
INSERT INTO auth (username, password, email, is_admin, bind_id, token)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
`, "root", utils.Sha2Encrypt("chatnio123456"), "root@example.com", true, 0, "root")
|
`, "root", utils.Sha2Encrypt("chatnio123456"), "root@example.com", true, 0, "root")
|
||||||
@ -93,7 +111,7 @@ func InitRootUser(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateUserTable(db *sql.DB) {
|
func CreateUserTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS auth (
|
CREATE TABLE IF NOT EXISTS auth (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
bind_id INT UNIQUE,
|
bind_id INT UNIQUE,
|
||||||
@ -113,7 +131,7 @@ func CreateUserTable(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreatePackageTable(db *sql.DB) {
|
func CreatePackageTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS package (
|
CREATE TABLE IF NOT EXISTS package (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
user_id INT,
|
user_id INT,
|
||||||
@ -129,7 +147,7 @@ func CreatePackageTable(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateQuotaTable(db *sql.DB) {
|
func CreateQuotaTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS quota (
|
CREATE TABLE IF NOT EXISTS quota (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
user_id INT UNIQUE,
|
user_id INT UNIQUE,
|
||||||
@ -146,7 +164,7 @@ func CreateQuotaTable(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateConversationTable(db *sql.DB) {
|
func CreateConversationTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS conversation (
|
CREATE TABLE IF NOT EXISTS conversation (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
user_id INT,
|
user_id INT,
|
||||||
@ -164,7 +182,7 @@ func CreateConversationTable(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateMaskTable(db *sql.DB) {
|
func CreateMaskTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS mask (
|
CREATE TABLE IF NOT EXISTS mask (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
user_id INT,
|
user_id INT,
|
||||||
@ -184,7 +202,7 @@ func CreateMaskTable(db *sql.DB) {
|
|||||||
|
|
||||||
func CreateSharingTable(db *sql.DB) {
|
func CreateSharingTable(db *sql.DB) {
|
||||||
// refs is an array of message id, separated by comma (-1 means all messages)
|
// refs is an array of message id, separated by comma (-1 means all messages)
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS sharing (
|
CREATE TABLE IF NOT EXISTS sharing (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
hash CHAR(32) UNIQUE,
|
hash CHAR(32) UNIQUE,
|
||||||
@ -201,7 +219,7 @@ func CreateSharingTable(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateSubscriptionTable(db *sql.DB) {
|
func CreateSubscriptionTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS subscription (
|
CREATE TABLE IF NOT EXISTS subscription (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
level INT DEFAULT 1,
|
level INT DEFAULT 1,
|
||||||
@ -220,7 +238,7 @@ func CreateSubscriptionTable(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateApiKeyTable(db *sql.DB) {
|
func CreateApiKeyTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS apikey (
|
CREATE TABLE IF NOT EXISTS apikey (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
user_id INT UNIQUE,
|
user_id INT UNIQUE,
|
||||||
@ -235,7 +253,7 @@ func CreateApiKeyTable(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateInvitationTable(db *sql.DB) {
|
func CreateInvitationTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS invitation (
|
CREATE TABLE IF NOT EXISTS invitation (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
code VARCHAR(255) UNIQUE,
|
code VARCHAR(255) UNIQUE,
|
||||||
@ -255,7 +273,7 @@ func CreateInvitationTable(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateRedeemTable(db *sql.DB) {
|
func CreateRedeemTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS redeem (
|
CREATE TABLE IF NOT EXISTS redeem (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
code VARCHAR(255) UNIQUE,
|
code VARCHAR(255) UNIQUE,
|
||||||
@ -271,7 +289,7 @@ func CreateRedeemTable(db *sql.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateBroadcastTable(db *sql.DB) {
|
func CreateBroadcastTable(db *sql.DB) {
|
||||||
_, err := db.Exec(`
|
_, err := globals.ExecDb(db, `
|
||||||
CREATE TABLE IF NOT EXISTS broadcast (
|
CREATE TABLE IF NOT EXISTS broadcast (
|
||||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||||
poster_id INT,
|
poster_id INT,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package connection
|
package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/globals"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -27,10 +28,14 @@ func checkSqlError(_ sql.Result, err error) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func execSql(db *sql.DB, sql string, args ...interface{}) error {
|
func execSql(db *sql.DB, sql string, args ...interface{}) error {
|
||||||
return checkSqlError(db.Exec(sql, args...))
|
return checkSqlError(globals.ExecDb(db, sql, args...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func doMigration(db *sql.DB) error {
|
func doMigration(db *sql.DB) error {
|
||||||
|
if globals.SqliteEngine {
|
||||||
|
return doSqliteMigration(db)
|
||||||
|
}
|
||||||
|
|
||||||
// v3.10 migration
|
// v3.10 migration
|
||||||
|
|
||||||
// update `quota`, `used` field in `quota` table
|
// update `quota`, `used` field in `quota` table
|
||||||
@ -54,3 +59,9 @@ func doMigration(db *sql.DB) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func doSqliteMigration(db *sql.DB) error {
|
||||||
|
// v3.10 added sqlite support, no migration needed before this version
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -12,7 +12,7 @@ func MysqlWorker(db *sql.DB) {
|
|||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
if db == nil || db.Ping() != nil {
|
if db == nil || db.Ping() != nil {
|
||||||
db = ConnectMySQL()
|
db = ConnectDatabase()
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(tick)
|
time.Sleep(tick)
|
||||||
|
77
globals/sql.go
Normal file
77
globals/sql.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package globals
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SqliteEngine = false
|
||||||
|
|
||||||
|
type batch struct {
|
||||||
|
Old string
|
||||||
|
New string
|
||||||
|
Regex bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func batchReplace(sql string, batch []batch) string {
|
||||||
|
for _, item := range batch {
|
||||||
|
if item.Regex {
|
||||||
|
sql = regexp.MustCompile(item.Old).ReplaceAllString(sql, item.New)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = strings.ReplaceAll(sql, item.Old, item.New)
|
||||||
|
}
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
func MultiSql(mysqlSql string, sqliteSql string) string {
|
||||||
|
if SqliteEngine {
|
||||||
|
return sqliteSql
|
||||||
|
}
|
||||||
|
return mysqlSql
|
||||||
|
}
|
||||||
|
|
||||||
|
func PreflightSql(sql string) string {
|
||||||
|
if SqliteEngine {
|
||||||
|
sql = batchReplace(sql, []batch{
|
||||||
|
// KEYWORD REPLACEMENT
|
||||||
|
{`INT `, `INTEGER `, false},
|
||||||
|
{` AUTO_INCREMENT`, ` AUTOINCREMENT`, false},
|
||||||
|
{`DATETIME`, `TEXT`, false},
|
||||||
|
{`DECIMAL`, `REAL`, false},
|
||||||
|
{`MEDIUMTEXT`, `TEXT`, false},
|
||||||
|
{`VARCHAR`, `TEXT`, false},
|
||||||
|
|
||||||
|
// TEXT(65535) -> TEXT, REAL(10,2) -> REAL
|
||||||
|
{`TEXT\(\d+\)`, `TEXT`, true},
|
||||||
|
{`REAL\(\d+,\d+\)`, `REAL`, true},
|
||||||
|
|
||||||
|
// UNIQUE KEY -> UNIQUE
|
||||||
|
{`UNIQUE KEY`, `UNIQUE`, false},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExecDb(db *sql.DB, sql string, args ...interface{}) (sql.Result, error) {
|
||||||
|
sql = PreflightSql(sql)
|
||||||
|
return db.Exec(sql, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepareDb(db *sql.DB, sql string) (*sql.Stmt, error) {
|
||||||
|
sql = PreflightSql(sql)
|
||||||
|
return db.Prepare(sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
func QueryDb(db *sql.DB, sql string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
sql = PreflightSql(sql)
|
||||||
|
return db.Query(sql, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func QueryRowDb(db *sql.DB, sql string, args ...interface{}) *sql.Row {
|
||||||
|
sql = PreflightSql(sql)
|
||||||
|
return db.QueryRow(sql, args...)
|
||||||
|
}
|
1
go.mod
1
go.mod
@ -53,6 +53,7 @@ require (
|
|||||||
github.com/leodido/go-urn v1.2.4 // indirect
|
github.com/leodido/go-urn v1.2.4 // indirect
|
||||||
github.com/magiconair/properties v1.8.7 // indirect
|
github.com/magiconair/properties v1.8.7 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -403,6 +403,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
|
|||||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||||
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
|
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
|
||||||
|
@ -2,6 +2,7 @@ package broadcast
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chat/auth"
|
"chat/auth"
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"context"
|
"context"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -11,7 +12,7 @@ func createBroadcast(c *gin.Context, user *auth.User, content string) error {
|
|||||||
db := utils.GetDBFromContext(c)
|
db := utils.GetDBFromContext(c)
|
||||||
cache := utils.GetCacheFromContext(c)
|
cache := utils.GetCacheFromContext(c)
|
||||||
|
|
||||||
if _, err := db.Exec(`INSERT INTO broadcast (poster_id, content) VALUES (?, ?)`, user.GetID(db), content); err != nil {
|
if _, err := globals.ExecDb(db, `INSERT INTO broadcast (poster_id, content) VALUES (?, ?)`, user.GetID(db), content); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ func getBroadcastList(c *gin.Context) ([]Info, error) {
|
|||||||
db := utils.GetDBFromContext(c)
|
db := utils.GetDBFromContext(c)
|
||||||
|
|
||||||
var broadcastList []Info
|
var broadcastList []Info
|
||||||
rows, err := db.Query(`
|
rows, err := globals.QueryDb(db, `
|
||||||
SELECT broadcast.id, broadcast.content, auth.username, broadcast.created_at
|
SELECT broadcast.id, broadcast.content, auth.username, broadcast.created_at
|
||||||
FROM broadcast
|
FROM broadcast
|
||||||
INNER JOIN auth ON broadcast.poster_id = auth.id
|
INNER JOIN auth ON broadcast.poster_id = auth.id
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package broadcast
|
package broadcast
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"context"
|
"context"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -18,7 +19,7 @@ func getLatestBroadcast(c *gin.Context) *Broadcast {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var broadcast Broadcast
|
var broadcast Broadcast
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT id, content FROM broadcast ORDER BY id DESC LIMIT 1;
|
SELECT id, content FROM broadcast ORDER BY id DESC LIMIT 1;
|
||||||
`).Scan(&broadcast.Index, &broadcast.Content); err != nil {
|
`).Scan(&broadcast.Index, &broadcast.Content); err != nil {
|
||||||
return nil
|
return nil
|
||||||
|
@ -26,14 +26,14 @@ func (m *Mask) Save(db *sql.DB, user *auth.User) error {
|
|||||||
userId := user.GetID(db)
|
userId := user.GetID(db)
|
||||||
|
|
||||||
if m.Id == -1 {
|
if m.Id == -1 {
|
||||||
_, err := db.Exec(
|
_, err := globals.ExecDb(db,
|
||||||
"INSERT INTO mask (mask.user_id, avatar, name, description, context) VALUES (?, ?, ?, ?, ?)",
|
"INSERT INTO mask (mask.user_id, avatar, name, description, context) VALUES (?, ?, ?, ?, ?)",
|
||||||
userId, m.Avatar, m.Name, m.Description, utils.Marshal(m.Context),
|
userId, m.Avatar, m.Name, m.Description, utils.Marshal(m.Context),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := db.Exec(
|
_, err := globals.ExecDb(db,
|
||||||
"UPDATE mask SET avatar = ?, name = ?, description = ?, context = ? WHERE id = ? AND user_id = ?",
|
"UPDATE mask SET avatar = ?, name = ?, description = ?, context = ? WHERE id = ? AND user_id = ?",
|
||||||
m.Avatar, m.Name, m.Description, utils.Marshal(m.Context), m.Id, userId,
|
m.Avatar, m.Name, m.Description, utils.Marshal(m.Context), m.Id, userId,
|
||||||
)
|
)
|
||||||
@ -41,12 +41,12 @@ func (m *Mask) Save(db *sql.DB, user *auth.User) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mask) Delete(db *sql.DB, user *auth.User) error {
|
func (m *Mask) Delete(db *sql.DB, user *auth.User) error {
|
||||||
_, err := db.Exec("DELETE FROM mask WHERE id = ? AND user_id = ?", m.Id, user.GetID(db))
|
_, err := globals.ExecDb(db, "DELETE FROM mask WHERE id = ? AND user_id = ?", m.Id, user.GetID(db))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadMask(db *sql.DB, user *auth.User) ([]Mask, error) {
|
func LoadMask(db *sql.DB, user *auth.User) ([]Mask, error) {
|
||||||
rows, err := db.Query(`
|
rows, err := globals.QueryDb(db, `
|
||||||
SELECT id, avatar, name, description, context
|
SELECT id, avatar, name, description, context
|
||||||
FROM mask WHERE user_id = ?
|
FROM mask WHERE user_id = ?
|
||||||
ORDER BY id DESC
|
ORDER BY id DESC
|
||||||
|
@ -49,7 +49,7 @@ func ShareConversation(db *sql.DB, user *auth.User, id int64, refs []int) (strin
|
|||||||
Refs: refs,
|
Refs: refs,
|
||||||
})
|
})
|
||||||
|
|
||||||
if _, err := db.Exec(`
|
if _, err := globals.ExecDb(db, `
|
||||||
INSERT INTO sharing (hash, user_id, conversation_id, refs) VALUES (?, ?, ?, ?)
|
INSERT INTO sharing (hash, user_id, conversation_id, refs) VALUES (?, ?, ?, ?)
|
||||||
ON DUPLICATE KEY UPDATE refs = ?
|
ON DUPLICATE KEY UPDATE refs = ?
|
||||||
`, hash, user.GetID(db), id, ref, ref); err != nil {
|
`, hash, user.GetID(db), id, ref, ref); err != nil {
|
||||||
@ -86,7 +86,7 @@ func ListSharedConversation(db *sql.DB, user *auth.User) []SharedPreviewForm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
id := user.GetID(db)
|
id := user.GetID(db)
|
||||||
rows, err := db.Query(`
|
rows, err := globals.QueryDb(db, `
|
||||||
SELECT conversation.conversation_name, conversation.conversation_id, sharing.updated_at, sharing.hash
|
SELECT conversation.conversation_name, conversation.conversation_id, sharing.updated_at, sharing.hash
|
||||||
FROM sharing
|
FROM sharing
|
||||||
INNER JOIN conversation
|
INNER JOIN conversation
|
||||||
@ -120,7 +120,7 @@ func DeleteSharedConversation(db *sql.DB, user *auth.User, hash string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
id := user.GetID(db)
|
id := user.GetID(db)
|
||||||
if _, err := db.Exec(`
|
if _, err := globals.ExecDb(db, `
|
||||||
DELETE FROM sharing WHERE user_id = ? AND hash = ?
|
DELETE FROM sharing WHERE user_id = ? AND hash = ?
|
||||||
`, id, hash); err != nil {
|
`, id, hash); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -136,7 +136,7 @@ func GetSharedConversation(db *sql.DB, hash string) (*SharedForm, error) {
|
|||||||
ref string
|
ref string
|
||||||
updated []uint8
|
updated []uint8
|
||||||
)
|
)
|
||||||
if err := db.QueryRow(`
|
if err := globals.QueryRowDb(db, `
|
||||||
SELECT auth.username, sharing.refs, sharing.updated_at, conversation.conversation_name,
|
SELECT auth.username, sharing.refs, sharing.updated_at, conversation.conversation_name,
|
||||||
sharing.user_id, sharing.conversation_id
|
sharing.user_id, sharing.conversation_id
|
||||||
FROM sharing
|
FROM sharing
|
||||||
|
@ -20,7 +20,7 @@ func (c *Conversation) SaveConversation(db *sql.DB) bool {
|
|||||||
ON DUPLICATE KEY UPDATE conversation_name = VALUES(conversation_name), data = VALUES(data)
|
ON DUPLICATE KEY UPDATE conversation_name = VALUES(conversation_name), data = VALUES(data)
|
||||||
`
|
`
|
||||||
|
|
||||||
stmt, err := db.Prepare(query)
|
stmt, err := globals.PrepareDb(db, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -40,7 +40,7 @@ func (c *Conversation) SaveConversation(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
func GetConversationLengthByUserID(db *sql.DB, userId int64) int64 {
|
func GetConversationLengthByUserID(db *sql.DB, userId int64) int64 {
|
||||||
var length int64
|
var length int64
|
||||||
err := db.QueryRow("SELECT MAX(conversation_id) FROM conversation WHERE user_id = ?", userId).Scan(&length)
|
err := globals.QueryRowDb(db, "SELECT MAX(conversation_id) FROM conversation WHERE user_id = ?", userId).Scan(&length)
|
||||||
if err != nil || length < 0 {
|
if err != nil || length < 0 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@ -57,7 +57,7 @@ func LoadConversation(db *sql.DB, userId int64, conversationId int64) *Conversat
|
|||||||
data string
|
data string
|
||||||
model interface{}
|
model interface{}
|
||||||
)
|
)
|
||||||
err := db.QueryRow(`
|
err := globals.QueryRowDb(db, `
|
||||||
SELECT conversation_name, model, data FROM conversation
|
SELECT conversation_name, model, data FROM conversation
|
||||||
WHERE user_id = ? AND conversation_id = ?
|
WHERE user_id = ? AND conversation_id = ?
|
||||||
`, userId, conversationId).Scan(&conversation.Name, &model, &data)
|
`, userId, conversationId).Scan(&conversation.Name, &model, &data)
|
||||||
@ -81,7 +81,7 @@ func LoadConversation(db *sql.DB, userId int64, conversationId int64) *Conversat
|
|||||||
|
|
||||||
func LoadConversationList(db *sql.DB, userId int64) []Conversation {
|
func LoadConversationList(db *sql.DB, userId int64) []Conversation {
|
||||||
var conversationList []Conversation
|
var conversationList []Conversation
|
||||||
rows, err := db.Query(`
|
rows, err := globals.QueryDb(db, `
|
||||||
SELECT conversation_id, conversation_name FROM conversation WHERE user_id = ?
|
SELECT conversation_id, conversation_name FROM conversation WHERE user_id = ?
|
||||||
ORDER BY conversation_id DESC LIMIT 100
|
ORDER BY conversation_id DESC LIMIT 100
|
||||||
`, userId)
|
`, userId)
|
||||||
@ -108,7 +108,7 @@ func LoadConversationList(db *sql.DB, userId int64) []Conversation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conversation) DeleteConversation(db *sql.DB) bool {
|
func (c *Conversation) DeleteConversation(db *sql.DB) bool {
|
||||||
_, err := db.Exec("DELETE FROM conversation WHERE user_id = ? AND conversation_id = ?", c.UserID, c.Id)
|
_, err := globals.ExecDb(db, "DELETE FROM conversation WHERE user_id = ? AND conversation_id = ?", c.UserID, c.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -116,6 +116,6 @@ func (c *Conversation) DeleteConversation(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DeleteAllConversations(db *sql.DB, user auth.User) error {
|
func DeleteAllConversations(db *sql.DB, user auth.User) error {
|
||||||
_, err := db.Exec("DELETE FROM conversation WHERE user_id = ?", user.GetID(db))
|
_, err := globals.ExecDb(db, "DELETE FROM conversation WHERE user_id = ?", user.GetID(db))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user