55 lines
1.2 KiB
Go
55 lines
1.2 KiB
Go
|
|
package auth
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
|
||
|
|
"github.com/writekitapp/writekit/internal/db"
|
||
|
|
)
|
||
|
|
|
||
|
|
type ctxKey string
|
||
|
|
|
||
|
|
const userIDKey ctxKey = "userID"
|
||
|
|
|
||
|
|
func SessionMiddleware(database *db.DB) func(http.Handler) http.Handler {
|
||
|
|
return func(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
token := extractToken(r)
|
||
|
|
if token == "" {
|
||
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
session, err := database.ValidateSession(r.Context(), token)
|
||
|
|
if err != nil || session == nil {
|
||
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
ctx := context.WithValue(r.Context(), userIDKey, session.UserID)
|
||
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func GetUserID(r *http.Request) string {
|
||
|
|
if id, ok := r.Context().Value(userIDKey).(string); ok {
|
||
|
|
return id
|
||
|
|
}
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
|
||
|
|
func extractToken(r *http.Request) string {
|
||
|
|
if cookie, err := r.Cookie("writekit_session"); err == nil && cookie.Value != "" {
|
||
|
|
return cookie.Value
|
||
|
|
}
|
||
|
|
|
||
|
|
auth := r.Header.Get("Authorization")
|
||
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
||
|
|
return strings.TrimPrefix(auth, "Bearer ")
|
||
|
|
}
|
||
|
|
|
||
|
|
return r.URL.Query().Get("token")
|
||
|
|
}
|