coverage/db/auth.go

141 lines
3.5 KiB
Go

package db
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"os"
"time"
jwt "github.com/dgrijalva/jwt-go"
"golang.org/x/oauth2"
)
var conf = &oauth2.Config{
ClientID: os.Getenv("OAUTH_ID"),
ClientSecret: os.Getenv("OAUTH_SECRET"),
Scopes: []string{"authorization_code"},
Endpoint: oauth2.Endpoint{
AuthURL: os.Getenv("GITEA_HOST") + "/login/oauth/authorize",
TokenURL: os.Getenv("GITEA_HOST") + "/login/oauth/access_token",
},
RedirectURL: os.Getenv("REDIRECT_URL"),
}
var (
validStates = make(map[string]string)
base64Encoding = base64.NewEncoding("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890+-")
)
// GenerateNewToken creates a new redirect URL for users.
func GenerateNewToken(redirect string) string {
state := randomToken()
validStates[state] = redirect
go func() {
time.Sleep(time.Minute * 2)
delete(validStates, state)
}()
return conf.AuthCodeURL(state)
}
// ConsumeOAuthCode takes the oauth response and returns a JWT Token.
func ConsumeOAuthCode(code, state string) (string, string, error) {
redirect, ok := validStates[state]
if !ok {
return "", "", errors.New("invalid state provided. please try again")
}
log.Printf("DEBUG :: Consuming OAuth code: code: %s, state: %s | redirect: %s\n", code, state, redirect)
token, err := conf.Exchange(context.Background(), code)
if err != nil {
log.Printf("DEBUG :: Could not exchange: %v\n", err)
return "", "", err
}
delete(validStates, state)
jwtToken := createJWT(token)
return jwtToken, redirect, nil
}
// Authorize checks if the provided JWT token and returns a HTTP Cliennt.
func Authorize(jwtToken string) (*http.Client, string, error) {
token, err := jwt.Parse(jwtToken, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("invalid signing method found, expected HMAC")
}
return []byte(os.Getenv("JWT_SECRET")), nil
})
if err != nil {
return nil, "", err
}
if !token.Valid {
return nil, "", errors.New("token invalid")
}
tokenSource := conf.TokenSource(context.Background(), createOAuth(token))
oauthToken, err := tokenSource.Token()
if err != nil {
fmt.Printf("Could not regenerate OAuth token: %s", err)
return nil, "", err
}
client := conf.Client(context.Background(), oauthToken)
return client, createJWT(oauthToken), err
}
func randomToken() string {
session := ""
ok := true
for ok {
token := make([]byte, 18)
_, err := rand.Read(token)
if err != nil {
panic(err)
}
session = base64Encoding.EncodeToString(token)
_, ok = validStates[session]
}
return session
}
func createJWT(token *oauth2.Token) string {
jwtToken := jwt.New(jwt.SigningMethodHS256)
claims := jwtToken.Claims.(jwt.MapClaims)
claims["access_token"] = token.AccessToken
claims["refresh_token"] = token.RefreshToken
claims["token_type"] = token.TokenType
claims["expiry"] = token.Expiry.Format(time.ANSIC)
claims["exp"] = time.Now().Add(time.Hour * 24 * 7)
str, err := jwtToken.SignedString([]byte(os.Getenv("JWT_SECRET")))
if err != nil {
panic(err)
}
return str
}
func createOAuth(jwtToken *jwt.Token) *oauth2.Token {
token := new(oauth2.Token)
claims := jwtToken.Claims.(jwt.MapClaims)
var err error
token.AccessToken = claims["access_token"].(string)
token.RefreshToken = claims["refresh_token"].(string)
token.TokenType = claims["token_type"].(string)
token.Expiry, err = time.Parse(time.ANSIC, claims["expiry"].(string))
if err != nil {
log.Fatalf("could not unmarshal token expiry: %v", err)
}
return token
}