writekit/internal/server/ratelimit.go

127 lines
2.4 KiB
Go
Raw Normal View History

2026-01-09 00:16:46 +02:00
package server
import (
"net/http"
"sync"
"time"
"github.com/writekitapp/writekit/internal/config"
)
type bucket struct {
tokens float64
lastFill time.Time
rateLimit int
}
type RateLimiter struct {
mu sync.RWMutex
buckets map[string]*bucket
}
func NewRateLimiter() *RateLimiter {
rl := &RateLimiter{
buckets: make(map[string]*bucket),
}
go rl.cleanup()
return rl
}
func (rl *RateLimiter) Allow(tenantID string, limit int) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
b, ok := rl.buckets[tenantID]
if !ok {
b = &bucket{
tokens: float64(limit),
lastFill: now,
rateLimit: limit,
}
rl.buckets[tenantID] = b
}
if b.rateLimit != limit {
b.rateLimit = limit
b.tokens = float64(limit)
}
elapsed := now.Sub(b.lastFill)
tokensToAdd := elapsed.Hours() * float64(limit)
b.tokens = min(b.tokens+tokensToAdd, float64(limit))
b.lastFill = now
if b.tokens >= 1 {
b.tokens--
return true
}
return false
}
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
rl.mu.Lock()
threshold := time.Now().Add(-1 * time.Hour)
for k, b := range rl.buckets {
if b.lastFill.Before(threshold) {
delete(rl.buckets, k)
}
}
rl.mu.Unlock()
}
}
func (s *Server) apiRateLimitMiddleware(rl *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tenantID, ok := r.Context().Value(tenantIDKey).(string)
if !ok || tenantID == "" {
next.ServeHTTP(w, r)
return
}
t, err := s.database.GetTenantByID(r.Context(), tenantID)
if err != nil {
next.ServeHTTP(w, r)
return
}
premium := t != nil && t.Premium
tierInfo := config.GetTierInfo(premium)
limit := tierInfo.Config.APIRateLimit
if !rl.Allow(tenantID, limit) {
w.Header().Set("X-RateLimit-Limit", itoa(limit))
w.Header().Set("X-RateLimit-Reset", "3600")
w.Header().Set("Retry-After", "60")
jsonError(w, http.StatusTooManyRequests, "rate limit exceeded")
return
}
w.Header().Set("X-RateLimit-Limit", itoa(limit))
next.ServeHTTP(w, r)
})
}
}
func itoa(n int) string {
if n == 0 {
return "0"
}
s := ""
for n > 0 {
s = string(rune('0'+n%10)) + s
n /= 10
}
return s
}
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}