81 lines
1.5 KiB
Go
81 lines
1.5 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
const defaultDSN = "postgres://writekit:writekit@localhost:5432/writekit?sslmode=disable"
|
|
|
|
type DB struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
func Connect(migrationsDir string) (*DB, error) {
|
|
dsn := os.Getenv("DATABASE_URL")
|
|
if dsn == "" {
|
|
dsn = defaultDSN
|
|
}
|
|
|
|
pool, err := pgxpool.New(context.Background(), dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("connect: %w", err)
|
|
}
|
|
|
|
if err := pool.Ping(context.Background()); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("ping: %w", err)
|
|
}
|
|
|
|
db := &DB{pool: pool}
|
|
|
|
if err := db.RunMigrations(migrationsDir); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("migrations: %w", err)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func (db *DB) RunMigrations(dir string) error {
|
|
entries, err := os.ReadDir(dir)
|
|
if err != nil {
|
|
return fmt.Errorf("read migrations dir: %w", err)
|
|
}
|
|
|
|
var files []string
|
|
for _, entry := range entries {
|
|
if entry.IsDir() || filepath.Ext(entry.Name()) != ".sql" {
|
|
continue
|
|
}
|
|
files = append(files, entry.Name())
|
|
}
|
|
sort.Strings(files)
|
|
|
|
for _, name := range files {
|
|
content, err := os.ReadFile(filepath.Join(dir, name))
|
|
if err != nil {
|
|
return fmt.Errorf("read %s: %w", name, err)
|
|
}
|
|
|
|
if _, err := db.pool.Exec(context.Background(), string(content)); err != nil {
|
|
return fmt.Errorf("run %s: %w", name, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) Close() {
|
|
db.pool.Close()
|
|
}
|
|
|
|
func (db *DB) Pool() *pgxpool.Pool {
|
|
return db.pool
|
|
}
|
|
|