coai/auth/invitation.go
2023-10-25 12:53:16 +08:00

114 lines
2.7 KiB
Go

package auth
import (
"chat/utils"
"database/sql"
"errors"
"fmt"
)
type Invitation struct {
Id int64 `json:"id"`
Code string `json:"code"`
Quota float32 `json:"quota"`
Type string `json:"type"`
Used bool `json:"used"`
UsedId int64 `json:"used_id"`
}
func GenerateInvitations(db *sql.DB, num int, quota float32, t string) ([]string, error) {
arr := make([]string, 0)
idx := 0
for idx < num {
code := fmt.Sprintf("%s-%s", t, utils.GenerateChar(24))
if err := CreateInvitationCode(db, code, quota, t); err != nil {
// unique constraint
if errors.Is(err, sql.ErrNoRows) {
continue
}
return nil, fmt.Errorf("failed to generate code: %w", err)
}
arr = append(arr, code)
idx++
}
return arr, nil
}
func CreateInvitationCode(db *sql.DB, code string, quota float32, t string) error {
_, err := db.Exec(`
INSERT INTO invitation (code, quota, type)
VALUES (?, ?, ?)
`, code, quota, t)
return err
}
func GetInvitation(db *sql.DB, code string) (*Invitation, error) {
row := db.QueryRow(`
SELECT id, code, quota, type, used, used_id
FROM invitation
WHERE code = ?
`, code)
var invitation Invitation
var id sql.NullInt64
err := row.Scan(&invitation.Id, &invitation.Code, &invitation.Quota, &invitation.Type, &invitation.Used, &id)
if id.Valid {
invitation.UsedId = id.Int64
}
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("invitation code not found")
}
return nil, fmt.Errorf("failed to get invitation: %w", err)
}
return &invitation, nil
}
func (i *Invitation) IsUsed() bool {
return i.Used
}
func (i *Invitation) Use(db *sql.DB, userId int64) error {
_, err := db.Exec(`
UPDATE invitation SET used = TRUE, used_id = ? WHERE id = ?
`, userId, i.Id)
return err
}
func (i *Invitation) GetQuota() float32 {
return i.Quota
}
func (i *Invitation) UseInvitation(db *sql.DB, user User) error {
if i.IsUsed() {
return fmt.Errorf("this invitation has been used")
}
if err := i.Use(db, user.GetID(db)); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("invitation code not found")
} else if errors.Is(err, sql.ErrTxDone) {
return fmt.Errorf("transaction has been closed")
}
return fmt.Errorf("failed to use invitation: %w", err)
}
if !user.IncreaseQuota(db, i.GetQuota()) {
return fmt.Errorf("failed to increase quota for user")
}
return nil
}
func (u *User) UseInvitation(db *sql.DB, code string) (float32, error) {
if invitation, err := GetInvitation(db, code); err != nil {
return 0, err
} else {
if err := invitation.UseInvitation(db, *u); err != nil {
return 0, fmt.Errorf("failed to use invitation: %w", err)
}
return invitation.GetQuota(), nil
}
}