writekit/internal/server/api.go

209 lines
4.8 KiB
Go
Raw Normal View History

2026-01-09 00:16:46 +02:00
package server
import (
"net/http"
"strconv"
"strings"
"github.com/go-chi/chi/v5"
"github.com/writekitapp/writekit/internal/auth"
"github.com/writekitapp/writekit/internal/tenant"
)
func (s *Server) publicAPIRoutes() chi.Router {
r := chi.NewRouter()
r.Use(s.apiKeyMiddleware)
r.Use(s.apiRateLimitMiddleware(s.rateLimiter))
r.Get("/posts", s.apiListPosts)
r.Post("/posts", s.apiCreatePost)
r.Get("/posts/{slug}", s.apiGetPost)
r.Put("/posts/{slug}", s.apiUpdatePost)
r.Delete("/posts/{slug}", s.apiDeletePost)
r.Get("/settings", s.apiGetSettings)
return r
}
func (s *Server) apiKeyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tenantID, ok := r.Context().Value(tenantIDKey).(string)
if !ok || tenantID == "" {
jsonError(w, http.StatusUnauthorized, "unauthorized")
return
}
key := extractAPIKey(r)
if key != "" {
db, err := s.tenantPool.Get(tenantID)
if err != nil {
jsonError(w, http.StatusInternalServerError, "database error")
return
}
q := tenant.NewQueries(db)
valid, err := q.ValidateAPIKey(r.Context(), key)
if err != nil {
jsonError(w, http.StatusInternalServerError, "validation error")
return
}
if valid {
next.ServeHTTP(w, r)
return
}
jsonError(w, http.StatusUnauthorized, "invalid API key")
return
}
if GetDemoInfo(r).IsDemo {
next.ServeHTTP(w, r)
return
}
userID := auth.GetUserID(r)
if userID == "" {
jsonError(w, http.StatusUnauthorized, "API key required")
return
}
isOwner, err := s.database.IsUserTenantOwner(r.Context(), userID, tenantID)
if err != nil || !isOwner {
jsonError(w, http.StatusUnauthorized, "API key required")
return
}
next.ServeHTTP(w, r)
})
}
func extractAPIKey(r *http.Request) string {
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer ")
}
return r.URL.Query().Get("api_key")
}
type paginatedPostsResponse struct {
Posts []postResponse `json:"posts"`
Total int `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
func (s *Server) apiListPosts(w http.ResponseWriter, r *http.Request) {
tenantID := r.Context().Value(tenantIDKey).(string)
db, err := s.tenantPool.Get(tenantID)
if err != nil {
jsonError(w, http.StatusInternalServerError, "database error")
return
}
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
offset, _ := strconv.Atoi(r.URL.Query().Get("offset"))
tag := r.URL.Query().Get("tag")
includeContent := r.URL.Query().Get("include") == "content"
q := tenant.NewQueries(db)
result, err := q.ListPostsPaginated(r.Context(), tenant.ListPostsOptions{
Limit: limit,
Offset: offset,
Tag: tag,
})
if err != nil {
jsonError(w, http.StatusInternalServerError, "failed to list posts")
return
}
posts := make([]postResponse, len(result.Posts))
for i, p := range result.Posts {
posts[i] = postToResponse(&p, includeContent)
}
jsonResponse(w, http.StatusOK, paginatedPostsResponse{
Posts: posts,
Total: result.Total,
Limit: limit,
Offset: offset,
})
}
func (s *Server) apiGetPost(w http.ResponseWriter, r *http.Request) {
tenantID := r.Context().Value(tenantIDKey).(string)
slug := chi.URLParam(r, "slug")
db, err := s.tenantPool.Get(tenantID)
if err != nil {
jsonError(w, http.StatusInternalServerError, "database error")
return
}
q := tenant.NewQueries(db)
post, err := q.GetPost(r.Context(), slug)
if err != nil {
jsonError(w, http.StatusInternalServerError, "failed to get post")
return
}
if post == nil {
jsonError(w, http.StatusNotFound, "post not found")
return
}
jsonResponse(w, http.StatusOK, postToResponse(post, true))
}
func (s *Server) apiCreatePost(w http.ResponseWriter, r *http.Request) {
s.createPost(w, r)
}
func (s *Server) apiUpdatePost(w http.ResponseWriter, r *http.Request) {
s.updatePost(w, r)
}
func (s *Server) apiDeletePost(w http.ResponseWriter, r *http.Request) {
s.deletePost(w, r)
}
var publicSettingsKeys = []string{
"site_name",
"site_description",
"author_name",
"author_role",
"author_bio",
"author_photo",
"twitter_handle",
"github_handle",
"linkedin_handle",
"email",
"accent_color",
"font",
}
func (s *Server) apiGetSettings(w http.ResponseWriter, r *http.Request) {
tenantID := r.Context().Value(tenantIDKey).(string)
db, err := s.tenantPool.Get(tenantID)
if err != nil {
jsonError(w, http.StatusInternalServerError, "database error")
return
}
q := tenant.NewQueries(db)
allSettings, err := q.GetSettings(r.Context())
if err != nil {
jsonError(w, http.StatusInternalServerError, "failed to get settings")
return
}
result := make(map[string]string)
for _, key := range publicSettingsKeys {
if val, ok := allSettings[key]; ok {
result[key] = val
}
}
jsonResponse(w, http.StatusOK, result)
}