141 lines
3.5 KiB
Go
141 lines
3.5 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
jwt "github.com/dgrijalva/jwt-go"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
var conf = &oauth2.Config{
|
|
ClientID: os.Getenv("OAUTH_ID"),
|
|
ClientSecret: os.Getenv("OAUTH_SECRET"),
|
|
Scopes: []string{"authorization_code"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: os.Getenv("GITEA_HOST") + "/login/oauth/authorize",
|
|
TokenURL: os.Getenv("GITEA_HOST") + "/login/oauth/access_token",
|
|
},
|
|
RedirectURL: os.Getenv("REDIRECT_URL"),
|
|
}
|
|
|
|
var (
|
|
validStates = make(map[string]string)
|
|
|
|
base64Encoding = base64.NewEncoding("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890+-")
|
|
)
|
|
|
|
// GenerateNewToken creates a new redirect URL for users.
|
|
func GenerateNewToken(redirect string) string {
|
|
state := randomToken()
|
|
validStates[state] = redirect
|
|
go func() {
|
|
time.Sleep(time.Minute * 2)
|
|
delete(validStates, state)
|
|
}()
|
|
|
|
return conf.AuthCodeURL(state)
|
|
}
|
|
|
|
// ConsumeOAuthCode takes the oauth response and returns a JWT Token.
|
|
func ConsumeOAuthCode(code, state string) (string, string, error) {
|
|
redirect, ok := validStates[state]
|
|
if !ok {
|
|
return "", "", errors.New("invalid state provided. please try again")
|
|
}
|
|
log.Printf("DEBUG :: Consuming OAuth code: code: %s, state: %s | redirect: %s\n", code, state, redirect)
|
|
|
|
token, err := conf.Exchange(context.Background(), code)
|
|
if err != nil {
|
|
log.Printf("DEBUG :: Could not exchange: %v\n", err)
|
|
return "", "", err
|
|
}
|
|
delete(validStates, state)
|
|
|
|
jwtToken := createJWT(token)
|
|
return jwtToken, redirect, nil
|
|
}
|
|
|
|
// Authorize checks if the provided JWT token and returns a HTTP Cliennt.
|
|
func Authorize(jwtToken string) (*http.Client, string, error) {
|
|
token, err := jwt.Parse(jwtToken, func(t *jwt.Token) (interface{}, error) {
|
|
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("invalid signing method found, expected HMAC")
|
|
}
|
|
return []byte(os.Getenv("JWT_SECRET")), nil
|
|
})
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
if !token.Valid {
|
|
return nil, "", errors.New("token invalid")
|
|
}
|
|
|
|
tokenSource := conf.TokenSource(context.Background(), createOAuth(token))
|
|
oauthToken, err := tokenSource.Token()
|
|
if err != nil {
|
|
fmt.Printf("Could not regenerate OAuth token: %s", err)
|
|
return nil, "", err
|
|
}
|
|
|
|
client := conf.Client(context.Background(), oauthToken)
|
|
return client, createJWT(oauthToken), err
|
|
}
|
|
|
|
func randomToken() string {
|
|
session := ""
|
|
ok := true
|
|
|
|
for ok {
|
|
token := make([]byte, 18)
|
|
_, err := rand.Read(token)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
session = base64Encoding.EncodeToString(token)
|
|
_, ok = validStates[session]
|
|
}
|
|
return session
|
|
}
|
|
|
|
func createJWT(token *oauth2.Token) string {
|
|
jwtToken := jwt.New(jwt.SigningMethodHS256)
|
|
claims := jwtToken.Claims.(jwt.MapClaims)
|
|
claims["access_token"] = token.AccessToken
|
|
claims["refresh_token"] = token.RefreshToken
|
|
claims["token_type"] = token.TokenType
|
|
claims["expiry"] = token.Expiry.Format(time.ANSIC)
|
|
claims["exp"] = time.Now().Add(time.Hour * 24 * 7)
|
|
|
|
str, err := jwtToken.SignedString([]byte(os.Getenv("JWT_SECRET")))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return str
|
|
}
|
|
|
|
func createOAuth(jwtToken *jwt.Token) *oauth2.Token {
|
|
token := new(oauth2.Token)
|
|
claims := jwtToken.Claims.(jwt.MapClaims)
|
|
|
|
var err error
|
|
token.AccessToken = claims["access_token"].(string)
|
|
token.RefreshToken = claims["refresh_token"].(string)
|
|
token.TokenType = claims["token_type"].(string)
|
|
token.Expiry, err = time.Parse(time.ANSIC, claims["expiry"].(string))
|
|
if err != nil {
|
|
log.Fatalf("could not unmarshal token expiry: %v", err)
|
|
}
|
|
|
|
return token
|
|
}
|