208 lines
4.8 KiB
Go
208 lines
4.8 KiB
Go
package server
|
|
|
|
import (
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/writekitapp/writekit/internal/auth"
|
|
"github.com/writekitapp/writekit/internal/tenant"
|
|
)
|
|
|
|
func (s *Server) publicAPIRoutes() chi.Router {
|
|
r := chi.NewRouter()
|
|
r.Use(s.apiKeyMiddleware)
|
|
r.Use(s.apiRateLimitMiddleware(s.rateLimiter))
|
|
|
|
r.Get("/posts", s.apiListPosts)
|
|
r.Post("/posts", s.apiCreatePost)
|
|
r.Get("/posts/{slug}", s.apiGetPost)
|
|
r.Put("/posts/{slug}", s.apiUpdatePost)
|
|
r.Delete("/posts/{slug}", s.apiDeletePost)
|
|
|
|
r.Get("/settings", s.apiGetSettings)
|
|
|
|
return r
|
|
}
|
|
|
|
func (s *Server) apiKeyMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tenantID, ok := r.Context().Value(tenantIDKey).(string)
|
|
if !ok || tenantID == "" {
|
|
jsonError(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
|
|
key := extractAPIKey(r)
|
|
if key != "" {
|
|
db, err := s.tenantPool.Get(tenantID)
|
|
if err != nil {
|
|
jsonError(w, http.StatusInternalServerError, "database error")
|
|
return
|
|
}
|
|
|
|
q := tenant.NewQueries(db)
|
|
valid, err := q.ValidateAPIKey(r.Context(), key)
|
|
if err != nil {
|
|
jsonError(w, http.StatusInternalServerError, "validation error")
|
|
return
|
|
}
|
|
if valid {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
jsonError(w, http.StatusUnauthorized, "invalid API key")
|
|
return
|
|
}
|
|
|
|
if GetDemoInfo(r).IsDemo {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
userID := auth.GetUserID(r)
|
|
if userID == "" {
|
|
jsonError(w, http.StatusUnauthorized, "API key required")
|
|
return
|
|
}
|
|
|
|
isOwner, err := s.database.IsUserTenantOwner(r.Context(), userID, tenantID)
|
|
if err != nil || !isOwner {
|
|
jsonError(w, http.StatusUnauthorized, "API key required")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func extractAPIKey(r *http.Request) string {
|
|
auth := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
|
return strings.TrimPrefix(auth, "Bearer ")
|
|
}
|
|
return r.URL.Query().Get("api_key")
|
|
}
|
|
|
|
type paginatedPostsResponse struct {
|
|
Posts []postResponse `json:"posts"`
|
|
Total int `json:"total"`
|
|
Limit int `json:"limit"`
|
|
Offset int `json:"offset"`
|
|
}
|
|
|
|
func (s *Server) apiListPosts(w http.ResponseWriter, r *http.Request) {
|
|
tenantID := r.Context().Value(tenantIDKey).(string)
|
|
|
|
db, err := s.tenantPool.Get(tenantID)
|
|
if err != nil {
|
|
jsonError(w, http.StatusInternalServerError, "database error")
|
|
return
|
|
}
|
|
|
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
|
offset, _ := strconv.Atoi(r.URL.Query().Get("offset"))
|
|
tag := r.URL.Query().Get("tag")
|
|
includeContent := r.URL.Query().Get("include") == "content"
|
|
|
|
q := tenant.NewQueries(db)
|
|
result, err := q.ListPostsPaginated(r.Context(), tenant.ListPostsOptions{
|
|
Limit: limit,
|
|
Offset: offset,
|
|
Tag: tag,
|
|
})
|
|
if err != nil {
|
|
jsonError(w, http.StatusInternalServerError, "failed to list posts")
|
|
return
|
|
}
|
|
|
|
posts := make([]postResponse, len(result.Posts))
|
|
for i, p := range result.Posts {
|
|
posts[i] = postToResponse(&p, includeContent)
|
|
}
|
|
|
|
jsonResponse(w, http.StatusOK, paginatedPostsResponse{
|
|
Posts: posts,
|
|
Total: result.Total,
|
|
Limit: limit,
|
|
Offset: offset,
|
|
})
|
|
}
|
|
|
|
func (s *Server) apiGetPost(w http.ResponseWriter, r *http.Request) {
|
|
tenantID := r.Context().Value(tenantIDKey).(string)
|
|
slug := chi.URLParam(r, "slug")
|
|
|
|
db, err := s.tenantPool.Get(tenantID)
|
|
if err != nil {
|
|
jsonError(w, http.StatusInternalServerError, "database error")
|
|
return
|
|
}
|
|
|
|
q := tenant.NewQueries(db)
|
|
post, err := q.GetPost(r.Context(), slug)
|
|
if err != nil {
|
|
jsonError(w, http.StatusInternalServerError, "failed to get post")
|
|
return
|
|
}
|
|
if post == nil {
|
|
jsonError(w, http.StatusNotFound, "post not found")
|
|
return
|
|
}
|
|
|
|
jsonResponse(w, http.StatusOK, postToResponse(post, true))
|
|
}
|
|
|
|
func (s *Server) apiCreatePost(w http.ResponseWriter, r *http.Request) {
|
|
s.createPost(w, r)
|
|
}
|
|
|
|
func (s *Server) apiUpdatePost(w http.ResponseWriter, r *http.Request) {
|
|
s.updatePost(w, r)
|
|
}
|
|
|
|
func (s *Server) apiDeletePost(w http.ResponseWriter, r *http.Request) {
|
|
s.deletePost(w, r)
|
|
}
|
|
|
|
var publicSettingsKeys = []string{
|
|
"site_name",
|
|
"site_description",
|
|
"author_name",
|
|
"author_role",
|
|
"author_bio",
|
|
"author_photo",
|
|
"twitter_handle",
|
|
"github_handle",
|
|
"linkedin_handle",
|
|
"email",
|
|
"accent_color",
|
|
"font",
|
|
}
|
|
|
|
func (s *Server) apiGetSettings(w http.ResponseWriter, r *http.Request) {
|
|
tenantID := r.Context().Value(tenantIDKey).(string)
|
|
|
|
db, err := s.tenantPool.Get(tenantID)
|
|
if err != nil {
|
|
jsonError(w, http.StatusInternalServerError, "database error")
|
|
return
|
|
}
|
|
|
|
q := tenant.NewQueries(db)
|
|
allSettings, err := q.GetSettings(r.Context())
|
|
if err != nil {
|
|
jsonError(w, http.StatusInternalServerError, "failed to get settings")
|
|
return
|
|
}
|
|
|
|
result := make(map[string]string)
|
|
for _, key := range publicSettingsKeys {
|
|
if val, ok := allSettings[key]; ok {
|
|
result[key] = val
|
|
}
|
|
}
|
|
|
|
jsonResponse(w, http.StatusOK, result)
|
|
}
|