535 lines
13 KiB
Go
535 lines
13 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/writekitapp/writekit/internal/db"
|
|
)
|
|
|
|
type Handler struct {
|
|
database *db.DB
|
|
sessionSecret []byte
|
|
baseURL string
|
|
providers map[string]provider
|
|
}
|
|
|
|
type provider struct {
|
|
Name string
|
|
ClientID string
|
|
ClientSecret string
|
|
AuthURL string
|
|
TokenURL string
|
|
UserInfoURL string
|
|
Scopes []string
|
|
}
|
|
|
|
type oauthToken struct {
|
|
AccessToken string `json:"access_token"`
|
|
}
|
|
|
|
type oauthState struct {
|
|
Provider string `json:"p"`
|
|
TenantID string `json:"t,omitempty"`
|
|
Redirect string `json:"r,omitempty"`
|
|
Callback string `json:"c,omitempty"`
|
|
Timestamp int64 `json:"ts"`
|
|
}
|
|
|
|
type userInfo struct {
|
|
ID string
|
|
Email string
|
|
Name string
|
|
AvatarURL string
|
|
}
|
|
|
|
func NewHandler(database *db.DB) *Handler {
|
|
baseURL := os.Getenv("BASE_URL")
|
|
if baseURL == "" {
|
|
baseURL = "https://writekit.dev"
|
|
}
|
|
|
|
secret := os.Getenv("SESSION_SECRET")
|
|
if secret == "" {
|
|
secret = "dev-secret-change-in-production"
|
|
}
|
|
|
|
h := &Handler{
|
|
database: database,
|
|
sessionSecret: []byte(secret),
|
|
baseURL: baseURL,
|
|
providers: make(map[string]provider),
|
|
}
|
|
|
|
if id := os.Getenv("GOOGLE_CLIENT_ID"); id != "" {
|
|
h.providers["google"] = provider{
|
|
Name: "Google",
|
|
ClientID: id,
|
|
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
|
AuthURL: "https://accounts.google.com/o/oauth2/v2/auth",
|
|
TokenURL: "https://oauth2.googleapis.com/token",
|
|
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
|
Scopes: []string{"email", "profile"},
|
|
}
|
|
}
|
|
|
|
if id := os.Getenv("GITHUB_CLIENT_ID"); id != "" {
|
|
h.providers["github"] = provider{
|
|
Name: "GitHub",
|
|
ClientID: id,
|
|
ClientSecret: os.Getenv("GITHUB_CLIENT_SECRET"),
|
|
AuthURL: "https://github.com/login/oauth/authorize",
|
|
TokenURL: "https://github.com/login/oauth/access_token",
|
|
UserInfoURL: "https://api.github.com/user",
|
|
Scopes: []string{"user:email"},
|
|
}
|
|
}
|
|
|
|
if id := os.Getenv("DISCORD_CLIENT_ID"); id != "" {
|
|
h.providers["discord"] = provider{
|
|
Name: "Discord",
|
|
ClientID: id,
|
|
ClientSecret: os.Getenv("DISCORD_CLIENT_SECRET"),
|
|
AuthURL: "https://discord.com/api/oauth2/authorize",
|
|
TokenURL: "https://discord.com/api/oauth2/token",
|
|
UserInfoURL: "https://discord.com/api/users/@me",
|
|
Scopes: []string{"identify", "email"},
|
|
}
|
|
}
|
|
|
|
return h
|
|
}
|
|
|
|
func (h *Handler) Routes() chi.Router {
|
|
r := chi.NewRouter()
|
|
r.Get("/google", h.initiate)
|
|
r.Get("/github", h.initiate)
|
|
r.Get("/discord", h.initiate)
|
|
r.Get("/callback", h.callback)
|
|
r.Get("/validate", h.validate)
|
|
r.Get("/user", h.user)
|
|
r.Get("/providers", h.listProviders)
|
|
r.Post("/logout", h.logout)
|
|
return r
|
|
}
|
|
|
|
func (h *Handler) initiate(w http.ResponseWriter, r *http.Request) {
|
|
providerName := strings.TrimPrefix(r.URL.Path, "/auth/")
|
|
if _, ok := h.providers[providerName]; !ok {
|
|
http.Error(w, "unknown provider", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
p := h.providers[providerName]
|
|
state := oauthState{
|
|
Provider: providerName,
|
|
TenantID: r.URL.Query().Get("tenant"),
|
|
Redirect: r.URL.Query().Get("redirect"),
|
|
Callback: r.URL.Query().Get("callback"),
|
|
Timestamp: time.Now().Unix(),
|
|
}
|
|
|
|
stateStr, err := h.encodeState(state)
|
|
if err != nil {
|
|
slog.Error("encode state", "error", err)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
authURL := fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s",
|
|
p.AuthURL,
|
|
url.QueryEscape(p.ClientID),
|
|
url.QueryEscape(h.baseURL+"/auth/callback"),
|
|
url.QueryEscape(strings.Join(p.Scopes, " ")),
|
|
url.QueryEscape(stateStr),
|
|
)
|
|
|
|
if providerName == "discord" {
|
|
authURL += "&prompt=consent"
|
|
}
|
|
|
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
|
}
|
|
|
|
func (h *Handler) callback(w http.ResponseWriter, r *http.Request) {
|
|
code := r.URL.Query().Get("code")
|
|
stateStr := r.URL.Query().Get("state")
|
|
|
|
if code == "" || stateStr == "" {
|
|
http.Error(w, "missing code or state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
state, err := h.decodeState(stateStr)
|
|
if err != nil {
|
|
slog.Error("decode state", "error", err)
|
|
http.Error(w, "invalid state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if time.Now().Unix()-state.Timestamp > 600 {
|
|
http.Error(w, "state expired", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
p, ok := h.providers[state.Provider]
|
|
if !ok {
|
|
http.Error(w, "unknown provider", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
token, err := h.exchangeCode(r.Context(), p, code)
|
|
if err != nil {
|
|
slog.Error("exchange code", "error", err)
|
|
http.Error(w, "auth failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
info, err := h.getUserInfo(r.Context(), p, state.Provider, token)
|
|
if err != nil {
|
|
slog.Error("get user info", "error", err)
|
|
http.Error(w, "failed to get user info", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
user, err := h.findOrCreateUser(r.Context(), state.Provider, info)
|
|
if err != nil {
|
|
slog.Error("find or create user", "error", err)
|
|
http.Error(w, "failed to create user", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
session, err := h.database.CreateSession(r.Context(), user.ID)
|
|
if err != nil {
|
|
slog.Error("create session", "error", err)
|
|
http.Error(w, "failed to create session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if state.Callback != "" {
|
|
callbackURL := state.Callback
|
|
if strings.Contains(callbackURL, "?") {
|
|
callbackURL += "&token=" + session.Token
|
|
} else {
|
|
callbackURL += "?token=" + session.Token
|
|
}
|
|
http.Redirect(w, r, callbackURL, http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
|
|
redirect := state.Redirect
|
|
if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
|
redirect = "/"
|
|
}
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "writekit_session",
|
|
Value: session.Token,
|
|
Path: "/",
|
|
Expires: session.ExpiresAt,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Secure: !strings.Contains(h.baseURL, "localhost"),
|
|
})
|
|
|
|
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
|
}
|
|
|
|
func (h *Handler) validate(w http.ResponseWriter, r *http.Request) {
|
|
token := r.URL.Query().Get("token")
|
|
if token == "" {
|
|
http.Error(w, "missing token", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
session, err := h.database.ValidateSession(r.Context(), token)
|
|
if err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if session == nil {
|
|
http.Error(w, "invalid session", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func (h *Handler) user(w http.ResponseWriter, r *http.Request) {
|
|
token := r.URL.Query().Get("token")
|
|
if token == "" {
|
|
http.Error(w, "missing token", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
session, err := h.database.ValidateSession(r.Context(), token)
|
|
if err != nil || session == nil {
|
|
http.Error(w, "invalid session", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
user, err := h.database.GetUserByID(r.Context(), session.UserID)
|
|
if err != nil || user == nil {
|
|
http.Error(w, "user not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"id": user.ID,
|
|
"email": user.Email,
|
|
"name": user.Name,
|
|
"avatar_url": user.AvatarURL,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) listProviders(w http.ResponseWriter, r *http.Request) {
|
|
providers := []map[string]string{}
|
|
if _, ok := h.providers["google"]; ok {
|
|
providers = append(providers, map[string]string{"id": "google", "name": "Google"})
|
|
}
|
|
if _, ok := h.providers["github"]; ok {
|
|
providers = append(providers, map[string]string{"id": "github", "name": "GitHub"})
|
|
}
|
|
if _, ok := h.providers["discord"]; ok {
|
|
providers = append(providers, map[string]string{"id": "discord", "name": "Discord"})
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]any{"providers": providers})
|
|
}
|
|
|
|
func (h *Handler) logout(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("writekit_session")
|
|
if err == nil && cookie.Value != "" {
|
|
h.database.DeleteSession(r.Context(), cookie.Value)
|
|
}
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "writekit_session",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Secure: !strings.Contains(h.baseURL, "localhost"),
|
|
})
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func (h *Handler) exchangeCode(ctx context.Context, p provider, code string) (*oauthToken, error) {
|
|
data := url.Values{}
|
|
data.Set("client_id", p.ClientID)
|
|
data.Set("client_secret", p.ClientSecret)
|
|
data.Set("code", code)
|
|
data.Set("redirect_uri", h.baseURL+"/auth/callback")
|
|
data.Set("grant_type", "authorization_code")
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", p.TokenURL, strings.NewReader(data.Encode()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("token exchange failed: %s", body)
|
|
}
|
|
|
|
var token oauthToken
|
|
if err := json.Unmarshal(body, &token); err != nil {
|
|
return nil, err
|
|
}
|
|
return &token, nil
|
|
}
|
|
|
|
func (h *Handler) getUserInfo(ctx context.Context, p provider, providerName string, token *oauthToken) (*userInfo, error) {
|
|
req, err := http.NewRequestWithContext(ctx, "GET", p.UserInfoURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("user info failed: %s", body)
|
|
}
|
|
|
|
var raw map[string]any
|
|
if err := json.Unmarshal(body, &raw); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
info := &userInfo{}
|
|
|
|
switch providerName {
|
|
case "google":
|
|
info.ID = getString(raw, "id")
|
|
info.Email = getString(raw, "email")
|
|
info.Name = getString(raw, "name")
|
|
info.AvatarURL = getString(raw, "picture")
|
|
|
|
case "github":
|
|
info.ID = fmt.Sprintf("%v", raw["id"])
|
|
info.Email = getString(raw, "email")
|
|
info.Name = getString(raw, "name")
|
|
if info.Name == "" {
|
|
info.Name = getString(raw, "login")
|
|
}
|
|
info.AvatarURL = getString(raw, "avatar_url")
|
|
|
|
if info.Email == "" {
|
|
info.Email, _ = h.getGitHubEmail(ctx, token)
|
|
}
|
|
|
|
case "discord":
|
|
info.ID = getString(raw, "id")
|
|
info.Email = getString(raw, "email")
|
|
info.Name = getString(raw, "global_name")
|
|
if info.Name == "" {
|
|
info.Name = getString(raw, "username")
|
|
}
|
|
if avatar := getString(raw, "avatar"); avatar != "" {
|
|
info.AvatarURL = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", info.ID, avatar)
|
|
}
|
|
}
|
|
|
|
if info.Email == "" {
|
|
return nil, fmt.Errorf("no email from provider")
|
|
}
|
|
return info, nil
|
|
}
|
|
|
|
func (h *Handler) getGitHubEmail(ctx context.Context, token *oauthToken) (string, error) {
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var emails []struct {
|
|
Email string `json:"email"`
|
|
Primary bool `json:"primary"`
|
|
Verified bool `json:"verified"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for _, e := range emails {
|
|
if e.Primary && e.Verified {
|
|
return e.Email, nil
|
|
}
|
|
}
|
|
for _, e := range emails {
|
|
if e.Verified {
|
|
return e.Email, nil
|
|
}
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func (h *Handler) findOrCreateUser(ctx context.Context, providerName string, info *userInfo) (*db.User, error) {
|
|
user, err := h.database.GetUserByIdentity(ctx, providerName, info.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user != nil {
|
|
return user, nil
|
|
}
|
|
|
|
user, err = h.database.GetUserByEmail(ctx, info.Email)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user != nil {
|
|
h.database.AddUserIdentity(ctx, user.ID, providerName, info.ID, info.Email)
|
|
return user, nil
|
|
}
|
|
|
|
user, err = h.database.CreateUser(ctx, info.Email, info.Name, info.AvatarURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
h.database.AddUserIdentity(ctx, user.ID, providerName, info.ID, info.Email)
|
|
return user, nil
|
|
}
|
|
|
|
func (h *Handler) encodeState(state oauthState) (string, error) {
|
|
data, err := json.Marshal(state)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
mac := hmac.New(sha256.New, h.sessionSecret)
|
|
mac.Write(data)
|
|
sig := mac.Sum(nil)
|
|
return base64.URLEncoding.EncodeToString(append(data, sig...)), nil
|
|
}
|
|
|
|
func (h *Handler) decodeState(s string) (*oauthState, error) {
|
|
payload, err := base64.URLEncoding.DecodeString(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(payload) < 32 {
|
|
return nil, fmt.Errorf("invalid state")
|
|
}
|
|
|
|
data := payload[:len(payload)-32]
|
|
sig := payload[len(payload)-32:]
|
|
|
|
mac := hmac.New(sha256.New, h.sessionSecret)
|
|
mac.Write(data)
|
|
if !hmac.Equal(sig, mac.Sum(nil)) {
|
|
return nil, fmt.Errorf("invalid signature")
|
|
}
|
|
|
|
var state oauthState
|
|
if err := json.Unmarshal(data, &state); err != nil {
|
|
return nil, err
|
|
}
|
|
return &state, nil
|
|
}
|
|
|
|
func getString(m map[string]any, key string) string {
|
|
if v, ok := m[key]; ok {
|
|
if s, ok := v.(string); ok {
|
|
return s
|
|
}
|
|
}
|
|
return ""
|
|
}
|