package server import ( "bytes" "context" "crypto/md5" "encoding/hex" "encoding/json" "fmt" "html/template" "log/slog" "net/http" "net/http/httputil" "net/url" "os" "time" "github.com/go-chi/chi/v5" "github.com/writekitapp/writekit/internal/auth" "github.com/writekitapp/writekit/internal/build/assets" "github.com/writekitapp/writekit/internal/build/templates" "github.com/writekitapp/writekit/internal/config" "github.com/writekitapp/writekit/internal/markdown" "github.com/writekitapp/writekit/internal/tenant" "github.com/writekitapp/writekit/studio" ) func (s *Server) serveBlog(w http.ResponseWriter, r *http.Request, subdomain string) { var tenantID string var demoInfo DemoInfo tenantID, ok := s.tenantCache.Get(subdomain) if !ok { t, err := s.database.GetTenantBySubdomain(r.Context(), subdomain) if err != nil || t == nil { d, err := s.database.GetDemoBySubdomain(r.Context(), subdomain) if err != nil || d == nil { s.notFound(w, r) return } tenantID = d.ID demoInfo = DemoInfo{IsDemo: true, ExpiresAt: d.ExpiresAt} s.tenantPool.MarkAsDemo(tenantID) s.ensureDemoSeeded(tenantID) } else { tenantID = t.ID } s.tenantCache.Set(subdomain, tenantID) } else { d, _ := s.database.GetDemoBySubdomain(r.Context(), subdomain) if d != nil { demoInfo = DemoInfo{IsDemo: true, ExpiresAt: d.ExpiresAt} s.tenantPool.MarkAsDemo(tenantID) } } ctx := context.WithValue(r.Context(), tenantIDKey, tenantID) ctx = context.WithValue(ctx, demoInfoKey, demoInfo) r = r.WithContext(ctx) mux := chi.NewRouter() mux.Get("/", s.blogHome) mux.Get("/posts", s.blogList) mux.Get("/posts/{slug}", s.blogPost) mux.Handle("/static/*", http.StripPrefix("/static/", assets.Handler())) mux.Route("/api/studio", func(r chi.Router) { r.Use(demoAwareSessionMiddleware(s.database)) r.Use(s.ownerMiddleware) r.Mount("/", s.studioRoutes()) }) mux.Mount("/api/v1", s.publicAPIRoutes()) mux.Mount("/api/reader", s.readerRoutes()) mux.Get("/studio", s.serveStudio) mux.Get("/studio/*", s.serveStudio) mux.Get("/sitemap.xml", s.sitemap) mux.Get("/robots.txt", s.robots) mux.ServeHTTP(w, r) } func (s *Server) blogHome(w http.ResponseWriter, r *http.Request) { tenantID := r.Context().Value(tenantIDKey).(string) db, err := s.tenantPool.Get(tenantID) if err != nil { slog.Error("blogHome: get tenant pool", "error", err, "tenantID", tenantID) http.Error(w, "internal error", http.StatusInternalServerError) return } q := tenant.NewQueries(db) s.recordPageView(q, r, "/", "") if html, etag, err := q.GetPage(r.Context(), "/"); err == nil && html != nil { s.servePreRendered(w, r, html, etag, "public, max-age=0, must-revalidate") return } posts, err := q.ListPosts(r.Context(), false) if err != nil { slog.Error("blogHome: list posts", "error", err) http.Error(w, "internal error", http.StatusInternalServerError) return } settings, _ := q.GetSettings(r.Context()) siteName := getSettingOr(settings, "site_name", "My Blog") siteDesc := getSettingOr(settings, "site_description", "") baseURL := getBaseURL(r.Host) showBadge := true if t, err := s.database.GetTenantByID(r.Context(), tenantID); err == nil && t != nil { tierInfo := config.GetTierInfo(t.Premium) showBadge = tierInfo.Config.BadgeRequired } postSummaries := make([]templates.PostSummary, 0, len(posts)) for _, p := range posts { if len(postSummaries) >= 10 { break } postSummaries = append(postSummaries, templates.PostSummary{ Slug: p.Slug, Title: p.Title, Description: p.Description, Date: timeOrZero(p.PublishedAt), }) } data := templates.HomeData{ PageData: templates.PageData{ Title: siteName, Description: siteDesc, CanonicalURL: baseURL + "/", OGType: "website", SiteName: siteName, Year: time.Now().Year(), Settings: settingsToMap(settings), NoIndex: GetDemoInfo(r).IsDemo, ShowBadge: showBadge, }, Posts: postSummaries, HasMore: len(posts) > 10, } html, err := templates.RenderHome(data) if err != nil { slog.Error("blogHome: render template", "error", err) http.Error(w, "render error", http.StatusInternalServerError) return } s.servePreRendered(w, r, html, computeETag(html), "public, max-age=0, must-revalidate") } func (s *Server) blogList(w http.ResponseWriter, r *http.Request) { tenantID := r.Context().Value(tenantIDKey).(string) db, err := s.tenantPool.Get(tenantID) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } q := tenant.NewQueries(db) s.recordPageView(q, r, "/posts", "") if html, etag, err := q.GetPage(r.Context(), "/posts"); err == nil && html != nil { s.servePreRendered(w, r, html, etag, "public, max-age=0, must-revalidate") return } posts, err := q.ListPosts(r.Context(), false) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } settings, _ := q.GetSettings(r.Context()) siteName := getSettingOr(settings, "site_name", "My Blog") baseURL := getBaseURL(r.Host) showBadge := true if t, err := s.database.GetTenantByID(r.Context(), tenantID); err == nil && t != nil { tierInfo := config.GetTierInfo(t.Premium) showBadge = tierInfo.Config.BadgeRequired } postSummaries := make([]templates.PostSummary, len(posts)) for i, p := range posts { postSummaries[i] = templates.PostSummary{ Slug: p.Slug, Title: p.Title, Description: p.Description, Date: timeOrZero(p.PublishedAt), } } data := templates.BlogData{ PageData: templates.PageData{ Title: "Posts - " + siteName, Description: "All posts", CanonicalURL: baseURL + "/posts", OGType: "website", SiteName: siteName, Year: time.Now().Year(), Settings: settingsToMap(settings), NoIndex: GetDemoInfo(r).IsDemo, ShowBadge: showBadge, }, Posts: postSummaries, } html, err := templates.RenderBlog(data) if err != nil { http.Error(w, "render error", http.StatusInternalServerError) return } s.servePreRendered(w, r, html, computeETag(html), "public, max-age=0, must-revalidate") } func (s *Server) blogPost(w http.ResponseWriter, r *http.Request) { tenantID := r.Context().Value(tenantIDKey).(string) slug := chi.URLParam(r, "slug") isPreview := r.URL.Query().Get("preview") == "true" db, err := s.tenantPool.Get(tenantID) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } q := tenant.NewQueries(db) if isPreview && !s.canPreview(r, tenantID) { http.Error(w, "unauthorized", http.StatusUnauthorized) return } if !isPreview { path := "/posts/" + slug s.recordPageView(q, r, path, slug) if html, etag, err := q.GetPage(r.Context(), path); err == nil && html != nil { s.servePreRendered(w, r, html, etag, "public, max-age=0, must-revalidate") return } } post, err := q.GetPost(r.Context(), slug) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } if post == nil { aliasPost, _ := q.GetPostByAlias(r.Context(), slug) if aliasPost != nil && aliasPost.IsPublished { http.Redirect(w, r, "/posts/"+aliasPost.Slug, http.StatusMovedPermanently) return } http.NotFound(w, r) return } if !post.IsPublished && !isPreview { http.NotFound(w, r) return } title := post.Title description := post.Description contentMD := post.ContentMD tags := post.Tags coverImage := post.CoverImage if isPreview { if draft, _ := q.GetDraft(r.Context(), post.ID); draft != nil { title = draft.Title description = draft.Description contentMD = draft.ContentMD tags = draft.Tags coverImage = draft.CoverImage } } settings, _ := q.GetSettings(r.Context()) siteName := getSettingOr(settings, "site_name", "My Blog") baseURL := getBaseURL(r.Host) codeTheme := getSettingOr(settings, "code_theme", "github") showBadge := true if t, err := s.database.GetTenantByID(r.Context(), tenantID); err == nil && t != nil { tierInfo := config.GetTierInfo(t.Premium) showBadge = tierInfo.Config.BadgeRequired } contentHTML := "" if contentMD != "" { contentHTML, _ = markdown.RenderWithTheme(contentMD, codeTheme) } interactionConfig := q.GetInteractionConfig(r.Context()) structuredData := buildArticleSchema(post, siteName, baseURL) data := templates.PostData{ PageData: templates.PageData{ Title: title + " - " + siteName, Description: description, CanonicalURL: baseURL + "/posts/" + post.Slug, OGType: "article", OGImage: coverImage, SiteName: siteName, Year: time.Now().Year(), StructuredData: template.JS(structuredData), Settings: settingsToMap(settings), NoIndex: GetDemoInfo(r).IsDemo || isPreview, ShowBadge: showBadge, }, Post: templates.PostDetail{ Slug: post.Slug, Title: title, Description: description, CoverImage: coverImage, Date: timeOrZero(post.PublishedAt), Tags: tags, }, ContentHTML: template.HTML(contentHTML), InteractionConfig: interactionConfig, } html, err := templates.RenderPost(data) if err != nil { http.Error(w, "render error", http.StatusInternalServerError) return } if isPreview { previewScript := `
Preview Mode Viewing as author Back to Editor
Rebuilding...
` html = bytes.Replace(html, []byte(""), []byte(previewScript+""), 1) w.Header().Set("Cache-Control", "no-store") w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Write(html) return } s.servePreRendered(w, r, html, computeETag(html), "public, max-age=0, must-revalidate") } func (s *Server) canPreview(r *http.Request, tenantID string) bool { if GetDemoInfo(r).IsDemo { return true } userID := auth.GetUserID(r) if userID == "" { return false } isOwner, err := s.database.IsUserTenantOwner(r.Context(), userID, tenantID) if err != nil { return false } return isOwner } func (s *Server) serveStudio(w http.ResponseWriter, r *http.Request) { if viteURL := os.Getenv("VITE_URL"); viteURL != "" && os.Getenv("ENV") == "local" { target, err := url.Parse(viteURL) if err != nil { slog.Error("invalid VITE_URL", "error", err) http.Error(w, "internal error", http.StatusInternalServerError) return } proxy := httputil.NewSingleHostReverseProxy(target) proxy.Director = func(req *http.Request) { req.URL.Scheme = target.Scheme req.URL.Host = target.Host req.Host = target.Host } proxy.ServeHTTP(w, r) return } path := chi.URLParam(r, "*") if path == "" { path = "index.html" } data, err := studio.Read(path) if err != nil { data, _ = studio.Read("index.html") } contentType := "text/html; charset=utf-8" if len(path) > 3 { switch path[len(path)-3:] { case ".js": contentType = "application/javascript" case "css": contentType = "text/css" } } if contentType == "text/html; charset=utf-8" { if demoInfo := GetDemoInfo(r); demoInfo.IsDemo { data = s.injectDemoBanner(data, demoInfo.ExpiresAt) } } w.Header().Set("Content-Type", contentType) w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") if contentType == "text/html; charset=utf-8" { w.Header().Set("Cache-Control", "no-cache") } w.Write(data) } func (s *Server) sitemap(w http.ResponseWriter, r *http.Request) { tenantID := r.Context().Value(tenantIDKey).(string) db, err := s.tenantPool.Get(tenantID) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } q := tenant.NewQueries(db) posts, _ := q.ListPosts(r.Context(), false) baseURL := getBaseURL(r.Host) w.Header().Set("Content-Type", "application/xml; charset=utf-8") w.Header().Set("Cache-Control", "public, max-age=0, must-revalidate") w.Write([]byte(` ` + baseURL + `/ `)) for _, p := range posts { lastmod := p.ModifiedAt.Format("2006-01-02") if p.UpdatedAt != nil { lastmod = p.UpdatedAt.Format("2006-01-02") } w.Write([]byte(fmt.Sprintf(" %s/posts/%s%s\n", baseURL, p.Slug, lastmod))) } w.Write([]byte("")) } func (s *Server) robots(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Cache-Control", "public, max-age=86400") if GetDemoInfo(r).IsDemo { w.Write([]byte("User-agent: *\nDisallow: /\n")) return } baseURL := getBaseURL(r.Host) fmt.Fprintf(w, "User-agent: *\nAllow: /\n\nSitemap: %s/sitemap.xml\n", baseURL) } func (s *Server) ownerMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { demoInfo := GetDemoInfo(r) if demoInfo.IsDemo { next.ServeHTTP(w, r) return } userID := auth.GetUserID(r) if userID == "" { http.Error(w, "unauthorized", http.StatusUnauthorized) return } tenantID, ok := r.Context().Value(tenantIDKey).(string) if !ok || tenantID == "" { http.Error(w, "unauthorized", http.StatusUnauthorized) return } isOwner, err := s.database.IsUserTenantOwner(r.Context(), userID, tenantID) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } if !isOwner { http.Error(w, "forbidden", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } func getSettingOr(settings tenant.Settings, key, fallback string) string { if v, ok := settings[key]; ok && v != "" { return v } return fallback } func settingsToMap(settings tenant.Settings) map[string]any { m := make(map[string]any) for k, v := range settings { m[k] = v } return m } func getBaseURL(host string) string { scheme := "https" if env := os.Getenv("ENV"); env != "prod" { scheme = "http" } return fmt.Sprintf("%s://%s", scheme, host) } func computeETag(data []byte) string { hash := md5.Sum(data) return `"` + hex.EncodeToString(hash[:]) + `"` } func (s *Server) servePreRendered(w http.ResponseWriter, r *http.Request, html []byte, etag, cacheControl string) { if demoInfo := GetDemoInfo(r); demoInfo.IsDemo { html = s.injectDemoBanner(html, demoInfo.ExpiresAt) etag = computeETag(html) } if match := r.Header.Get("If-None-Match"); match == etag { w.WriteHeader(http.StatusNotModified) return } w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Cache-Control", cacheControl) w.Header().Set("ETag", etag) w.Write(html) } func buildArticleSchema(post *tenant.Post, siteName, baseURL string) string { publishedAt := timeOrZero(post.PublishedAt) modifiedAt := publishedAt if post.UpdatedAt != nil { modifiedAt = *post.UpdatedAt } schema := map[string]any{ "@context": "https://schema.org", "@type": "Article", "headline": post.Title, "datePublished": publishedAt.Format(time.RFC3339), "dateModified": modifiedAt.Format(time.RFC3339), "author": map[string]any{ "@type": "Person", "name": siteName, }, "publisher": map[string]any{ "@type": "Organization", "name": siteName, }, "mainEntityOfPage": map[string]any{ "@type": "WebPage", "@id": baseURL + "/posts/" + post.Slug, }, } if post.Description != "" { schema["description"] = post.Description } b, _ := json.Marshal(schema) return string(b) } func (s *Server) recordPageView(q *tenant.Queries, r *http.Request, path, postSlug string) { referrer := r.Header.Get("Referer") userAgent := r.Header.Get("User-Agent") go func() { q.RecordPageView(context.Background(), path, postSlug, referrer, userAgent) }() } func timeOrZero(t *time.Time) time.Time { if t == nil { return time.Time{} } return *t }