writekit/internal/db/db.go
2026-01-09 00:16:46 +02:00

81 lines
1.5 KiB
Go

package db
import (
"context"
"fmt"
"os"
"path/filepath"
"sort"
"github.com/jackc/pgx/v5/pgxpool"
)
const defaultDSN = "postgres://writekit:writekit@localhost:5432/writekit?sslmode=disable"
type DB struct {
pool *pgxpool.Pool
}
func Connect(migrationsDir string) (*DB, error) {
dsn := os.Getenv("DATABASE_URL")
if dsn == "" {
dsn = defaultDSN
}
pool, err := pgxpool.New(context.Background(), dsn)
if err != nil {
return nil, fmt.Errorf("connect: %w", err)
}
if err := pool.Ping(context.Background()); err != nil {
pool.Close()
return nil, fmt.Errorf("ping: %w", err)
}
db := &DB{pool: pool}
if err := db.RunMigrations(migrationsDir); err != nil {
pool.Close()
return nil, fmt.Errorf("migrations: %w", err)
}
return db, nil
}
func (db *DB) RunMigrations(dir string) error {
entries, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("read migrations dir: %w", err)
}
var files []string
for _, entry := range entries {
if entry.IsDir() || filepath.Ext(entry.Name()) != ".sql" {
continue
}
files = append(files, entry.Name())
}
sort.Strings(files)
for _, name := range files {
content, err := os.ReadFile(filepath.Join(dir, name))
if err != nil {
return fmt.Errorf("read %s: %w", name, err)
}
if _, err := db.pool.Exec(context.Background(), string(content)); err != nil {
return fmt.Errorf("run %s: %w", name, err)
}
}
return nil
}
func (db *DB) Close() {
db.pool.Close()
}
func (db *DB) Pool() *pgxpool.Pool {
return db.pool
}