- Add OptionalSessionMiddleware for non-required auth checks - Add GetUserID helper function - Update import paths in auth and main - Update docker-compose with frontend build configuration - Clean up go.mod and go.sum Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
69 lines
1.7 KiB
Go
69 lines
1.7 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"writekit/internal/db"
|
|
)
|
|
|
|
type ctxKey string
|
|
|
|
const userIDKey ctxKey = "userID"
|
|
|
|
func SessionMiddleware(database *db.DB) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
token := extractToken(r)
|
|
if token == "" {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
session, err := database.ValidateSession(r.Context(), token)
|
|
if err != nil || session == nil {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userIDKey, session.UserID)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
func GetUserID(r *http.Request) string {
|
|
if id, ok := r.Context().Value(userIDKey).(string); ok {
|
|
return id
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func OptionalSessionMiddleware(database *db.DB) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
token := extractToken(r)
|
|
if token != "" {
|
|
if session, err := database.ValidateSession(r.Context(), token); err == nil && session != nil {
|
|
ctx := context.WithValue(r.Context(), userIDKey, session.UserID)
|
|
r = r.WithContext(ctx)
|
|
}
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func extractToken(r *http.Request) string {
|
|
if cookie, err := r.Cookie("writekit_session"); err == nil && cookie.Value != "" {
|
|
return cookie.Value
|
|
}
|
|
|
|
auth := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
|
return strings.TrimPrefix(auth, "Bearer ")
|
|
}
|
|
|
|
return r.URL.Query().Get("token")
|
|
}
|