package server import ( "context" "database/sql" "log/slog" "net/http" "os" "strings" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/writekitapp/writekit/internal/auth" "github.com/writekitapp/writekit/internal/cloudflare" "github.com/writekitapp/writekit/internal/db" "github.com/writekitapp/writekit/internal/imaginary" "github.com/writekitapp/writekit/internal/storage" "github.com/writekitapp/writekit/internal/tenant" ) type Server struct { router chi.Router database *db.DB tenantPool *tenant.Pool tenantCache *tenant.Cache storage storage.Client imaginary *imaginary.Client cloudflare *cloudflare.Client rateLimiter *RateLimiter domain string jarvisURL string stopCleanup chan struct{} } func New(database *db.DB, pool *tenant.Pool, cache *tenant.Cache, storageClient storage.Client) *Server { domain := os.Getenv("DOMAIN") if domain == "" { domain = "writekit.dev" } jarvisURL := os.Getenv("JARVIS_URL") if jarvisURL == "" { jarvisURL = "http://localhost:8090" } var imgClient *imaginary.Client if url := os.Getenv("IMAGINARY_URL"); url != "" { imgClient = imaginary.New(url) } cfClient := cloudflare.NewClient() s := &Server{ router: chi.NewRouter(), database: database, tenantPool: pool, tenantCache: cache, storage: storageClient, imaginary: imgClient, cloudflare: cfClient, rateLimiter: NewRateLimiter(), domain: domain, jarvisURL: jarvisURL, stopCleanup: make(chan struct{}), } s.router.Use(middleware.Logger) s.router.Use(middleware.Recoverer) s.router.Use(middleware.Compress(5)) s.routes() go s.cleanupDemos() return s } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.router.ServeHTTP(w, r) } func (s *Server) routes() { s.router.HandleFunc("/*", s.route) } func (s *Server) route(w http.ResponseWriter, r *http.Request) { host := r.Host if idx := strings.Index(host, ":"); idx != -1 { host = host[:idx] } if host == s.domain || host == "www."+s.domain { s.servePlatform(w, r) return } if strings.HasSuffix(host, "."+s.domain) { subdomain := strings.TrimSuffix(host, "."+s.domain) s.serveBlog(w, r, subdomain) return } s.notFound(w, r) } func (s *Server) servePlatform(w http.ResponseWriter, r *http.Request) { mux := chi.NewRouter() mux.NotFound(s.notFound) mux.Get("/", s.platformHome) mux.Get("/login", s.platformLogin) mux.Get("/signup", s.platformSignup) mux.Get("/signup/complete", s.platformSignup) mux.Get("/dashboard", s.platformDashboard) mux.Handle("/assets/*", http.HandlerFunc(s.serveStaticAssets)) mux.Mount("/auth", auth.NewHandler(s.database).Routes()) mux.Route("/api", func(r chi.Router) { r.Get("/tenant/check", s.checkSubdomain) r.Post("/demo", s.createDemo) r.Group(func(r chi.Router) { r.Use(auth.SessionMiddleware(s.database)) r.Post("/tenant", s.createTenant) r.Get("/tenant", s.getTenant) }) }) mux.ServeHTTP(w, r) } func (s *Server) cleanupDemos() { ticker := time.NewTicker(15 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: expired, err := s.database.CleanupExpiredDemos(context.Background()) if err != nil { slog.Error("cleanup expired demos", "error", err) continue } for _, d := range expired { s.tenantPool.Evict(d.ID) s.tenantCache.Delete(d.Subdomain) slog.Info("cleaned up expired demo", "demo_id", d.ID, "subdomain", d.Subdomain) } case <-s.stopCleanup: return } } } func (s *Server) Close() { close(s.stopCleanup) } // getPluginRunner returns a PluginRunner for the given tenant func (s *Server) getPluginRunner(tenantID string, db *sql.DB) *tenant.PluginRunner { return tenant.NewPluginRunner(db, tenantID) }