writekit/internal/db/demos.go
2026-01-09 00:16:46 +02:00

109 lines
2.5 KiB
Go

package db
import (
"context"
"crypto/rand"
"encoding/hex"
"os"
"strconv"
"time"
"github.com/jackc/pgx/v5"
)
func getDemoDuration() time.Duration {
if mins := os.Getenv("DEMO_DURATION_MINUTES"); mins != "" {
if m, err := strconv.Atoi(mins); err == nil && m > 0 {
return time.Duration(m) * time.Minute
}
}
if os.Getenv("ENV") != "prod" {
return 100 * 365 * 24 * time.Hour // infinite in local/dev
}
return 15 * time.Minute
}
func (db *DB) GetDemoBySubdomain(ctx context.Context, subdomain string) (*Demo, error) {
var d Demo
err := db.pool.QueryRow(ctx,
`SELECT id, subdomain, expires_at FROM demos WHERE subdomain = $1 AND expires_at > NOW()`,
subdomain).Scan(&d.ID, &d.Subdomain, &d.ExpiresAt)
if err == pgx.ErrNoRows {
return nil, nil
}
return &d, err
}
func (db *DB) GetDemoByID(ctx context.Context, id string) (*Demo, error) {
var d Demo
err := db.pool.QueryRow(ctx,
`SELECT id, subdomain, expires_at FROM demos WHERE id = $1 AND expires_at > NOW()`,
id).Scan(&d.ID, &d.Subdomain, &d.ExpiresAt)
if err == pgx.ErrNoRows {
return nil, nil
}
return &d, err
}
func (db *DB) CreateDemo(ctx context.Context) (*Demo, error) {
subdomain := "demo-" + randomHex(4)
expiresAt := time.Now().Add(getDemoDuration())
var d Demo
err := db.pool.QueryRow(ctx,
`INSERT INTO demos (subdomain, expires_at) VALUES ($1, $2) RETURNING id, subdomain, expires_at`,
subdomain, expiresAt).Scan(&d.ID, &d.Subdomain, &d.ExpiresAt)
if err != nil {
return nil, err
}
return &d, nil
}
type ExpiredDemo struct {
ID string
Subdomain string
}
func (db *DB) CleanupExpiredDemos(ctx context.Context) ([]ExpiredDemo, error) {
rows, err := db.pool.Query(ctx,
`DELETE FROM demos WHERE expires_at < NOW() RETURNING id, subdomain`)
if err != nil {
return nil, err
}
defer rows.Close()
var demos []ExpiredDemo
for rows.Next() {
var d ExpiredDemo
if err := rows.Scan(&d.ID, &d.Subdomain); err != nil {
return nil, err
}
demos = append(demos, d)
}
return demos, rows.Err()
}
func (db *DB) ListActiveDemos(ctx context.Context) ([]Demo, error) {
rows, err := db.pool.Query(ctx,
`SELECT id, subdomain, expires_at FROM demos WHERE expires_at > NOW()`)
if err != nil {
return nil, err
}
defer rows.Close()
var demos []Demo
for rows.Next() {
var d Demo
if err := rows.Scan(&d.ID, &d.Subdomain, &d.ExpiresAt); err != nil {
return nil, err
}
demos = append(demos, d)
}
return demos, rows.Err()
}
func randomHex(n int) string {
b := make([]byte, n)
rand.Read(b)
return hex.EncodeToString(b)
}