163 lines
3.7 KiB
Go
163 lines
3.7 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/writekitapp/writekit/internal/auth"
|
|
"github.com/writekitapp/writekit/internal/cloudflare"
|
|
"github.com/writekitapp/writekit/internal/db"
|
|
"github.com/writekitapp/writekit/internal/imaginary"
|
|
"github.com/writekitapp/writekit/internal/storage"
|
|
"github.com/writekitapp/writekit/internal/tenant"
|
|
)
|
|
|
|
type Server struct {
|
|
router chi.Router
|
|
database *db.DB
|
|
tenantPool *tenant.Pool
|
|
tenantCache *tenant.Cache
|
|
storage storage.Client
|
|
imaginary *imaginary.Client
|
|
cloudflare *cloudflare.Client
|
|
rateLimiter *RateLimiter
|
|
domain string
|
|
jarvisURL string
|
|
stopCleanup chan struct{}
|
|
}
|
|
|
|
func New(database *db.DB, pool *tenant.Pool, cache *tenant.Cache, storageClient storage.Client) *Server {
|
|
domain := os.Getenv("DOMAIN")
|
|
if domain == "" {
|
|
domain = "writekit.dev"
|
|
}
|
|
|
|
jarvisURL := os.Getenv("JARVIS_URL")
|
|
if jarvisURL == "" {
|
|
jarvisURL = "http://localhost:8090"
|
|
}
|
|
|
|
var imgClient *imaginary.Client
|
|
if url := os.Getenv("IMAGINARY_URL"); url != "" {
|
|
imgClient = imaginary.New(url)
|
|
}
|
|
|
|
cfClient := cloudflare.NewClient()
|
|
|
|
s := &Server{
|
|
router: chi.NewRouter(),
|
|
database: database,
|
|
tenantPool: pool,
|
|
tenantCache: cache,
|
|
storage: storageClient,
|
|
imaginary: imgClient,
|
|
cloudflare: cfClient,
|
|
rateLimiter: NewRateLimiter(),
|
|
domain: domain,
|
|
jarvisURL: jarvisURL,
|
|
stopCleanup: make(chan struct{}),
|
|
}
|
|
|
|
s.router.Use(middleware.Logger)
|
|
s.router.Use(middleware.Recoverer)
|
|
s.router.Use(middleware.Compress(5))
|
|
|
|
s.routes()
|
|
go s.cleanupDemos()
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
s.router.ServeHTTP(w, r)
|
|
}
|
|
|
|
func (s *Server) routes() {
|
|
s.router.HandleFunc("/*", s.route)
|
|
}
|
|
|
|
func (s *Server) route(w http.ResponseWriter, r *http.Request) {
|
|
host := r.Host
|
|
if idx := strings.Index(host, ":"); idx != -1 {
|
|
host = host[:idx]
|
|
}
|
|
|
|
if host == s.domain || host == "www."+s.domain {
|
|
s.servePlatform(w, r)
|
|
return
|
|
}
|
|
|
|
if strings.HasSuffix(host, "."+s.domain) {
|
|
subdomain := strings.TrimSuffix(host, "."+s.domain)
|
|
s.serveBlog(w, r, subdomain)
|
|
return
|
|
}
|
|
|
|
s.notFound(w, r)
|
|
}
|
|
|
|
func (s *Server) servePlatform(w http.ResponseWriter, r *http.Request) {
|
|
mux := chi.NewRouter()
|
|
mux.NotFound(s.notFound)
|
|
|
|
mux.Get("/", s.platformHome)
|
|
mux.Get("/login", s.platformLogin)
|
|
mux.Get("/signup", s.platformSignup)
|
|
mux.Get("/signup/complete", s.platformSignup)
|
|
mux.Get("/dashboard", s.platformDashboard)
|
|
mux.Handle("/assets/*", http.HandlerFunc(s.serveStaticAssets))
|
|
|
|
mux.Mount("/auth", auth.NewHandler(s.database).Routes())
|
|
|
|
mux.Route("/api", func(r chi.Router) {
|
|
r.Get("/tenant/check", s.checkSubdomain)
|
|
r.Post("/demo", s.createDemo)
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(auth.SessionMiddleware(s.database))
|
|
r.Post("/tenant", s.createTenant)
|
|
r.Get("/tenant", s.getTenant)
|
|
})
|
|
})
|
|
|
|
mux.ServeHTTP(w, r)
|
|
}
|
|
|
|
func (s *Server) cleanupDemos() {
|
|
ticker := time.NewTicker(15 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
expired, err := s.database.CleanupExpiredDemos(context.Background())
|
|
if err != nil {
|
|
slog.Error("cleanup expired demos", "error", err)
|
|
continue
|
|
}
|
|
for _, d := range expired {
|
|
s.tenantPool.Evict(d.ID)
|
|
s.tenantCache.Delete(d.Subdomain)
|
|
slog.Info("cleaned up expired demo", "demo_id", d.ID, "subdomain", d.Subdomain)
|
|
}
|
|
case <-s.stopCleanup:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) Close() {
|
|
close(s.stopCleanup)
|
|
}
|
|
|
|
// getPluginRunner returns a PluginRunner for the given tenant
|
|
func (s *Server) getPluginRunner(tenantID string, db *sql.DB) *tenant.PluginRunner {
|
|
return tenant.NewPluginRunner(db, tenantID)
|
|
}
|
|
|