2021-09-28 23:57:59 +07:00
|
|
|
package db
|
|
|
|
|
|
|
|
import (
|
2021-09-30 10:32:18 +07:00
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"os"
|
2021-09-28 23:57:59 +07:00
|
|
|
"sync"
|
|
|
|
"time"
|
2021-09-30 10:32:18 +07:00
|
|
|
|
|
|
|
"golang.org/x/oauth2"
|
|
|
|
"golang.org/x/oauth2/google"
|
2021-09-28 23:57:59 +07:00
|
|
|
)
|
|
|
|
|
2021-09-30 10:32:18 +07:00
|
|
|
const (
|
|
|
|
base = "https://www.googleapis.com"
|
|
|
|
scopeEmail = base + "/auth/userinfo.email"
|
|
|
|
scopeProfile = base + "/auth/userinfo.profile"
|
|
|
|
)
|
2021-09-28 23:57:59 +07:00
|
|
|
|
|
|
|
type OAuthState struct {
|
2021-09-30 10:32:18 +07:00
|
|
|
*oauth2.Config
|
2021-09-28 23:57:59 +07:00
|
|
|
states map[string]oauthEntry
|
|
|
|
m sync.Mutex
|
|
|
|
}
|
|
|
|
|
2021-09-30 10:32:18 +07:00
|
|
|
type oauthEntry struct {
|
|
|
|
created time.Time
|
|
|
|
callback string
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewOAuthState(path string) (*OAuthState, error) {
|
|
|
|
key, err := os.ReadFile(path)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("could not open file: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
config, err := google.ConfigFromJSON(key, scopeEmail, scopeProfile)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("could not load config: %w", err)
|
|
|
|
}
|
|
|
|
|
2021-09-28 23:57:59 +07:00
|
|
|
return &OAuthState{
|
2021-09-30 10:32:18 +07:00
|
|
|
Config: config,
|
2021-09-28 23:57:59 +07:00
|
|
|
states: make(map[string]oauthEntry),
|
|
|
|
m: sync.Mutex{},
|
2021-09-30 10:32:18 +07:00
|
|
|
}, nil
|
2021-09-28 23:57:59 +07:00
|
|
|
}
|
|
|
|
|
|
|
|
func (o *OAuthState) Create(callback string) string {
|
|
|
|
o.m.Lock()
|
|
|
|
defer o.m.Unlock()
|
|
|
|
|
|
|
|
var state string
|
|
|
|
|
|
|
|
ok := true
|
|
|
|
for ok {
|
|
|
|
state = createToken(32)
|
|
|
|
_, ok = o.states[state]
|
|
|
|
}
|
|
|
|
|
|
|
|
o.states[state] = oauthEntry{
|
|
|
|
created: time.Now(),
|
|
|
|
callback: callback,
|
|
|
|
}
|
2021-09-30 10:32:18 +07:00
|
|
|
return o.Config.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
2021-09-28 23:57:59 +07:00
|
|
|
}
|
|
|
|
|
2021-09-30 10:32:18 +07:00
|
|
|
func (o *OAuthState) Validate(state string) (string, bool) {
|
2021-09-28 23:57:59 +07:00
|
|
|
o.m.Lock()
|
|
|
|
defer o.m.Unlock()
|
|
|
|
|
2021-09-30 10:32:18 +07:00
|
|
|
entry, ok := o.states[state]
|
|
|
|
if !ok {
|
|
|
|
return "", false
|
|
|
|
}
|
|
|
|
return entry.callback, true
|
|
|
|
}
|
|
|
|
|
|
|
|
func (o *OAuthState) Exchange(ctx context.Context, code string) (*oauth2.Token, error) {
|
|
|
|
tk, err := o.Config.Exchange(ctx, code, oauth2.AccessTypeOffline)
|
|
|
|
return tk, err
|
2021-09-28 23:57:59 +07:00
|
|
|
}
|
|
|
|
|
|
|
|
func (o *OAuthState) Remove(code string) {
|
|
|
|
o.m.Lock()
|
|
|
|
defer o.m.Unlock()
|
|
|
|
|
|
|
|
delete(o.states, code)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (o *OAuthState) GarbageCycle(period time.Duration, duration time.Duration) {
|
|
|
|
tick := time.NewTicker(period)
|
|
|
|
|
|
|
|
clean := func() {
|
|
|
|
o.m.Lock()
|
|
|
|
defer o.m.Unlock()
|
|
|
|
|
|
|
|
now := time.Now()
|
|
|
|
for state, entry := range o.states {
|
|
|
|
if now.After(entry.created.Add(duration)) {
|
|
|
|
delete(o.states, state)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for {
|
|
|
|
select {
|
|
|
|
case <-tick.C:
|
|
|
|
clean()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|