This commit is contained in:
Josh 2026-01-09 00:16:46 +02:00
commit d69342b2e9
160 changed files with 28681 additions and 0 deletions

View file

@ -0,0 +1,54 @@
package auth
import (
"context"
"net/http"
"strings"
"github.com/writekitapp/writekit/internal/db"
)
type ctxKey string
const userIDKey ctxKey = "userID"
func SessionMiddleware(database *db.DB) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := extractToken(r)
if token == "" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
session, err := database.ValidateSession(r.Context(), token)
if err != nil || session == nil {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), userIDKey, session.UserID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func GetUserID(r *http.Request) string {
if id, ok := r.Context().Value(userIDKey).(string); ok {
return id
}
return ""
}
func extractToken(r *http.Request) string {
if cookie, err := r.Cookie("writekit_session"); err == nil && cookie.Value != "" {
return cookie.Value
}
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer ")
}
return r.URL.Query().Get("token")
}

535
internal/auth/oauth.go Normal file
View file

@ -0,0 +1,535 @@
package auth
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/writekitapp/writekit/internal/db"
)
type Handler struct {
database *db.DB
sessionSecret []byte
baseURL string
providers map[string]provider
}
type provider struct {
Name string
ClientID string
ClientSecret string
AuthURL string
TokenURL string
UserInfoURL string
Scopes []string
}
type oauthToken struct {
AccessToken string `json:"access_token"`
}
type oauthState struct {
Provider string `json:"p"`
TenantID string `json:"t,omitempty"`
Redirect string `json:"r,omitempty"`
Callback string `json:"c,omitempty"`
Timestamp int64 `json:"ts"`
}
type userInfo struct {
ID string
Email string
Name string
AvatarURL string
}
func NewHandler(database *db.DB) *Handler {
baseURL := os.Getenv("BASE_URL")
if baseURL == "" {
baseURL = "https://writekit.dev"
}
secret := os.Getenv("SESSION_SECRET")
if secret == "" {
secret = "dev-secret-change-in-production"
}
h := &Handler{
database: database,
sessionSecret: []byte(secret),
baseURL: baseURL,
providers: make(map[string]provider),
}
if id := os.Getenv("GOOGLE_CLIENT_ID"); id != "" {
h.providers["google"] = provider{
Name: "Google",
ClientID: id,
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
AuthURL: "https://accounts.google.com/o/oauth2/v2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
Scopes: []string{"email", "profile"},
}
}
if id := os.Getenv("GITHUB_CLIENT_ID"); id != "" {
h.providers["github"] = provider{
Name: "GitHub",
ClientID: id,
ClientSecret: os.Getenv("GITHUB_CLIENT_SECRET"),
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
Scopes: []string{"user:email"},
}
}
if id := os.Getenv("DISCORD_CLIENT_ID"); id != "" {
h.providers["discord"] = provider{
Name: "Discord",
ClientID: id,
ClientSecret: os.Getenv("DISCORD_CLIENT_SECRET"),
AuthURL: "https://discord.com/api/oauth2/authorize",
TokenURL: "https://discord.com/api/oauth2/token",
UserInfoURL: "https://discord.com/api/users/@me",
Scopes: []string{"identify", "email"},
}
}
return h
}
func (h *Handler) Routes() chi.Router {
r := chi.NewRouter()
r.Get("/google", h.initiate)
r.Get("/github", h.initiate)
r.Get("/discord", h.initiate)
r.Get("/callback", h.callback)
r.Get("/validate", h.validate)
r.Get("/user", h.user)
r.Get("/providers", h.listProviders)
r.Post("/logout", h.logout)
return r
}
func (h *Handler) initiate(w http.ResponseWriter, r *http.Request) {
providerName := strings.TrimPrefix(r.URL.Path, "/auth/")
if _, ok := h.providers[providerName]; !ok {
http.Error(w, "unknown provider", http.StatusBadRequest)
return
}
p := h.providers[providerName]
state := oauthState{
Provider: providerName,
TenantID: r.URL.Query().Get("tenant"),
Redirect: r.URL.Query().Get("redirect"),
Callback: r.URL.Query().Get("callback"),
Timestamp: time.Now().Unix(),
}
stateStr, err := h.encodeState(state)
if err != nil {
slog.Error("encode state", "error", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
authURL := fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s",
p.AuthURL,
url.QueryEscape(p.ClientID),
url.QueryEscape(h.baseURL+"/auth/callback"),
url.QueryEscape(strings.Join(p.Scopes, " ")),
url.QueryEscape(stateStr),
)
if providerName == "discord" {
authURL += "&prompt=consent"
}
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
}
func (h *Handler) callback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
stateStr := r.URL.Query().Get("state")
if code == "" || stateStr == "" {
http.Error(w, "missing code or state", http.StatusBadRequest)
return
}
state, err := h.decodeState(stateStr)
if err != nil {
slog.Error("decode state", "error", err)
http.Error(w, "invalid state", http.StatusBadRequest)
return
}
if time.Now().Unix()-state.Timestamp > 600 {
http.Error(w, "state expired", http.StatusBadRequest)
return
}
p, ok := h.providers[state.Provider]
if !ok {
http.Error(w, "unknown provider", http.StatusBadRequest)
return
}
token, err := h.exchangeCode(r.Context(), p, code)
if err != nil {
slog.Error("exchange code", "error", err)
http.Error(w, "auth failed", http.StatusInternalServerError)
return
}
info, err := h.getUserInfo(r.Context(), p, state.Provider, token)
if err != nil {
slog.Error("get user info", "error", err)
http.Error(w, "failed to get user info", http.StatusInternalServerError)
return
}
user, err := h.findOrCreateUser(r.Context(), state.Provider, info)
if err != nil {
slog.Error("find or create user", "error", err)
http.Error(w, "failed to create user", http.StatusInternalServerError)
return
}
session, err := h.database.CreateSession(r.Context(), user.ID)
if err != nil {
slog.Error("create session", "error", err)
http.Error(w, "failed to create session", http.StatusInternalServerError)
return
}
if state.Callback != "" {
callbackURL := state.Callback
if strings.Contains(callbackURL, "?") {
callbackURL += "&token=" + session.Token
} else {
callbackURL += "?token=" + session.Token
}
http.Redirect(w, r, callbackURL, http.StatusTemporaryRedirect)
return
}
redirect := state.Redirect
if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
redirect = "/"
}
http.SetCookie(w, &http.Cookie{
Name: "writekit_session",
Value: session.Token,
Path: "/",
Expires: session.ExpiresAt,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: !strings.Contains(h.baseURL, "localhost"),
})
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}
func (h *Handler) validate(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if token == "" {
http.Error(w, "missing token", http.StatusBadRequest)
return
}
session, err := h.database.ValidateSession(r.Context(), token)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if session == nil {
http.Error(w, "invalid session", http.StatusUnauthorized)
return
}
w.WriteHeader(http.StatusOK)
}
func (h *Handler) user(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if token == "" {
http.Error(w, "missing token", http.StatusBadRequest)
return
}
session, err := h.database.ValidateSession(r.Context(), token)
if err != nil || session == nil {
http.Error(w, "invalid session", http.StatusUnauthorized)
return
}
user, err := h.database.GetUserByID(r.Context(), session.UserID)
if err != nil || user == nil {
http.Error(w, "user not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"id": user.ID,
"email": user.Email,
"name": user.Name,
"avatar_url": user.AvatarURL,
})
}
func (h *Handler) listProviders(w http.ResponseWriter, r *http.Request) {
providers := []map[string]string{}
if _, ok := h.providers["google"]; ok {
providers = append(providers, map[string]string{"id": "google", "name": "Google"})
}
if _, ok := h.providers["github"]; ok {
providers = append(providers, map[string]string{"id": "github", "name": "GitHub"})
}
if _, ok := h.providers["discord"]; ok {
providers = append(providers, map[string]string{"id": "discord", "name": "Discord"})
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{"providers": providers})
}
func (h *Handler) logout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("writekit_session")
if err == nil && cookie.Value != "" {
h.database.DeleteSession(r.Context(), cookie.Value)
}
http.SetCookie(w, &http.Cookie{
Name: "writekit_session",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: !strings.Contains(h.baseURL, "localhost"),
})
w.WriteHeader(http.StatusOK)
}
func (h *Handler) exchangeCode(ctx context.Context, p provider, code string) (*oauthToken, error) {
data := url.Values{}
data.Set("client_id", p.ClientID)
data.Set("client_secret", p.ClientSecret)
data.Set("code", code)
data.Set("redirect_uri", h.baseURL+"/auth/callback")
data.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, "POST", p.TokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed: %s", body)
}
var token oauthToken
if err := json.Unmarshal(body, &token); err != nil {
return nil, err
}
return &token, nil
}
func (h *Handler) getUserInfo(ctx context.Context, p provider, providerName string, token *oauthToken) (*userInfo, error) {
req, err := http.NewRequestWithContext(ctx, "GET", p.UserInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("user info failed: %s", body)
}
var raw map[string]any
if err := json.Unmarshal(body, &raw); err != nil {
return nil, err
}
info := &userInfo{}
switch providerName {
case "google":
info.ID = getString(raw, "id")
info.Email = getString(raw, "email")
info.Name = getString(raw, "name")
info.AvatarURL = getString(raw, "picture")
case "github":
info.ID = fmt.Sprintf("%v", raw["id"])
info.Email = getString(raw, "email")
info.Name = getString(raw, "name")
if info.Name == "" {
info.Name = getString(raw, "login")
}
info.AvatarURL = getString(raw, "avatar_url")
if info.Email == "" {
info.Email, _ = h.getGitHubEmail(ctx, token)
}
case "discord":
info.ID = getString(raw, "id")
info.Email = getString(raw, "email")
info.Name = getString(raw, "global_name")
if info.Name == "" {
info.Name = getString(raw, "username")
}
if avatar := getString(raw, "avatar"); avatar != "" {
info.AvatarURL = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", info.ID, avatar)
}
}
if info.Email == "" {
return nil, fmt.Errorf("no email from provider")
}
return info, nil
}
func (h *Handler) getGitHubEmail(ctx context.Context, token *oauthToken) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
var emails []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
return "", err
}
for _, e := range emails {
if e.Primary && e.Verified {
return e.Email, nil
}
}
for _, e := range emails {
if e.Verified {
return e.Email, nil
}
}
return "", nil
}
func (h *Handler) findOrCreateUser(ctx context.Context, providerName string, info *userInfo) (*db.User, error) {
user, err := h.database.GetUserByIdentity(ctx, providerName, info.ID)
if err != nil {
return nil, err
}
if user != nil {
return user, nil
}
user, err = h.database.GetUserByEmail(ctx, info.Email)
if err != nil {
return nil, err
}
if user != nil {
h.database.AddUserIdentity(ctx, user.ID, providerName, info.ID, info.Email)
return user, nil
}
user, err = h.database.CreateUser(ctx, info.Email, info.Name, info.AvatarURL)
if err != nil {
return nil, err
}
h.database.AddUserIdentity(ctx, user.ID, providerName, info.ID, info.Email)
return user, nil
}
func (h *Handler) encodeState(state oauthState) (string, error) {
data, err := json.Marshal(state)
if err != nil {
return "", err
}
mac := hmac.New(sha256.New, h.sessionSecret)
mac.Write(data)
sig := mac.Sum(nil)
return base64.URLEncoding.EncodeToString(append(data, sig...)), nil
}
func (h *Handler) decodeState(s string) (*oauthState, error) {
payload, err := base64.URLEncoding.DecodeString(s)
if err != nil {
return nil, err
}
if len(payload) < 32 {
return nil, fmt.Errorf("invalid state")
}
data := payload[:len(payload)-32]
sig := payload[len(payload)-32:]
mac := hmac.New(sha256.New, h.sessionSecret)
mac.Write(data)
if !hmac.Equal(sig, mac.Sum(nil)) {
return nil, fmt.Errorf("invalid signature")
}
var state oauthState
if err := json.Unmarshal(data, &state); err != nil {
return nil, err
}
return &state, nil
}
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}