go: Move db into its own package

pull/7/head
Luther Wen Xu 2019-10-12 16:07:10 +07:00
parent 2b5c88059d
commit cc5e832859
Signed by: chanbakjsd
GPG Key ID: B7D77E3E9D102B70
11 changed files with 93 additions and 78 deletions

@ -5,6 +5,8 @@ import (
"strings" "strings"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"TerraOceanBot/db"
) )
var pendingInviteConfirmation = make(map[string]invite) var pendingInviteConfirmation = make(map[string]invite)
@ -30,10 +32,10 @@ func createInvite(s *discordgo.Session, m *discordgo.MessageCreate, command []st
return return
} }
err := setInviteOwner(command[1], m.Author.ID) err := db.SetInviteOwner(command[1], m.Author.ID)
switch err { switch err {
case errAlreadyExists: case db.ErrAlreadyExists:
owner, _ := getInviteOwner(command[1]) owner, _ := db.GetInviteOwner(command[1])
if owner == m.Author.ID { if owner == m.Author.ID {
s.ChannelMessageSend( s.ChannelMessageSend(
m.ChannelID, m.ChannelID,
@ -45,7 +47,7 @@ func createInvite(s *discordgo.Session, m *discordgo.MessageCreate, command []st
m.ChannelID, m.ChannelID,
"这个验证码已被其他会员注册。请使用别的验证码。\nThis validation code is already registered. Please try another one.", "这个验证码已被其他会员注册。请使用别的验证码。\nThis validation code is already registered. Please try another one.",
) )
case errInviteUsed: case db.ErrInviteUsed:
s.ChannelMessageSend( s.ChannelMessageSend(
m.ChannelID, m.ChannelID,
"这个验证码已被使用。请使用别的验证码。\nThis validation code has already been used. Please try another one.", "这个验证码已被使用。请使用别的验证码。\nThis validation code has already been used. Please try another one.",
@ -77,12 +79,12 @@ func checkUseInvite(s *discordgo.Session, m *discordgo.MessageCreate, command []
} }
messageSplit := strings.SplitN(m.Content, " ", 3) messageSplit := strings.SplitN(m.Content, " ", 3)
inviter, err := getInviteOwner(messageSplit[1]) inviter, err := db.GetInviteOwner(messageSplit[1])
if err == errInviteUsed { if err == db.ErrInviteUsed {
sendPrivateMessage(s, m.Author.ID, "该验证码已被使用过。请向验证码制造者要求新的验证码。\nThis validation code has been used before. Please request a new one.") sendPrivateMessage(s, m.Author.ID, "该验证码已被使用过。请向验证码制造者要求新的验证码。\nThis validation code has been used before. Please request a new one.")
return return
} }
if err == errNotFound { if err == db.ErrNotFound {
sendPrivateMessage(s, m.Author.ID, "该验证码不存在。请使用存在的验证码。\nThis validation code doesn't exist. Please use an existing one.") sendPrivateMessage(s, m.Author.ID, "该验证码不存在。请使用存在的验证码。\nThis validation code doesn't exist. Please use an existing one.")
return return
} }
@ -92,7 +94,7 @@ func checkUseInvite(s *discordgo.Session, m *discordgo.MessageCreate, command []
return return
} }
useInvite(messageSplit[1]) db.UseInvite(messageSplit[1])
member, err := s.GuildMember(guildID, inviter) member, err := s.GuildMember(guildID, inviter)
if err != nil { if err != nil {
sendPrivateMessage(s, m.Author.ID, "验证码制造者不是会员。\nValidation code creator is no longer a member.") sendPrivateMessage(s, m.Author.ID, "验证码制造者不是会员。\nValidation code creator is no longer a member.")
@ -159,7 +161,7 @@ func checkForInvite(s *discordgo.Session, r *discordgo.MessageReactionAdd) {
) )
return return
} }
id, err := createInviteVote(msg.ID, invite.User, invite.Reason) id, err := db.CreateInviteVote(msg.ID, invite.User, invite.Reason)
if err != nil { if err != nil {
sendPrivateMessage(s, r.UserID, "创造投票失败。Failed to create vote.") sendPrivateMessage(s, r.UserID, "创造投票失败。Failed to create vote.")
auditLog(s, auditLog(s,

@ -5,6 +5,8 @@ import (
"strings" "strings"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"TerraOceanBot/db"
) )
var trustMessage = make(map[string]string) var trustMessage = make(map[string]string)
@ -84,14 +86,14 @@ func checkForTrustUpdate(s *discordgo.Session, r *discordgo.MessageReactionAdd)
case emojiFive: case emojiFive:
value = 5 value = 5
case emojiCheck: case emojiCheck:
value = nuclearOptionVote value = db.NuclearOptionVote
default: default:
auditLog(s, "Reaction "+r.Emoji.Name+" was added to trust message ID "+r.MessageID) auditLog(s, "Reaction "+r.Emoji.Name+" was added to trust message ID "+r.MessageID)
return return
} }
err := updateTrust(r.UserID, target, value) err := updateTrust(r.UserID, target, value)
if err == errRecentlyChanged { if err == db.ErrRecentlyChanged {
s.ChannelMessageSend(r.ChannelID, "你在一个月内有设定过该名玩家的分数。你这次的更动没有被记录。\nYou have changed your score for this player within the last month. Your change was not recorded.") s.ChannelMessageSend(r.ChannelID, "你在一个月内有设定过该名玩家的分数。你这次的更动没有被记录。\nYou have changed your score for this player within the last month. Your change was not recorded.")
return return
} }

@ -6,6 +6,8 @@ import (
"strings" "strings"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"TerraOceanBot/db"
) )
type voteType struct { type voteType struct {
@ -41,8 +43,8 @@ func checkForVote(s *discordgo.Session, r *discordgo.MessageReactionAdd) {
if r.UserID == s.State.User.ID { if r.UserID == s.State.User.ID {
return return
} }
voteID, err := getVoteFromMessageID(r.MessageID) voteID, err := db.GetVoteFromMessageID(r.MessageID)
if err == errNotFound { if err == db.ErrNotFound {
return return
} }
if err != nil { if err != nil {
@ -55,7 +57,7 @@ func checkForVote(s *discordgo.Session, r *discordgo.MessageReactionAdd) {
var value int var value int
switch r.Emoji.Name { switch r.Emoji.Name {
case emojiX: case emojiX:
value = forceRejectionVote value = db.ForceRejectionVote
case emojiOne: case emojiOne:
value = 1 value = 1
case emojiTwo: case emojiTwo:
@ -67,26 +69,26 @@ func checkForVote(s *discordgo.Session, r *discordgo.MessageReactionAdd) {
case emojiFive: case emojiFive:
value = 5 value = 5
case emojiCheck: case emojiCheck:
value = nuclearOptionVote value = db.NuclearOptionVote
default: default:
auditLog(s, "Reaction "+r.Emoji.Name+" was added to vote #"+strconv.Itoa(voteID)) auditLog(s, "Reaction "+r.Emoji.Name+" was added to vote #"+strconv.Itoa(voteID))
return return
} }
err = updateVote(voteID, r.UserID, value) err = db.UpdateVote(voteID, r.UserID, value)
if err == errForceRejectionVoteReuse { if err == db.ErrForceRejectionVoteReuse {
sendPrivateMessage(s, r.UserID, "您在这个月内已使用过:x:。请选择其他选项。\nYou have used :x: this month. Please choose another option.") sendPrivateMessage(s, r.UserID, "您在这个月内已使用过:x:。请选择其他选项。\nYou have used :x: this month. Please choose another option.")
return return
} }
if err == errVoteIsOver { if err == db.ErrVoteIsOver {
sendPrivateMessage(s, r.UserID, "这个投票已结束。你投的票没有被记录。\nThe vote is over so your vote is not recorded.") sendPrivateMessage(s, r.UserID, "这个投票已结束。你投的票没有被记录。\nThe vote is over so your vote is not recorded.")
return return
} }
var voteName string var voteName string
if err == nil { if err == nil {
voteName, err = getVoteName(voteID) voteName, err = db.GetVoteName(voteID)
} }
if err != nil { if err != nil {
@ -126,7 +128,7 @@ func voteSuggestion(s *discordgo.Session, m *discordgo.MessageCreate) {
) )
return return
} }
id, err := createCustomVote(msg.ID, args[1]) id, err := db.CreateCustomVote(msg.ID, args[1])
if err != nil { if err != nil {
sendPrivateMessage(s, m.Author.ID, "创造投票失败。Failed to create vote.") sendPrivateMessage(s, m.Author.ID, "创造投票失败。Failed to create vote.")
auditLog(s, auditLog(s,

@ -0,0 +1,6 @@
package common
type VoteChoice struct {
UserID string
Value int
}

@ -1,4 +1,4 @@
package main package db
import ( import (
"database/sql" "database/sql"
@ -10,8 +10,8 @@ import (
var db *sql.DB var db *sql.DB
var errNotFound = errors.New("db: the requested entry was not found") var ErrNotFound = errors.New("db: the requested entry was not found")
var errAlreadyExists = errors.New("db: attempting to write to an entry that already exists") var ErrAlreadyExists = errors.New("db: attempting to write to an entry that already exists")
func init() { func init() {
var err error var err error

@ -1,10 +1,10 @@
package main package db
import "errors" import "errors"
var errInviteUsed = errors.New("db: the requested invite has already been used") var ErrInviteUsed = errors.New("db: the requested invite has already been used")
func getInviteOwner(inviteCode string) (string, error) { func GetInviteOwner(inviteCode string) (string, error) {
rows, err := db.Query("SELECT owner, used FROM invite WHERE code=?", inviteCode) rows, err := db.Query("SELECT owner, used FROM invite WHERE code=?", inviteCode)
if err != nil { if err != nil {
return "", err return "", err
@ -18,27 +18,27 @@ func getInviteOwner(inviteCode string) (string, error) {
return "", err return "", err
} }
if used { if used {
return "", errInviteUsed return "", ErrInviteUsed
} }
return owner, nil return owner, nil
} }
return "", errNotFound return "", ErrNotFound
} }
func setInviteOwner(inviteCode, owner string) error { func SetInviteOwner(inviteCode, owner string) error {
_, err := getInviteOwner(inviteCode) _, err := GetInviteOwner(inviteCode)
if err == nil { if err == nil {
return errAlreadyExists return ErrAlreadyExists
} }
if err != errNotFound { if err != ErrNotFound {
return err return err
} }
_, err = db.Exec("INSERT INTO invite(code, owner, used) VALUES(?,?,?)", inviteCode, owner, false) _, err = db.Exec("INSERT INTO invite(code, owner, used) VALUES(?,?,?)", inviteCode, owner, false)
return err return err
} }
func useInvite(inviteCode string) error { func UseInvite(inviteCode string) error {
owner, err := getInviteOwner(inviteCode) owner, err := GetInviteOwner(inviteCode)
if err != nil { if err != nil {
return err return err
} }

@ -1,13 +1,13 @@
package main package db
import ( import (
"errors" "errors"
"time" "time"
) )
var errRecentlyChanged = errors.New("db: attempt to change the trust for a user twice in a month") var ErrRecentlyChanged = errors.New("db: attempt to change the trust for a user twice in a month")
func getTrustVote(targetID string) ([]int, error) { func GetTrustVote(targetID string) ([]int, error) {
rows, err := db.Query("SELECT trust FROM trustVote WHERE targetUser=?", targetID) rows, err := db.Query("SELECT trust FROM trustVote WHERE targetUser=?", targetID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -25,7 +25,7 @@ func getTrustVote(targetID string) ([]int, error) {
return array, nil return array, nil
} }
func updateDbTrust(sourceID, targetID string, voteValue int, force bool) error { func UpdateTrust(sourceID, targetID string, voteValue int, force bool) error {
if !force { if !force {
rows, err := db.Query("SELECT lastUpdated FROM trustVote WHERE sourceUser=? AND targetUser=?", sourceID, targetID) rows, err := db.Query("SELECT lastUpdated FROM trustVote WHERE sourceUser=? AND targetUser=?", sourceID, targetID)
if err != nil { if err != nil {
@ -37,7 +37,7 @@ func updateDbTrust(sourceID, targetID string, voteValue int, force bool) error {
rows.Scan(&lastUpdated) rows.Scan(&lastUpdated)
if time.Now().Sub(lastUpdated) > 24*30*time.Hour { if time.Now().Sub(lastUpdated) > 24*30*time.Hour {
rows.Close() rows.Close()
return errRecentlyChanged return ErrRecentlyChanged
} }
} }
rows.Close() rows.Close()

@ -1,19 +1,21 @@
package main package db
import ( import (
"errors" "errors"
"time" "time"
"TerraOceanBot/common"
) )
const ( const (
forceRejectionVote = -1 ForceRejectionVote = -1
nuclearOptionVote = 10 NuclearOptionVote = 10
) )
var errForceRejectionVoteReuse = errors.New("db: the user has used force rejection vote in the last month") var ErrForceRejectionVoteReuse = errors.New("db: the user has used force rejection vote in the last month")
var errVoteIsOver = errors.New("db: the vote cannot be changed once it's over") var ErrVoteIsOver = errors.New("db: the vote cannot be changed once it's over")
func createCustomVote(messageID, text string) (int, error) { func CreateCustomVote(messageID, text string) (int, error) {
result, err := db.Exec("INSERT INTO vote(messageId, name, type, finished) VALUES(?, ?, ?, ?)", messageID, text, "custom", false) result, err := db.Exec("INSERT INTO vote(messageId, name, type, finished) VALUES(?, ?, ?, ?)", messageID, text, "custom", false)
if err != nil { if err != nil {
return 0, err return 0, err
@ -22,7 +24,7 @@ func createCustomVote(messageID, text string) (int, error) {
return int(lastID), nil return int(lastID), nil
} }
func createInviteVote(messageID, username, reason string) (int, error) { func CreateInviteVote(messageID, username, reason string) (int, error) {
result, err := db.Exec("INSERT INTO vote(messageId, name, type, finished) VALUES(?, ?, ?, ?)", messageID, username+":"+reason, "invite", false) result, err := db.Exec("INSERT INTO vote(messageId, name, type, finished) VALUES(?, ?, ?, ?)", messageID, username+":"+reason, "invite", false)
if err != nil { if err != nil {
return 0, err return 0, err
@ -31,7 +33,7 @@ func createInviteVote(messageID, username, reason string) (int, error) {
return int(lastID), nil return int(lastID), nil
} }
func getVoteName(id int) (string, error) { func GetVoteName(id int) (string, error) {
rows, err := db.Query("SELECT name FROM vote WHERE id=?", id) rows, err := db.Query("SELECT name FROM vote WHERE id=?", id)
if err != nil { if err != nil {
return "", err return "", err
@ -45,10 +47,10 @@ func getVoteName(id int) (string, error) {
} }
return name, nil return name, nil
} }
return "", errNotFound return "", ErrNotFound
} }
func getVoteFromMessageID(msgID string) (int, error) { func GetVoteFromMessageID(msgID string) (int, error) {
rows, err := db.Query("SELECT id FROM vote WHERE messageId=?", msgID) rows, err := db.Query("SELECT id FROM vote WHERE messageId=?", msgID)
if err != nil { if err != nil {
return 0, err return 0, err
@ -62,10 +64,10 @@ func getVoteFromMessageID(msgID string) (int, error) {
} }
return id, nil return id, nil
} }
return 0, errNotFound return 0, ErrNotFound
} }
func getMessageIDFromVote(voteID int) (string, error) { func GetMessageIDFromVote(voteID int) (string, error) {
rows, err := db.Query("SELECT messageId FROM vote WHERE id=?", voteID) rows, err := db.Query("SELECT messageId FROM vote WHERE id=?", voteID)
if err != nil { if err != nil {
return "", err return "", err
@ -79,10 +81,10 @@ func getMessageIDFromVote(voteID int) (string, error) {
} }
return id, nil return id, nil
} }
return "", errNotFound return "", ErrNotFound
} }
func getVoteType(voteID int) (string, error) { func GetVoteType(voteID int) (string, error) {
rows, err := db.Query("SELECT type FROM vote WHERE id=?", voteID) rows, err := db.Query("SELECT type FROM vote WHERE id=?", voteID)
if err != nil { if err != nil {
return "", err return "", err
@ -96,18 +98,18 @@ func getVoteType(voteID int) (string, error) {
} }
return messageType, nil return messageType, nil
} }
return "", errNotFound return "", ErrNotFound
} }
func updateVote(voteID int, userID string, voteValue int) error { func UpdateVote(voteID int, userID string, voteValue int) error {
if voteValue == forceRejectionVote { if voteValue == ForceRejectionVote {
//Check if they used it within a month. //Check if they used it within a month.
rows, err := db.Query("SELECT voteId FROM choices WHERE date >= ? AND value=?", time.Now().AddDate(0, -1, 0)) rows, err := db.Query("SELECT voteId FROM choices WHERE date >= ? AND value=?", time.Now().AddDate(0, -1, 0))
if err != nil { if err != nil {
return err return err
} }
if rows.Next() { if rows.Next() {
return errForceRejectionVoteReuse return ErrForceRejectionVoteReuse
} }
rows.Close() rows.Close()
} }
@ -116,30 +118,30 @@ func updateVote(voteID int, userID string, voteValue int) error {
rows, err := db.Query("SELECT finished FROM vote WHERE id=? AND finished=?", voteID, true) rows, err := db.Query("SELECT finished FROM vote WHERE id=? AND finished=?", voteID, true)
defer rows.Close() defer rows.Close()
if rows.Next() { if rows.Next() {
return errVoteIsOver return ErrVoteIsOver
} }
_, err = db.Exec("REPLACE INTO choices (voteId, userId, date, value) VALUES (?, ?, ?, ?)", voteID, userID, time.Now(), voteValue) _, err = db.Exec("REPLACE INTO choices (voteId, userId, date, value) VALUES (?, ?, ?, ?)", voteID, userID, time.Now(), voteValue)
return err return err
} }
func finishVote(voteID int) error { func FinishVote(voteID int) error {
_, err := db.Exec("UPDATE vote SET finished=true WHERE id=?", voteID) _, err := db.Exec("UPDATE vote SET finished=true WHERE id=?", voteID)
return err return err
} }
func getAllVoteChoices(voteID int) ([]voteChoice, error) { func GetAllVoteChoices(voteID int) ([]common.VoteChoice, error) {
rows, err := db.Query("SELECT userId, value FROM choices WHERE voteId=?", voteID) rows, err := db.Query("SELECT userId, value FROM choices WHERE voteId=?", voteID)
if err != nil { if err != nil {
return []voteChoice{}, err return []common.VoteChoice{}, err
} }
defer rows.Close() defer rows.Close()
array := make([]voteChoice, 0) array := make([]common.VoteChoice, 0)
for rows.Next() { for rows.Next() {
var userID string var userID string
var value int var value int
rows.Scan(&userID, &value) rows.Scan(&userID, &value)
array = append(array, voteChoice{ array = append(array, common.VoteChoice{
UserID: userID, UserID: userID,
Value: value, Value: value,
}) })

@ -4,6 +4,8 @@ import (
"errors" "errors"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"TerraOceanBot/db"
) )
var trustCache map[string]float64 var trustCache map[string]float64
@ -27,7 +29,7 @@ func getTrust(s *discordgo.Session, discordID string) (float64, error) {
if !isMember { if !isMember {
return 0.0, errNotMember return 0.0, errNotMember
} }
votes, err := getTrustVote(discordID) votes, err := db.GetTrustVote(discordID)
if err != nil { if err != nil {
return 0.0, err return 0.0, err
} }
@ -65,7 +67,7 @@ func getTotalTrust(s *discordgo.Session) (float64, error) {
} }
func updateTrust(sourceID, targetID string, voteValue int) error { func updateTrust(sourceID, targetID string, voteValue int) error {
err := updateDbTrust(sourceID, targetID, voteValue, false) err := db.UpdateTrust(sourceID, targetID, voteValue, false)
if err != nil { if err != nil {
return err return err
} }

@ -6,6 +6,8 @@ import (
"strings" "strings"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"TerraOceanBot/db"
) )
const announceInviteChannel = "627165467269922864" const announceInviteChannel = "627165467269922864"
@ -30,12 +32,12 @@ func handleInviteResult(s *discordgo.Session, id int, name string, isPositive bo
return return
} }
s.ChannelMessageSend(announceInviteChannel, fmt.Sprintf("欢迎<@%s>的加入Welcome <@%s> to this server!", member.User.ID, member.User.ID)) s.ChannelMessageSend(announceInviteChannel, fmt.Sprintf("欢迎<@%s>的加入Welcome <@%s> to this server!", member.User.ID, member.User.ID))
choices, err := getAllVoteChoices(id) choices, err := db.GetAllVoteChoices(id)
if err != nil { if err != nil {
auditLog(s, "Error while pulling vote choices. "+err.Error()) auditLog(s, "Error while pulling vote choices. "+err.Error())
return return
} }
for _, choice := range choices { for _, choice := range choices {
updateDbTrust(choice.UserID, member.User.ID, choice.Value, true) db.UpdateTrust(choice.UserID, member.User.ID, choice.Value, true)
} }
} }

@ -6,12 +6,9 @@ import (
"time" "time"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
)
type voteChoice struct { "TerraOceanBot/db"
UserID string )
Value int
}
type confirmedResult struct { type confirmedResult struct {
VoteID int VoteID int
@ -26,12 +23,12 @@ func listenToVoteFinishes(s *discordgo.Session) {
voteMutex.Lock() voteMutex.Lock()
for _, v := range toAnnounceResultList { for _, v := range toAnnounceResultList {
auditLog(s, fmt.Sprintf("Announcing the result of vote %d.", v.VoteID)) auditLog(s, fmt.Sprintf("Announcing the result of vote %d.", v.VoteID))
if err := finishVote(v.VoteID); err != nil { if err := db.FinishVote(v.VoteID); err != nil {
auditLog(s, fmt.Sprintf("Error while finishing vote %d.", v.VoteID)) auditLog(s, fmt.Sprintf("Error while finishing vote %d.", v.VoteID))
} }
voteType, _ := getVoteType(v.VoteID) voteType, _ := db.GetVoteType(v.VoteID)
voteName, _ := getVoteName(v.VoteID) voteName, _ := db.GetVoteName(v.VoteID)
messageID, _ := getMessageIDFromVote(v.VoteID) messageID, _ := db.GetMessageIDFromVote(v.VoteID)
s.ChannelMessageEditEmbed(voteChannel, messageID, showVoteStatus(voteTypes[voteType].EmbedBuilder(v.VoteID, voteName), v.IsPositive)) s.ChannelMessageEditEmbed(voteChannel, messageID, showVoteStatus(voteTypes[voteType].EmbedBuilder(v.VoteID, voteName), v.IsPositive))
voteTypes[voteType].ResultHandler(s, v.VoteID, voteName, v.IsPositive) voteTypes[voteType].ResultHandler(s, v.VoteID, voteName, v.IsPositive)
} }
@ -42,7 +39,7 @@ func listenToVoteFinishes(s *discordgo.Session) {
} }
func checkForVoteResult(s *discordgo.Session, id int) { func checkForVoteResult(s *discordgo.Session, id int) {
votes, err := getAllVoteChoices(id) votes, err := db.GetAllVoteChoices(id)
if err != nil { if err != nil {
auditLog(s, fmt.Sprintf("Error while calculating vote result for ID %d: %s", id, err.Error())) auditLog(s, fmt.Sprintf("Error while calculating vote result for ID %d: %s", id, err.Error()))
return return
@ -50,7 +47,7 @@ func checkForVoteResult(s *discordgo.Session, id int) {
var currentScore, totalTrust float64 var currentScore, totalTrust float64
var absoluteRejectionVote bool var absoluteRejectionVote bool
for _, vote := range votes { for _, vote := range votes {
if vote.Value == forceRejectionVote { if vote.Value == db.ForceRejectionVote {
absoluteRejectionVote = true absoluteRejectionVote = true
} }
trust, err := getTrust(s, vote.UserID) trust, err := getTrust(s, vote.UserID)