package auth import ( "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "io" "log/slog" "net/http" "net/url" "os" "strings" "time" "github.com/go-chi/chi/v5" "github.com/writekitapp/writekit/internal/db" ) type Handler struct { database *db.DB sessionSecret []byte baseURL string providers map[string]provider } type provider struct { Name string ClientID string ClientSecret string AuthURL string TokenURL string UserInfoURL string Scopes []string } type oauthToken struct { AccessToken string `json:"access_token"` } type oauthState struct { Provider string `json:"p"` TenantID string `json:"t,omitempty"` Redirect string `json:"r,omitempty"` Callback string `json:"c,omitempty"` Timestamp int64 `json:"ts"` } type userInfo struct { ID string Email string Name string AvatarURL string } func NewHandler(database *db.DB) *Handler { baseURL := os.Getenv("BASE_URL") if baseURL == "" { baseURL = "https://writekit.dev" } secret := os.Getenv("SESSION_SECRET") if secret == "" { secret = "dev-secret-change-in-production" } h := &Handler{ database: database, sessionSecret: []byte(secret), baseURL: baseURL, providers: make(map[string]provider), } if id := os.Getenv("GOOGLE_CLIENT_ID"); id != "" { h.providers["google"] = provider{ Name: "Google", ClientID: id, ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"), AuthURL: "https://accounts.google.com/o/oauth2/v2/auth", TokenURL: "https://oauth2.googleapis.com/token", UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", Scopes: []string{"email", "profile"}, } } if id := os.Getenv("GITHUB_CLIENT_ID"); id != "" { h.providers["github"] = provider{ Name: "GitHub", ClientID: id, ClientSecret: os.Getenv("GITHUB_CLIENT_SECRET"), AuthURL: "https://github.com/login/oauth/authorize", TokenURL: "https://github.com/login/oauth/access_token", UserInfoURL: "https://api.github.com/user", Scopes: []string{"user:email"}, } } if id := os.Getenv("DISCORD_CLIENT_ID"); id != "" { h.providers["discord"] = provider{ Name: "Discord", ClientID: id, ClientSecret: os.Getenv("DISCORD_CLIENT_SECRET"), AuthURL: "https://discord.com/api/oauth2/authorize", TokenURL: "https://discord.com/api/oauth2/token", UserInfoURL: "https://discord.com/api/users/@me", Scopes: []string{"identify", "email"}, } } return h } func (h *Handler) Routes() chi.Router { r := chi.NewRouter() r.Get("/google", h.initiate) r.Get("/github", h.initiate) r.Get("/discord", h.initiate) r.Get("/callback", h.callback) r.Get("/validate", h.validate) r.Get("/user", h.user) r.Get("/providers", h.listProviders) r.Post("/logout", h.logout) return r } func (h *Handler) initiate(w http.ResponseWriter, r *http.Request) { providerName := strings.TrimPrefix(r.URL.Path, "/auth/") if _, ok := h.providers[providerName]; !ok { http.Error(w, "unknown provider", http.StatusBadRequest) return } p := h.providers[providerName] state := oauthState{ Provider: providerName, TenantID: r.URL.Query().Get("tenant"), Redirect: r.URL.Query().Get("redirect"), Callback: r.URL.Query().Get("callback"), Timestamp: time.Now().Unix(), } stateStr, err := h.encodeState(state) if err != nil { slog.Error("encode state", "error", err) http.Error(w, "internal error", http.StatusInternalServerError) return } authURL := fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s", p.AuthURL, url.QueryEscape(p.ClientID), url.QueryEscape(h.baseURL+"/auth/callback"), url.QueryEscape(strings.Join(p.Scopes, " ")), url.QueryEscape(stateStr), ) if providerName == "discord" { authURL += "&prompt=consent" } http.Redirect(w, r, authURL, http.StatusTemporaryRedirect) } func (h *Handler) callback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") stateStr := r.URL.Query().Get("state") if code == "" || stateStr == "" { http.Error(w, "missing code or state", http.StatusBadRequest) return } state, err := h.decodeState(stateStr) if err != nil { slog.Error("decode state", "error", err) http.Error(w, "invalid state", http.StatusBadRequest) return } if time.Now().Unix()-state.Timestamp > 600 { http.Error(w, "state expired", http.StatusBadRequest) return } p, ok := h.providers[state.Provider] if !ok { http.Error(w, "unknown provider", http.StatusBadRequest) return } token, err := h.exchangeCode(r.Context(), p, code) if err != nil { slog.Error("exchange code", "error", err) http.Error(w, "auth failed", http.StatusInternalServerError) return } info, err := h.getUserInfo(r.Context(), p, state.Provider, token) if err != nil { slog.Error("get user info", "error", err) http.Error(w, "failed to get user info", http.StatusInternalServerError) return } user, err := h.findOrCreateUser(r.Context(), state.Provider, info) if err != nil { slog.Error("find or create user", "error", err) http.Error(w, "failed to create user", http.StatusInternalServerError) return } session, err := h.database.CreateSession(r.Context(), user.ID) if err != nil { slog.Error("create session", "error", err) http.Error(w, "failed to create session", http.StatusInternalServerError) return } if state.Callback != "" { callbackURL := state.Callback if strings.Contains(callbackURL, "?") { callbackURL += "&token=" + session.Token } else { callbackURL += "?token=" + session.Token } http.Redirect(w, r, callbackURL, http.StatusTemporaryRedirect) return } redirect := state.Redirect if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { redirect = "/" } http.SetCookie(w, &http.Cookie{ Name: "writekit_session", Value: session.Token, Path: "/", Expires: session.ExpiresAt, HttpOnly: true, SameSite: http.SameSiteLaxMode, Secure: !strings.Contains(h.baseURL, "localhost"), }) http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) } func (h *Handler) validate(w http.ResponseWriter, r *http.Request) { token := r.URL.Query().Get("token") if token == "" { http.Error(w, "missing token", http.StatusBadRequest) return } session, err := h.database.ValidateSession(r.Context(), token) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } if session == nil { http.Error(w, "invalid session", http.StatusUnauthorized) return } w.WriteHeader(http.StatusOK) } func (h *Handler) user(w http.ResponseWriter, r *http.Request) { token := r.URL.Query().Get("token") if token == "" { http.Error(w, "missing token", http.StatusBadRequest) return } session, err := h.database.ValidateSession(r.Context(), token) if err != nil || session == nil { http.Error(w, "invalid session", http.StatusUnauthorized) return } user, err := h.database.GetUserByID(r.Context(), session.UserID) if err != nil || user == nil { http.Error(w, "user not found", http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]any{ "id": user.ID, "email": user.Email, "name": user.Name, "avatar_url": user.AvatarURL, }) } func (h *Handler) listProviders(w http.ResponseWriter, r *http.Request) { providers := []map[string]string{} if _, ok := h.providers["google"]; ok { providers = append(providers, map[string]string{"id": "google", "name": "Google"}) } if _, ok := h.providers["github"]; ok { providers = append(providers, map[string]string{"id": "github", "name": "GitHub"}) } if _, ok := h.providers["discord"]; ok { providers = append(providers, map[string]string{"id": "discord", "name": "Discord"}) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]any{"providers": providers}) } func (h *Handler) logout(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("writekit_session") if err == nil && cookie.Value != "" { h.database.DeleteSession(r.Context(), cookie.Value) } http.SetCookie(w, &http.Cookie{ Name: "writekit_session", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, SameSite: http.SameSiteLaxMode, Secure: !strings.Contains(h.baseURL, "localhost"), }) w.WriteHeader(http.StatusOK) } func (h *Handler) exchangeCode(ctx context.Context, p provider, code string) (*oauthToken, error) { data := url.Values{} data.Set("client_id", p.ClientID) data.Set("client_secret", p.ClientSecret) data.Set("code", code) data.Set("redirect_uri", h.baseURL+"/auth/callback") data.Set("grant_type", "authorization_code") req, err := http.NewRequestWithContext(ctx, "POST", p.TokenURL, strings.NewReader(data.Encode())) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token exchange failed: %s", body) } var token oauthToken if err := json.Unmarshal(body, &token); err != nil { return nil, err } return &token, nil } func (h *Handler) getUserInfo(ctx context.Context, p provider, providerName string, token *oauthToken) (*userInfo, error) { req, err := http.NewRequestWithContext(ctx, "GET", p.UserInfoURL, nil) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+token.AccessToken) req.Header.Set("Accept", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("user info failed: %s", body) } var raw map[string]any if err := json.Unmarshal(body, &raw); err != nil { return nil, err } info := &userInfo{} switch providerName { case "google": info.ID = getString(raw, "id") info.Email = getString(raw, "email") info.Name = getString(raw, "name") info.AvatarURL = getString(raw, "picture") case "github": info.ID = fmt.Sprintf("%v", raw["id"]) info.Email = getString(raw, "email") info.Name = getString(raw, "name") if info.Name == "" { info.Name = getString(raw, "login") } info.AvatarURL = getString(raw, "avatar_url") if info.Email == "" { info.Email, _ = h.getGitHubEmail(ctx, token) } case "discord": info.ID = getString(raw, "id") info.Email = getString(raw, "email") info.Name = getString(raw, "global_name") if info.Name == "" { info.Name = getString(raw, "username") } if avatar := getString(raw, "avatar"); avatar != "" { info.AvatarURL = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", info.ID, avatar) } } if info.Email == "" { return nil, fmt.Errorf("no email from provider") } return info, nil } func (h *Handler) getGitHubEmail(ctx context.Context, token *oauthToken) (string, error) { req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil) if err != nil { return "", err } req.Header.Set("Authorization", "Bearer "+token.AccessToken) req.Header.Set("Accept", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return "", err } defer resp.Body.Close() var emails []struct { Email string `json:"email"` Primary bool `json:"primary"` Verified bool `json:"verified"` } if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil { return "", err } for _, e := range emails { if e.Primary && e.Verified { return e.Email, nil } } for _, e := range emails { if e.Verified { return e.Email, nil } } return "", nil } func (h *Handler) findOrCreateUser(ctx context.Context, providerName string, info *userInfo) (*db.User, error) { user, err := h.database.GetUserByIdentity(ctx, providerName, info.ID) if err != nil { return nil, err } if user != nil { return user, nil } user, err = h.database.GetUserByEmail(ctx, info.Email) if err != nil { return nil, err } if user != nil { h.database.AddUserIdentity(ctx, user.ID, providerName, info.ID, info.Email) return user, nil } user, err = h.database.CreateUser(ctx, info.Email, info.Name, info.AvatarURL) if err != nil { return nil, err } h.database.AddUserIdentity(ctx, user.ID, providerName, info.ID, info.Email) return user, nil } func (h *Handler) encodeState(state oauthState) (string, error) { data, err := json.Marshal(state) if err != nil { return "", err } mac := hmac.New(sha256.New, h.sessionSecret) mac.Write(data) sig := mac.Sum(nil) return base64.URLEncoding.EncodeToString(append(data, sig...)), nil } func (h *Handler) decodeState(s string) (*oauthState, error) { payload, err := base64.URLEncoding.DecodeString(s) if err != nil { return nil, err } if len(payload) < 32 { return nil, fmt.Errorf("invalid state") } data := payload[:len(payload)-32] sig := payload[len(payload)-32:] mac := hmac.New(sha256.New, h.sessionSecret) mac.Write(data) if !hmac.Equal(sig, mac.Sum(nil)) { return nil, fmt.Errorf("invalid signature") } var state oauthState if err := json.Unmarshal(data, &state); err != nil { return nil, err } return &state, nil } func getString(m map[string]any, key string) string { if v, ok := m[key]; ok { if s, ok := v.(string); ok { return s } } return "" }