init
This commit is contained in:
commit
d69342b2e9
160 changed files with 28681 additions and 0 deletions
54
internal/auth/middleware.go
Normal file
54
internal/auth/middleware.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/writekitapp/writekit/internal/db"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const userIDKey ctxKey = "userID"
|
||||
|
||||
func SessionMiddleware(database *db.DB) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := extractToken(r)
|
||||
if token == "" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := database.ValidateSession(r.Context(), token)
|
||||
if err != nil || session == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userIDKey, session.UserID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserID(r *http.Request) string {
|
||||
if id, ok := r.Context().Value(userIDKey).(string); ok {
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractToken(r *http.Request) string {
|
||||
if cookie, err := r.Cookie("writekit_session"); err == nil && cookie.Value != "" {
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
auth := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
return strings.TrimPrefix(auth, "Bearer ")
|
||||
}
|
||||
|
||||
return r.URL.Query().Get("token")
|
||||
}
|
||||
535
internal/auth/oauth.go
Normal file
535
internal/auth/oauth.go
Normal file
|
|
@ -0,0 +1,535 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/writekitapp/writekit/internal/db"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
database *db.DB
|
||||
sessionSecret []byte
|
||||
baseURL string
|
||||
providers map[string]provider
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
Name string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type oauthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type oauthState struct {
|
||||
Provider string `json:"p"`
|
||||
TenantID string `json:"t,omitempty"`
|
||||
Redirect string `json:"r,omitempty"`
|
||||
Callback string `json:"c,omitempty"`
|
||||
Timestamp int64 `json:"ts"`
|
||||
}
|
||||
|
||||
type userInfo struct {
|
||||
ID string
|
||||
Email string
|
||||
Name string
|
||||
AvatarURL string
|
||||
}
|
||||
|
||||
func NewHandler(database *db.DB) *Handler {
|
||||
baseURL := os.Getenv("BASE_URL")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://writekit.dev"
|
||||
}
|
||||
|
||||
secret := os.Getenv("SESSION_SECRET")
|
||||
if secret == "" {
|
||||
secret = "dev-secret-change-in-production"
|
||||
}
|
||||
|
||||
h := &Handler{
|
||||
database: database,
|
||||
sessionSecret: []byte(secret),
|
||||
baseURL: baseURL,
|
||||
providers: make(map[string]provider),
|
||||
}
|
||||
|
||||
if id := os.Getenv("GOOGLE_CLIENT_ID"); id != "" {
|
||||
h.providers["google"] = provider{
|
||||
Name: "Google",
|
||||
ClientID: id,
|
||||
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
Scopes: []string{"email", "profile"},
|
||||
}
|
||||
}
|
||||
|
||||
if id := os.Getenv("GITHUB_CLIENT_ID"); id != "" {
|
||||
h.providers["github"] = provider{
|
||||
Name: "GitHub",
|
||||
ClientID: id,
|
||||
ClientSecret: os.Getenv("GITHUB_CLIENT_SECRET"),
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
Scopes: []string{"user:email"},
|
||||
}
|
||||
}
|
||||
|
||||
if id := os.Getenv("DISCORD_CLIENT_ID"); id != "" {
|
||||
h.providers["discord"] = provider{
|
||||
Name: "Discord",
|
||||
ClientID: id,
|
||||
ClientSecret: os.Getenv("DISCORD_CLIENT_SECRET"),
|
||||
AuthURL: "https://discord.com/api/oauth2/authorize",
|
||||
TokenURL: "https://discord.com/api/oauth2/token",
|
||||
UserInfoURL: "https://discord.com/api/users/@me",
|
||||
Scopes: []string{"identify", "email"},
|
||||
}
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handler) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/google", h.initiate)
|
||||
r.Get("/github", h.initiate)
|
||||
r.Get("/discord", h.initiate)
|
||||
r.Get("/callback", h.callback)
|
||||
r.Get("/validate", h.validate)
|
||||
r.Get("/user", h.user)
|
||||
r.Get("/providers", h.listProviders)
|
||||
r.Post("/logout", h.logout)
|
||||
return r
|
||||
}
|
||||
|
||||
func (h *Handler) initiate(w http.ResponseWriter, r *http.Request) {
|
||||
providerName := strings.TrimPrefix(r.URL.Path, "/auth/")
|
||||
if _, ok := h.providers[providerName]; !ok {
|
||||
http.Error(w, "unknown provider", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
p := h.providers[providerName]
|
||||
state := oauthState{
|
||||
Provider: providerName,
|
||||
TenantID: r.URL.Query().Get("tenant"),
|
||||
Redirect: r.URL.Query().Get("redirect"),
|
||||
Callback: r.URL.Query().Get("callback"),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
stateStr, err := h.encodeState(state)
|
||||
if err != nil {
|
||||
slog.Error("encode state", "error", err)
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s",
|
||||
p.AuthURL,
|
||||
url.QueryEscape(p.ClientID),
|
||||
url.QueryEscape(h.baseURL+"/auth/callback"),
|
||||
url.QueryEscape(strings.Join(p.Scopes, " ")),
|
||||
url.QueryEscape(stateStr),
|
||||
)
|
||||
|
||||
if providerName == "discord" {
|
||||
authURL += "&prompt=consent"
|
||||
}
|
||||
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (h *Handler) callback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
stateStr := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" || stateStr == "" {
|
||||
http.Error(w, "missing code or state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
state, err := h.decodeState(stateStr)
|
||||
if err != nil {
|
||||
slog.Error("decode state", "error", err)
|
||||
http.Error(w, "invalid state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if time.Now().Unix()-state.Timestamp > 600 {
|
||||
http.Error(w, "state expired", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
p, ok := h.providers[state.Provider]
|
||||
if !ok {
|
||||
http.Error(w, "unknown provider", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.exchangeCode(r.Context(), p, code)
|
||||
if err != nil {
|
||||
slog.Error("exchange code", "error", err)
|
||||
http.Error(w, "auth failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := h.getUserInfo(r.Context(), p, state.Provider, token)
|
||||
if err != nil {
|
||||
slog.Error("get user info", "error", err)
|
||||
http.Error(w, "failed to get user info", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.findOrCreateUser(r.Context(), state.Provider, info)
|
||||
if err != nil {
|
||||
slog.Error("find or create user", "error", err)
|
||||
http.Error(w, "failed to create user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.database.CreateSession(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
slog.Error("create session", "error", err)
|
||||
http.Error(w, "failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if state.Callback != "" {
|
||||
callbackURL := state.Callback
|
||||
if strings.Contains(callbackURL, "?") {
|
||||
callbackURL += "&token=" + session.Token
|
||||
} else {
|
||||
callbackURL += "?token=" + session.Token
|
||||
}
|
||||
http.Redirect(w, r, callbackURL, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
redirect := state.Redirect
|
||||
if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
||||
redirect = "/"
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "writekit_session",
|
||||
Value: session.Token,
|
||||
Path: "/",
|
||||
Expires: session.ExpiresAt,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: !strings.Contains(h.baseURL, "localhost"),
|
||||
})
|
||||
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (h *Handler) validate(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.URL.Query().Get("token")
|
||||
if token == "" {
|
||||
http.Error(w, "missing token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.database.ValidateSession(r.Context(), token)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
http.Error(w, "invalid session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (h *Handler) user(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.URL.Query().Get("token")
|
||||
if token == "" {
|
||||
http.Error(w, "missing token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.database.ValidateSession(r.Context(), token)
|
||||
if err != nil || session == nil {
|
||||
http.Error(w, "invalid session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.database.GetUserByID(r.Context(), session.UserID)
|
||||
if err != nil || user == nil {
|
||||
http.Error(w, "user not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": user.ID,
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
"avatar_url": user.AvatarURL,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) listProviders(w http.ResponseWriter, r *http.Request) {
|
||||
providers := []map[string]string{}
|
||||
if _, ok := h.providers["google"]; ok {
|
||||
providers = append(providers, map[string]string{"id": "google", "name": "Google"})
|
||||
}
|
||||
if _, ok := h.providers["github"]; ok {
|
||||
providers = append(providers, map[string]string{"id": "github", "name": "GitHub"})
|
||||
}
|
||||
if _, ok := h.providers["discord"]; ok {
|
||||
providers = append(providers, map[string]string{"id": "discord", "name": "Discord"})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{"providers": providers})
|
||||
}
|
||||
|
||||
func (h *Handler) logout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("writekit_session")
|
||||
if err == nil && cookie.Value != "" {
|
||||
h.database.DeleteSession(r.Context(), cookie.Value)
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "writekit_session",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: !strings.Contains(h.baseURL, "localhost"),
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (h *Handler) exchangeCode(ctx context.Context, p provider, code string) (*oauthToken, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", p.ClientID)
|
||||
data.Set("client_secret", p.ClientSecret)
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", h.baseURL+"/auth/callback")
|
||||
data.Set("grant_type", "authorization_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", p.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token exchange failed: %s", body)
|
||||
}
|
||||
|
||||
var token oauthToken
|
||||
if err := json.Unmarshal(body, &token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getUserInfo(ctx context.Context, p provider, providerName string, token *oauthToken) (*userInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("user info failed: %s", body)
|
||||
}
|
||||
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(body, &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := &userInfo{}
|
||||
|
||||
switch providerName {
|
||||
case "google":
|
||||
info.ID = getString(raw, "id")
|
||||
info.Email = getString(raw, "email")
|
||||
info.Name = getString(raw, "name")
|
||||
info.AvatarURL = getString(raw, "picture")
|
||||
|
||||
case "github":
|
||||
info.ID = fmt.Sprintf("%v", raw["id"])
|
||||
info.Email = getString(raw, "email")
|
||||
info.Name = getString(raw, "name")
|
||||
if info.Name == "" {
|
||||
info.Name = getString(raw, "login")
|
||||
}
|
||||
info.AvatarURL = getString(raw, "avatar_url")
|
||||
|
||||
if info.Email == "" {
|
||||
info.Email, _ = h.getGitHubEmail(ctx, token)
|
||||
}
|
||||
|
||||
case "discord":
|
||||
info.ID = getString(raw, "id")
|
||||
info.Email = getString(raw, "email")
|
||||
info.Name = getString(raw, "global_name")
|
||||
if info.Name == "" {
|
||||
info.Name = getString(raw, "username")
|
||||
}
|
||||
if avatar := getString(raw, "avatar"); avatar != "" {
|
||||
info.AvatarURL = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", info.ID, avatar)
|
||||
}
|
||||
}
|
||||
|
||||
if info.Email == "" {
|
||||
return nil, fmt.Errorf("no email from provider")
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getGitHubEmail(ctx context.Context, token *oauthToken) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var emails []struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, e := range emails {
|
||||
if e.Primary && e.Verified {
|
||||
return e.Email, nil
|
||||
}
|
||||
}
|
||||
for _, e := range emails {
|
||||
if e.Verified {
|
||||
return e.Email, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (h *Handler) findOrCreateUser(ctx context.Context, providerName string, info *userInfo) (*db.User, error) {
|
||||
user, err := h.database.GetUserByIdentity(ctx, providerName, info.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user != nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
user, err = h.database.GetUserByEmail(ctx, info.Email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user != nil {
|
||||
h.database.AddUserIdentity(ctx, user.ID, providerName, info.ID, info.Email)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
user, err = h.database.CreateUser(ctx, info.Email, info.Name, info.AvatarURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.database.AddUserIdentity(ctx, user.ID, providerName, info.ID, info.Email)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (h *Handler) encodeState(state oauthState) (string, error) {
|
||||
data, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
mac := hmac.New(sha256.New, h.sessionSecret)
|
||||
mac.Write(data)
|
||||
sig := mac.Sum(nil)
|
||||
return base64.URLEncoding.EncodeToString(append(data, sig...)), nil
|
||||
}
|
||||
|
||||
func (h *Handler) decodeState(s string) (*oauthState, error) {
|
||||
payload, err := base64.URLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(payload) < 32 {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
data := payload[:len(payload)-32]
|
||||
sig := payload[len(payload)-32:]
|
||||
|
||||
mac := hmac.New(sha256.New, h.sessionSecret)
|
||||
mac.Write(data)
|
||||
if !hmac.Equal(sig, mac.Sum(nil)) {
|
||||
return nil, fmt.Errorf("invalid signature")
|
||||
}
|
||||
|
||||
var state oauthState
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func getString(m map[string]any, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue