2021-09-28 23:57:59 +07:00
|
|
|
package db
|
|
|
|
|
|
|
|
import (
|
2021-09-30 10:32:18 +07:00
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"os"
|
2021-09-28 23:57:59 +07:00
|
|
|
"sync"
|
|
|
|
"time"
|
2021-09-30 10:32:18 +07:00
|
|
|
|
2021-12-21 10:18:12 +07:00
|
|
|
"github.com/hhhapz/codequest/models"
|
2021-09-30 10:32:18 +07:00
|
|
|
"golang.org/x/oauth2"
|
|
|
|
"golang.org/x/oauth2/google"
|
2021-09-28 23:57:59 +07:00
|
|
|
)
|
|
|
|
|
2021-09-30 10:32:18 +07:00
|
|
|
const (
|
|
|
|
base = "https://www.googleapis.com"
|
|
|
|
scopeEmail = base + "/auth/userinfo.email"
|
|
|
|
scopeProfile = base + "/auth/userinfo.profile"
|
|
|
|
)
|
2021-09-28 23:57:59 +07:00
|
|
|
|
|
|
|
type OAuthState struct {
|
2021-12-21 10:18:12 +07:00
|
|
|
config *oauth2.Config
|
2021-09-28 23:57:59 +07:00
|
|
|
states map[string]oauthEntry
|
|
|
|
m sync.Mutex
|
|
|
|
}
|
|
|
|
|
2021-09-30 10:32:18 +07:00
|
|
|
type oauthEntry struct {
|
|
|
|
created time.Time
|
|
|
|
callback string
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewOAuthState(path string) (*OAuthState, error) {
|
|
|
|
key, err := os.ReadFile(path)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("could not open file: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
config, err := google.ConfigFromJSON(key, scopeEmail, scopeProfile)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("could not load config: %w", err)
|
|
|
|
}
|
|
|
|
|
2021-09-28 23:57:59 +07:00
|
|
|
return &OAuthState{
|
2021-12-21 10:18:12 +07:00
|
|
|
config: config,
|
2021-09-28 23:57:59 +07:00
|
|
|
states: make(map[string]oauthEntry),
|
|
|
|
m: sync.Mutex{},
|
2021-09-30 10:32:18 +07:00
|
|
|
}, nil
|
2021-09-28 23:57:59 +07:00
|
|
|
}
|
|
|
|
|
|
|
|
func (o *OAuthState) Create(callback string) string {
|
|
|
|
o.m.Lock()
|
|
|
|
defer o.m.Unlock()
|
|
|
|
|
|
|
|
var state string
|
|
|
|
|
|
|
|
ok := true
|
|
|
|
for ok {
|
|
|
|
state = createToken(32)
|
|
|
|
_, ok = o.states[state]
|
|
|
|
}
|
|
|
|
|
|
|
|
o.states[state] = oauthEntry{
|
|
|
|
created: time.Now(),
|
|
|
|
callback: callback,
|
|
|
|
}
|
2021-12-21 10:18:12 +07:00
|
|
|
return o.config.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
2021-09-28 23:57:59 +07:00
|
|
|
}
|
|
|
|
|
2021-12-21 10:18:12 +07:00
|
|
|
func (o *OAuthState) Validate(ctx context.Context, state string, code string) (*oauth2.Token, string, error) {
|
2021-09-28 23:57:59 +07:00
|
|
|
o.m.Lock()
|
|
|
|
defer o.m.Unlock()
|
|
|
|
|
2021-09-30 10:32:18 +07:00
|
|
|
entry, ok := o.states[state]
|
|
|
|
if !ok {
|
2021-12-21 10:18:12 +07:00
|
|
|
return nil, "", models.NewUserError("invalid state: %q", state)
|
2021-09-30 10:32:18 +07:00
|
|
|
}
|
2021-12-19 15:30:21 +07:00
|
|
|
delete(o.states, state)
|
2021-09-30 10:32:18 +07:00
|
|
|
|
2021-12-21 10:18:12 +07:00
|
|
|
tk, err := o.config.Exchange(ctx, code, oauth2.AccessTypeOffline)
|
|
|
|
return tk, entry.callback, err
|
2021-09-28 23:57:59 +07:00
|
|
|
}
|
|
|
|
|
|
|
|
func (o *OAuthState) GarbageCycle(period time.Duration, duration time.Duration) {
|
|
|
|
tick := time.NewTicker(period)
|
|
|
|
|
2021-12-21 10:18:12 +07:00
|
|
|
for range tick.C {
|
2021-09-28 23:57:59 +07:00
|
|
|
o.m.Lock()
|
|
|
|
|
|
|
|
now := time.Now()
|
|
|
|
for state, entry := range o.states {
|
|
|
|
if now.After(entry.created.Add(duration)) {
|
|
|
|
delete(o.states, state)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-12-21 10:18:12 +07:00
|
|
|
o.m.Unlock()
|
2021-09-28 23:57:59 +07:00
|
|
|
}
|
|
|
|
}
|