package api import ( "context" "encoding/json" "errors" "fmt" "io" "log" "net/http" "strconv" "strings" "time" "github.com/hhhapz/hackathon/models" "golang.org/x/oauth2" ) type AuthService struct { oauthStore OAuthStore userStore UserStore } func (as *AuthService) GenOauth(w http.ResponseWriter, r *http.Request, params GenOauthParams) { // We enforce callback field in production. if inProd && params.Callback != "https://hackathon.teamortix.com" { serverError(w, http.StatusUnprocessableEntity, "invalid callback provided") } page := as.oauthStore.Create(params.Callback) w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(ConsentPage{page}) } var couldNotValidate = "Unable to authorize. Please try again.\nIf this issue persists, please contact an admin." func (as *AuthService) AuthorizeCallback(w http.ResponseWriter, r *http.Request, params AuthorizeCallbackParams) { callback, ok := as.oauthStore.Validate(params.State) if !ok { serverError(w, http.StatusUnprocessableEntity, "Invalid state provided") return } token, err := as.oauthStore.Exchange(r.Context(), params.Code) if err != nil { log.Printf("error: could not perform token exchange: %v", err) serverError(w, http.StatusInternalServerError, couldNotValidate) return } accessToken, err := as.consumeToken(r.Context(), token) reason := "serverError" if errors.As(err, &parseError{}) { reason = "email" } if err != nil { log.Printf("error: could not consume oauth token: %v", err) uri := fmt.Sprintf("%s/fail?reason=%s", callback, reason) http.Redirect(w, r, uri, http.StatusFound) return } http.SetCookie(w, &http.Cookie{ Name: "token", Value: accessToken, Path: "/", Expires: time.Now().Add(time.Hour * 24 * 14), // 2 weeks SameSite: http.SameSiteNoneMode, // handled via CORS Secure: inProd, }) http.Redirect(w, r, callback+"/success", http.StatusFound) } const userinfoEndpoint = "https://www.googleapis.com/oauth2/v2/userinfo?access_token=" func (as *AuthService) consumeToken(ctx context.Context, token *oauth2.Token) (string, error) { endpoint := userinfoEndpoint + token.AccessToken client := http.Client{Timeout: time.Second * 5} res, err := client.Get(endpoint) if err != nil { return "", fmt.Errorf("could not get userinfo: %w", err) } defer res.Body.Close() buf, err := io.ReadAll(res.Body) if err != nil { return "", fmt.Errorf("could not read request body: %w", err) } user := &models.User{ CreatedAt: models.NewTime(time.Now()), } if err := json.Unmarshal(buf, &user); err != nil { return "", fmt.Errorf("could not decode userinfo json: %w", err) } if u, err := as.userStore.User(ctx, user.Email); err == nil { var tk *models.Token if tk, err = as.userStore.CreateToken(ctx, u); err != nil { return "", fmt.Errorf("could not create user token: %w", err) } return tk.Token, nil } student, valid := parseEmail(user.Email) if !valid { return "", newParseError("email", user.Email, "invalid email domain") } user.Teacher = !student err = as.userStore.UpdateUser(ctx, user) if err != nil { return "", fmt.Errorf("could not create user: %w", err) } tk, err := as.userStore.CreateToken(ctx, user) if err != nil { return "", fmt.Errorf("could not create user token: %w", err) } return tk.Token, nil } func (as *AuthService) DeleteToken(w http.ResponseWriter, r *http.Request, params DeleteTokenParams) { var all bool if params.All != nil { all = *params.All } var err error if all { user, err := as.userStore.UserByToken(r.Context(), params.Token) if err != nil { serverError(w, http.StatusUnauthorized, "Invalid token provided.") return } _, err = as.userStore.RevokeUserTokens(r.Context(), user) } else { err = as.userStore.RevokeToken(r.Context(), params.Token) } if err != nil { log.Printf("Could not revoke user token %#v: %v", params, err) serverError(w, http.StatusInternalServerError, "Could not revoke token.") return } w.WriteHeader(http.StatusNoContent) } const domain = "jisedu.or.id" // parseEmail parses the provided email, and returns two bools; // 1. if the email belongs to a student: it conforms to `\d+@[domain]`, and // 2. if it's a valid internal email (ends with `domain`). func parseEmail(email string) (bool, bool) { parts := strings.Split(email, "@") if len(parts) != 2 || parts[1] != domain { return false, false } _, err := strconv.Atoi(parts[1]) return err == nil, true }