package server import ( "net/http" "sync" "time" "writekit/internal/config" ) type bucket struct { tokens float64 lastFill time.Time rateLimit int } type RateLimiter struct { mu sync.RWMutex buckets map[string]*bucket } func NewRateLimiter() *RateLimiter { rl := &RateLimiter{ buckets: make(map[string]*bucket), } go rl.cleanup() return rl } func (rl *RateLimiter) Allow(tenantID string, limit int) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() b, ok := rl.buckets[tenantID] if !ok { b = &bucket{ tokens: float64(limit), lastFill: now, rateLimit: limit, } rl.buckets[tenantID] = b } if b.rateLimit != limit { b.rateLimit = limit b.tokens = float64(limit) } elapsed := now.Sub(b.lastFill) tokensToAdd := elapsed.Hours() * float64(limit) b.tokens = min(b.tokens+tokensToAdd, float64(limit)) b.lastFill = now if b.tokens >= 1 { b.tokens-- return true } return false } func (rl *RateLimiter) cleanup() { ticker := time.NewTicker(5 * time.Minute) for range ticker.C { rl.mu.Lock() threshold := time.Now().Add(-1 * time.Hour) for k, b := range rl.buckets { if b.lastFill.Before(threshold) { delete(rl.buckets, k) } } rl.mu.Unlock() } } func (s *Server) apiRateLimitMiddleware(rl *RateLimiter) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tenantID, ok := r.Context().Value(tenantIDKey).(string) if !ok || tenantID == "" { next.ServeHTTP(w, r) return } t, err := s.database.GetTenantByID(r.Context(), tenantID) if err != nil { next.ServeHTTP(w, r) return } premium := t != nil && t.Premium tierInfo := config.GetTierInfo(premium) limit := tierInfo.Config.APIRateLimit if !rl.Allow(tenantID, limit) { w.Header().Set("X-RateLimit-Limit", itoa(limit)) w.Header().Set("X-RateLimit-Reset", "3600") w.Header().Set("Retry-After", "60") jsonError(w, http.StatusTooManyRequests, "rate limit exceeded") return } w.Header().Set("X-RateLimit-Limit", itoa(limit)) next.ServeHTTP(w, r) }) } } func itoa(n int) string { if n == 0 { return "0" } s := "" for n > 0 { s = string(rune('0'+n%10)) + s n /= 10 } return s } func min(a, b float64) float64 { if a < b { return a } return b }