127 lines
2.4 KiB
Go
127 lines
2.4 KiB
Go
|
|
package server
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/writekitapp/writekit/internal/config"
|
||
|
|
)
|
||
|
|
|
||
|
|
type bucket struct {
|
||
|
|
tokens float64
|
||
|
|
lastFill time.Time
|
||
|
|
rateLimit int
|
||
|
|
}
|
||
|
|
|
||
|
|
type RateLimiter struct {
|
||
|
|
mu sync.RWMutex
|
||
|
|
buckets map[string]*bucket
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewRateLimiter() *RateLimiter {
|
||
|
|
rl := &RateLimiter{
|
||
|
|
buckets: make(map[string]*bucket),
|
||
|
|
}
|
||
|
|
go rl.cleanup()
|
||
|
|
return rl
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rl *RateLimiter) Allow(tenantID string, limit int) bool {
|
||
|
|
rl.mu.Lock()
|
||
|
|
defer rl.mu.Unlock()
|
||
|
|
|
||
|
|
now := time.Now()
|
||
|
|
b, ok := rl.buckets[tenantID]
|
||
|
|
if !ok {
|
||
|
|
b = &bucket{
|
||
|
|
tokens: float64(limit),
|
||
|
|
lastFill: now,
|
||
|
|
rateLimit: limit,
|
||
|
|
}
|
||
|
|
rl.buckets[tenantID] = b
|
||
|
|
}
|
||
|
|
|
||
|
|
if b.rateLimit != limit {
|
||
|
|
b.rateLimit = limit
|
||
|
|
b.tokens = float64(limit)
|
||
|
|
}
|
||
|
|
|
||
|
|
elapsed := now.Sub(b.lastFill)
|
||
|
|
tokensToAdd := elapsed.Hours() * float64(limit)
|
||
|
|
b.tokens = min(b.tokens+tokensToAdd, float64(limit))
|
||
|
|
b.lastFill = now
|
||
|
|
|
||
|
|
if b.tokens >= 1 {
|
||
|
|
b.tokens--
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
func (rl *RateLimiter) cleanup() {
|
||
|
|
ticker := time.NewTicker(5 * time.Minute)
|
||
|
|
for range ticker.C {
|
||
|
|
rl.mu.Lock()
|
||
|
|
threshold := time.Now().Add(-1 * time.Hour)
|
||
|
|
for k, b := range rl.buckets {
|
||
|
|
if b.lastFill.Before(threshold) {
|
||
|
|
delete(rl.buckets, k)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
rl.mu.Unlock()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *Server) apiRateLimitMiddleware(rl *RateLimiter) func(http.Handler) http.Handler {
|
||
|
|
return func(next http.Handler) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
tenantID, ok := r.Context().Value(tenantIDKey).(string)
|
||
|
|
if !ok || tenantID == "" {
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
t, err := s.database.GetTenantByID(r.Context(), tenantID)
|
||
|
|
if err != nil {
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
premium := t != nil && t.Premium
|
||
|
|
tierInfo := config.GetTierInfo(premium)
|
||
|
|
limit := tierInfo.Config.APIRateLimit
|
||
|
|
|
||
|
|
if !rl.Allow(tenantID, limit) {
|
||
|
|
w.Header().Set("X-RateLimit-Limit", itoa(limit))
|
||
|
|
w.Header().Set("X-RateLimit-Reset", "3600")
|
||
|
|
w.Header().Set("Retry-After", "60")
|
||
|
|
jsonError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
w.Header().Set("X-RateLimit-Limit", itoa(limit))
|
||
|
|
next.ServeHTTP(w, r)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func itoa(n int) string {
|
||
|
|
if n == 0 {
|
||
|
|
return "0"
|
||
|
|
}
|
||
|
|
s := ""
|
||
|
|
for n > 0 {
|
||
|
|
s = string(rune('0'+n%10)) + s
|
||
|
|
n /= 10
|
||
|
|
}
|
||
|
|
return s
|
||
|
|
}
|
||
|
|
|
||
|
|
func min(a, b float64) float64 {
|
||
|
|
if a < b {
|
||
|
|
return a
|
||
|
|
}
|
||
|
|
return b
|
||
|
|
}
|