coai/connection/database.go
2024-04-10 18:12:14 +08:00

317 lines
7.6 KiB
Go

package connection
import (
"chat/globals"
"chat/utils"
"crypto/tls"
"database/sql"
"fmt"
"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"github.com/spf13/viper"
)
var DB *sql.DB
func InitMySQLSafe() *sql.DB {
ConnectDatabase()
// using DB as a global variable to point to the latest db connection
MysqlWorker(DB)
return DB
}
func getConn() *sql.DB {
if viper.GetString("mysql.host") == "" {
globals.SqliteEngine = true
globals.Warn("[connection] mysql host is not set, using sqlite (~/db/chatnio.db)")
db, err := sql.Open("sqlite3", utils.FileSafe("./db/chatnio.db"))
if err != nil {
panic(err)
}
return db
}
mysqlUrl := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s",
viper.GetString("mysql.user"),
viper.GetString("mysql.password"),
viper.GetString("mysql.host"),
viper.GetInt("mysql.port"),
viper.GetString("mysql.db"),
)
if viper.GetBool("mysql.tls") {
mysql.RegisterTLSConfig("tls", &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: viper.GetString("mysql.host"),
})
mysqlUrl += "?tls=tls"
}
// connect to MySQL
db, err := sql.Open("mysql", mysqlUrl)
if pingErr := db.Ping(); err != nil || pingErr != nil {
errMsg := utils.Multi[string](err != nil, utils.GetError(err), utils.GetError(pingErr)) // err.Error() may contain nil pointer
globals.Warn(
fmt.Sprintf("[connection] failed to connect to mysql server: %s (message: %s), will retry in 5 seconds",
viper.GetString("mysql.host"), errMsg,
),
)
utils.Sleep(5000)
db.Close()
return getConn()
}
globals.Debug(fmt.Sprintf("[connection] connected to mysql server (host: %s)", viper.GetString("mysql.host")))
return db
}
func ConnectDatabase() *sql.DB {
db := getConn()
db.SetMaxOpenConns(512)
db.SetMaxIdleConns(64)
CreateUserTable(db)
CreateConversationTable(db)
CreateMaskTable(db)
CreateSharingTable(db)
CreatePackageTable(db)
CreateQuotaTable(db)
CreateSubscriptionTable(db)
CreateApiKeyTable(db)
CreateInvitationTable(db)
CreateRedeemTable(db)
CreateBroadcastTable(db)
if err := doMigration(db); err != nil {
fmt.Println(fmt.Sprintf("migration error: %s", err))
}
DB = db
return db
}
func InitRootUser(db *sql.DB) {
// create root user if totally empty
var count int
err := globals.QueryRowDb(db, "SELECT COUNT(*) FROM auth").Scan(&count)
if err != nil {
globals.Warn(fmt.Sprintf("[service] failed to query user count: %s", err.Error()))
return
}
if count == 0 {
globals.Debug("[service] no user found, creating root user (username: root, password: chatnio123456, email: root@example.com)")
_, err := globals.ExecDb(db, `
INSERT INTO auth (username, password, email, is_admin, bind_id, token)
VALUES (?, ?, ?, ?, ?, ?)
`, "root", utils.Sha2Encrypt("chatnio123456"), "root@example.com", true, 0, "root")
if err != nil {
globals.Warn(fmt.Sprintf("[service] failed to create root user: %s", err.Error()))
}
} else {
globals.Debug(fmt.Sprintf("[service] %d user(s) found, skip creating root user", count))
}
}
func CreateUserTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS auth (
id INT PRIMARY KEY AUTO_INCREMENT,
bind_id INT UNIQUE,
username VARCHAR(24) UNIQUE,
token VARCHAR(255) NOT NULL,
email VARCHAR(255) UNIQUE,
password VARCHAR(64) NOT NULL,
is_admin BOOLEAN DEFAULT FALSE,
is_banned BOOLEAN DEFAULT FALSE
);
`)
if err != nil {
fmt.Println(err)
}
InitRootUser(db)
}
func CreatePackageTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS package (
id INT PRIMARY KEY AUTO_INCREMENT,
user_id INT,
type VARCHAR(255),
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES auth(id),
UNIQUE KEY (user_id, type)
);
`)
if err != nil {
fmt.Println(err)
}
}
func CreateQuotaTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS quota (
id INT PRIMARY KEY AUTO_INCREMENT,
user_id INT UNIQUE,
quota DECIMAL(24, 6),
used DECIMAL(24, 6),
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES auth(id)
);
`)
if err != nil {
fmt.Println(err)
}
}
func CreateConversationTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS conversation (
id INT PRIMARY KEY AUTO_INCREMENT,
user_id INT,
conversation_id INT,
conversation_name VARCHAR(255),
data MEDIUMTEXT,
model VARCHAR(255) NOT NULL DEFAULT 'gpt-3.5-turbo-0613',
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY (user_id, conversation_id)
);
`)
if err != nil {
fmt.Println(err)
}
}
func CreateMaskTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS mask (
id INT PRIMARY KEY AUTO_INCREMENT,
user_id INT,
avatar VARCHAR(255),
name VARCHAR(255),
description TEXT,
context TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES auth(id)
);
`)
if err != nil {
fmt.Println(err)
}
}
func CreateSharingTable(db *sql.DB) {
// refs is an array of message id, separated by comma (-1 means all messages)
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS sharing (
id INT PRIMARY KEY AUTO_INCREMENT,
hash CHAR(32) UNIQUE,
user_id INT,
conversation_id INT,
refs TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES auth(id)
);
`)
if err != nil {
fmt.Println(err)
}
}
func CreateSubscriptionTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS subscription (
id INT PRIMARY KEY AUTO_INCREMENT,
level INT DEFAULT 1,
user_id INT UNIQUE,
expired_at DATETIME,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
total_month INT DEFAULT 0,
enterprise BOOLEAN DEFAULT FALSE,
FOREIGN KEY (user_id) REFERENCES auth(id)
);
`)
if err != nil {
fmt.Println(err)
}
}
func CreateApiKeyTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS apikey (
id INT PRIMARY KEY AUTO_INCREMENT,
user_id INT UNIQUE,
api_key VARCHAR(255) UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES auth(id)
);
`)
if err != nil {
fmt.Println(err)
}
}
func CreateInvitationTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS invitation (
id INT PRIMARY KEY AUTO_INCREMENT,
code VARCHAR(255) UNIQUE,
quota DECIMAL(16, 4),
type VARCHAR(255),
used BOOLEAN DEFAULT FALSE,
used_id INT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY (used_id, type),
FOREIGN KEY (used_id) REFERENCES auth(id)
);
`)
if err != nil {
fmt.Println(err)
}
}
func CreateRedeemTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS redeem (
id INT PRIMARY KEY AUTO_INCREMENT,
code VARCHAR(255) UNIQUE,
quota DECIMAL(16, 4),
used BOOLEAN DEFAULT FALSE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
`)
if err != nil {
fmt.Println(err)
}
}
func CreateBroadcastTable(db *sql.DB) {
_, err := globals.ExecDb(db, `
CREATE TABLE IF NOT EXISTS broadcast (
id INT PRIMARY KEY AUTO_INCREMENT,
poster_id INT,
content TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (poster_id) REFERENCES auth(id)
);
`)
if err != nil {
fmt.Println(err)
}
}