117 lines
2.7 KiB
Go
117 lines
2.7 KiB
Go
package tenant
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
func (q *Queries) GetUserByID(ctx context.Context, id string) (*User, error) {
|
|
row := q.db.QueryRowContext(ctx, `SELECT id, email, name, avatar_url, created_at FROM users WHERE id = ?`, id)
|
|
|
|
u, err := scanUser(row)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &u, nil
|
|
}
|
|
|
|
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (*User, error) {
|
|
row := q.db.QueryRowContext(ctx, `SELECT id, email, name, avatar_url, created_at FROM users WHERE email = ?`, email)
|
|
|
|
u, err := scanUser(row)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &u, nil
|
|
}
|
|
|
|
func (q *Queries) CreateUser(ctx context.Context, u *User) error {
|
|
if u.ID == "" {
|
|
u.ID = uuid.NewString()
|
|
}
|
|
_, err := q.db.ExecContext(ctx, `INSERT INTO users (id, email, name, avatar_url) VALUES (?, ?, ?, ?)`,
|
|
u.ID, u.Email, nullStr(u.Name), nullStr(u.AvatarURL))
|
|
return err
|
|
}
|
|
|
|
func (q *Queries) ValidateSession(ctx context.Context, token string) (*Session, error) {
|
|
var s Session
|
|
var expiresAt string
|
|
err := q.db.QueryRowContext(ctx, `SELECT token, user_id, expires_at FROM sessions WHERE token = ?`, token).
|
|
Scan(&s.Token, &s.UserID, &expiresAt)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.ExpiresAt, _ = time.Parse(time.RFC3339, expiresAt)
|
|
if time.Now().After(s.ExpiresAt) {
|
|
q.db.ExecContext(ctx, `DELETE FROM sessions WHERE token = ?`, token)
|
|
return nil, nil
|
|
}
|
|
|
|
return &s, nil
|
|
}
|
|
|
|
func (q *Queries) CreateSession(ctx context.Context, userID string) (*Session, error) {
|
|
token, err := generateToken()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
expires := time.Now().Add(30 * 24 * time.Hour)
|
|
expiresStr := expires.UTC().Format(time.RFC3339)
|
|
|
|
_, err = q.db.ExecContext(ctx, `INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)`,
|
|
token, userID, expiresStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Session{
|
|
Token: token,
|
|
UserID: userID,
|
|
ExpiresAt: expires,
|
|
}, nil
|
|
}
|
|
|
|
func (q *Queries) DeleteSession(ctx context.Context, token string) error {
|
|
_, err := q.db.ExecContext(ctx, `DELETE FROM sessions WHERE token = ?`, token)
|
|
return err
|
|
}
|
|
|
|
func scanUser(s scanner) (User, error) {
|
|
var u User
|
|
var name, avatarURL, createdAt sql.NullString
|
|
|
|
err := s.Scan(&u.ID, &u.Email, &name, &avatarURL, &createdAt)
|
|
if err != nil {
|
|
return u, err
|
|
}
|
|
|
|
u.Name = name.String
|
|
u.AvatarURL = avatarURL.String
|
|
u.CreatedAt = parseTime(createdAt.String)
|
|
return u, nil
|
|
}
|
|
|
|
func generateToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(b), nil
|
|
}
|