package db import ( "context" "fmt" "os" "path/filepath" "sort" "github.com/jackc/pgx/v5/pgxpool" ) const defaultDSN = "postgres://writekit:writekit@localhost:5432/writekit?sslmode=disable" type DB struct { pool *pgxpool.Pool } func Connect(migrationsDir string) (*DB, error) { dsn := os.Getenv("DATABASE_URL") if dsn == "" { dsn = defaultDSN } pool, err := pgxpool.New(context.Background(), dsn) if err != nil { return nil, fmt.Errorf("connect: %w", err) } if err := pool.Ping(context.Background()); err != nil { pool.Close() return nil, fmt.Errorf("ping: %w", err) } db := &DB{pool: pool} if err := db.RunMigrations(migrationsDir); err != nil { pool.Close() return nil, fmt.Errorf("migrations: %w", err) } return db, nil } func (db *DB) RunMigrations(dir string) error { entries, err := os.ReadDir(dir) if err != nil { return fmt.Errorf("read migrations dir: %w", err) } var files []string for _, entry := range entries { if entry.IsDir() || filepath.Ext(entry.Name()) != ".sql" { continue } files = append(files, entry.Name()) } sort.Strings(files) for _, name := range files { content, err := os.ReadFile(filepath.Join(dir, name)) if err != nil { return fmt.Errorf("read %s: %w", name, err) } if _, err := db.pool.Exec(context.Background(), string(content)); err != nil { return fmt.Errorf("run %s: %w", name, err) } } return nil } func (db *DB) Close() { db.pool.Close() } func (db *DB) Pool() *pgxpool.Pool { return db.pool }