diff --git a/config.example.yaml b/config.example.yaml index 5302097..791f007 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -4,6 +4,7 @@ mysql: password: chatnio123456 port: 3306 user: root + tls: false redis: host: localhost diff --git a/connection/database.go b/connection/database.go index 7809008..8d0fe78 100644 --- a/connection/database.go +++ b/connection/database.go @@ -3,8 +3,10 @@ package connection import ( "chat/globals" "chat/utils" + "crypto/tls" "database/sql" "fmt" + "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" "github.com/spf13/viper" @@ -32,15 +34,25 @@ func getConn() *sql.DB { return db } - // connect to MySQL - db, err := sql.Open("mysql", fmt.Sprintf( + mysqlUrl := fmt.Sprintf( "%s:%s@tcp(%s:%d)/%s", viper.GetString("mysql.user"), viper.GetString("mysql.password"), viper.GetString("mysql.host"), viper.GetInt("mysql.port"), viper.GetString("mysql.db"), - )) + ) + if viper.GetBool("mysql.tls") { + mysql.RegisterTLSConfig("tls", &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: viper.GetString("mysql.host"), + }) + + mysqlUrl += "?tls=tls" + } + + // connect to MySQL + db, err := sql.Open("mysql", mysqlUrl) if pingErr := db.Ping(); err != nil || pingErr != nil { errMsg := utils.Multi[string](err != nil, utils.GetError(err), utils.GetError(pingErr)) // err.Error() may contain nil pointer