init
This commit is contained in:
commit
d69342b2e9
160 changed files with 28681 additions and 0 deletions
126
internal/server/ratelimit.go
Normal file
126
internal/server/ratelimit.go
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/writekitapp/writekit/internal/config"
|
||||
)
|
||||
|
||||
type bucket struct {
|
||||
tokens float64
|
||||
lastFill time.Time
|
||||
rateLimit int
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
buckets map[string]*bucket
|
||||
}
|
||||
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
buckets: make(map[string]*bucket),
|
||||
}
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Allow(tenantID string, limit int) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
b, ok := rl.buckets[tenantID]
|
||||
if !ok {
|
||||
b = &bucket{
|
||||
tokens: float64(limit),
|
||||
lastFill: now,
|
||||
rateLimit: limit,
|
||||
}
|
||||
rl.buckets[tenantID] = b
|
||||
}
|
||||
|
||||
if b.rateLimit != limit {
|
||||
b.rateLimit = limit
|
||||
b.tokens = float64(limit)
|
||||
}
|
||||
|
||||
elapsed := now.Sub(b.lastFill)
|
||||
tokensToAdd := elapsed.Hours() * float64(limit)
|
||||
b.tokens = min(b.tokens+tokensToAdd, float64(limit))
|
||||
b.lastFill = now
|
||||
|
||||
if b.tokens >= 1 {
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
threshold := time.Now().Add(-1 * time.Hour)
|
||||
for k, b := range rl.buckets {
|
||||
if b.lastFill.Before(threshold) {
|
||||
delete(rl.buckets, k)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) apiRateLimitMiddleware(rl *RateLimiter) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tenantID, ok := r.Context().Value(tenantIDKey).(string)
|
||||
if !ok || tenantID == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
t, err := s.database.GetTenantByID(r.Context(), tenantID)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
premium := t != nil && t.Premium
|
||||
tierInfo := config.GetTierInfo(premium)
|
||||
limit := tierInfo.Config.APIRateLimit
|
||||
|
||||
if !rl.Allow(tenantID, limit) {
|
||||
w.Header().Set("X-RateLimit-Limit", itoa(limit))
|
||||
w.Header().Set("X-RateLimit-Reset", "3600")
|
||||
w.Header().Set("Retry-After", "60")
|
||||
jsonError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("X-RateLimit-Limit", itoa(limit))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func itoa(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
s := ""
|
||||
for n > 0 {
|
||||
s = string(rune('0'+n%10)) + s
|
||||
n /= 10
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func min(a, b float64) float64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue