package db import ( "context" "crypto/rand" "encoding/base64" "errors" "fmt" "log" "net/http" "os" "time" jwt "github.com/dgrijalva/jwt-go" "golang.org/x/oauth2" ) var conf = &oauth2.Config{ ClientID: os.Getenv("OAUTH_ID"), ClientSecret: os.Getenv("OAUTH_SECRET"), Scopes: []string{"authorization_code"}, Endpoint: oauth2.Endpoint{ AuthURL: os.Getenv("GITEA_HOST") + "/login/oauth/authorize", TokenURL: os.Getenv("GITEA_HOST") + "/login/oauth/access_token", }, RedirectURL: os.Getenv("REDIRECT_URL"), } var ( validStates = make(map[string]string) base64Encoding = base64.NewEncoding("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890+-") ) // GenerateNewToken creates a new redirect URL for users. func GenerateNewToken(redirect string) string { state := randomToken() validStates[state] = redirect go func() { time.Sleep(time.Minute * 2) delete(validStates, state) }() return conf.AuthCodeURL(state) } // ConsumeOAuthCode takes the oauth response and returns a JWT Token. func ConsumeOAuthCode(code, state string) (string, string, error) { redirect, ok := validStates[state] if !ok { return "", "", errors.New("invalid state provided. please try again") } log.Printf("DEBUG :: Consuming OAuth code: code: %s, state: %s | redirect: %s\n", code, state, redirect) token, err := conf.Exchange(context.Background(), code) if err != nil { log.Printf("DEBUG :: Could not exchange: %v\n", err) return "", "", err } delete(validStates, state) jwtToken := createJWT(token) return jwtToken, redirect, nil } // Authorize checks if the provided JWT token and returns a HTTP Cliennt. func Authorize(jwtToken string) (*http.Client, string, error) { token, err := jwt.Parse(jwtToken, func(t *jwt.Token) (interface{}, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("invalid signing method found, expected HMAC") } return []byte(os.Getenv("JWT_SECRET")), nil }) if err != nil { return nil, "", err } if !token.Valid { return nil, "", errors.New("token invalid") } tokenSource := conf.TokenSource(context.Background(), createOAuth(token)) oauthToken, err := tokenSource.Token() if err != nil { fmt.Printf("Could not regenerate OAuth token: %s", err) return nil, "", err } client := conf.Client(context.Background(), oauthToken) return client, createJWT(oauthToken), err } func randomToken() string { session := "" ok := true for ok { token := make([]byte, 18) _, err := rand.Read(token) if err != nil { panic(err) } session = base64Encoding.EncodeToString(token) _, ok = validStates[session] } return session } func createJWT(token *oauth2.Token) string { jwtToken := jwt.New(jwt.SigningMethodHS256) claims := jwtToken.Claims.(jwt.MapClaims) claims["access_token"] = token.AccessToken claims["refresh_token"] = token.RefreshToken claims["token_type"] = token.TokenType claims["expiry"] = token.Expiry.Format(time.ANSIC) claims["exp"] = time.Now().Add(time.Hour * 24 * 7) str, err := jwtToken.SignedString([]byte(os.Getenv("JWT_SECRET"))) if err != nil { panic(err) } return str } func createOAuth(jwtToken *jwt.Token) *oauth2.Token { token := new(oauth2.Token) claims := jwtToken.Claims.(jwt.MapClaims) var err error token.AccessToken = claims["access_token"].(string) token.RefreshToken = claims["refresh_token"].(string) token.TokenType = claims["token_type"].(string) token.Expiry, err = time.Parse(time.ANSIC, claims["expiry"].(string)) if err != nil { log.Fatalf("could not unmarshal token expiry: %v", err) } return token }