203 lines
4.3 KiB
Go
203 lines
4.3 KiB
Go
|
|
package tenant
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/aes"
|
||
|
|
"crypto/cipher"
|
||
|
|
"crypto/rand"
|
||
|
|
"crypto/sha256"
|
||
|
|
"database/sql"
|
||
|
|
"encoding/hex"
|
||
|
|
"fmt"
|
||
|
|
"os"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
type Secret struct {
|
||
|
|
Key string
|
||
|
|
CreatedAt time.Time
|
||
|
|
UpdatedAt time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
var masterKey []byte
|
||
|
|
|
||
|
|
func init() {
|
||
|
|
key := os.Getenv("SECRETS_MASTER_KEY")
|
||
|
|
if key == "" {
|
||
|
|
key = "writekit-dev-key-change-in-prod"
|
||
|
|
}
|
||
|
|
hash := sha256.Sum256([]byte(key))
|
||
|
|
masterKey = hash[:]
|
||
|
|
}
|
||
|
|
|
||
|
|
func deriveKey(tenantID string) []byte {
|
||
|
|
combined := append(masterKey, []byte(tenantID)...)
|
||
|
|
hash := sha256.Sum256(combined)
|
||
|
|
return hash[:]
|
||
|
|
}
|
||
|
|
|
||
|
|
func encrypt(plaintext []byte, tenantID string) (ciphertext, nonce []byte, err error) {
|
||
|
|
key := deriveKey(tenantID)
|
||
|
|
block, err := aes.NewCipher(key)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
gcm, err := cipher.NewGCM(block)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
nonce = make([]byte, gcm.NonceSize())
|
||
|
|
if _, err := rand.Read(nonce); err != nil {
|
||
|
|
return nil, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
ciphertext = gcm.Seal(nil, nonce, plaintext, nil)
|
||
|
|
return ciphertext, nonce, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func decrypt(ciphertext, nonce []byte, tenantID string) ([]byte, error) {
|
||
|
|
key := deriveKey(tenantID)
|
||
|
|
block, err := aes.NewCipher(key)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
gcm, err := cipher.NewGCM(block)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return gcm.Open(nil, nonce, ciphertext, nil)
|
||
|
|
}
|
||
|
|
|
||
|
|
func ensureSecretsTable(db *sql.DB) error {
|
||
|
|
_, err := db.Exec(`
|
||
|
|
CREATE TABLE IF NOT EXISTS secrets (
|
||
|
|
key TEXT PRIMARY KEY,
|
||
|
|
value BLOB NOT NULL,
|
||
|
|
nonce BLOB NOT NULL,
|
||
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||
|
|
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||
|
|
)
|
||
|
|
`)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func SetSecret(db *sql.DB, tenantID, key, value string) error {
|
||
|
|
if err := ensureSecretsTable(db); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
ciphertext, nonce, err := encrypt([]byte(value), tenantID)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("encrypt: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err = db.Exec(`
|
||
|
|
INSERT INTO secrets (key, value, nonce, updated_at)
|
||
|
|
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
|
||
|
|
ON CONFLICT(key) DO UPDATE SET
|
||
|
|
value = excluded.value,
|
||
|
|
nonce = excluded.nonce,
|
||
|
|
updated_at = CURRENT_TIMESTAMP
|
||
|
|
`, key, ciphertext, nonce)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func GetSecret(db *sql.DB, tenantID, key string) (string, error) {
|
||
|
|
if err := ensureSecretsTable(db); err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
|
||
|
|
var ciphertext, nonce []byte
|
||
|
|
err := db.QueryRow(`SELECT value, nonce FROM secrets WHERE key = ?`, key).Scan(&ciphertext, &nonce)
|
||
|
|
if err == sql.ErrNoRows {
|
||
|
|
return "", nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
|
||
|
|
plaintext, err := decrypt(ciphertext, nonce, tenantID)
|
||
|
|
if err != nil {
|
||
|
|
return "", fmt.Errorf("decrypt: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return string(plaintext), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func DeleteSecret(db *sql.DB, key string) error {
|
||
|
|
if err := ensureSecretsTable(db); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err := db.Exec(`DELETE FROM secrets WHERE key = ?`, key)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func ListSecrets(db *sql.DB) ([]Secret, error) {
|
||
|
|
if err := ensureSecretsTable(db); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
rows, err := db.Query(`SELECT key, created_at, updated_at FROM secrets ORDER BY key`)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
var secrets []Secret
|
||
|
|
for rows.Next() {
|
||
|
|
var s Secret
|
||
|
|
var createdAt, updatedAt string
|
||
|
|
if err := rows.Scan(&s.Key, &createdAt, &updatedAt); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
s.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||
|
|
s.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||
|
|
secrets = append(secrets, s)
|
||
|
|
}
|
||
|
|
return secrets, rows.Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
func GetSecretsMap(db *sql.DB, tenantID string) (map[string]string, error) {
|
||
|
|
if err := ensureSecretsTable(db); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
rows, err := db.Query(`SELECT key, value, nonce FROM secrets`)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
secrets := make(map[string]string)
|
||
|
|
for rows.Next() {
|
||
|
|
var key string
|
||
|
|
var ciphertext, nonce []byte
|
||
|
|
if err := rows.Scan(&key, &ciphertext, &nonce); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
plaintext, err := decrypt(ciphertext, nonce, tenantID)
|
||
|
|
if err != nil {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
secrets[key] = string(plaintext)
|
||
|
|
}
|
||
|
|
return secrets, rows.Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
func MaskSecret(value string) string {
|
||
|
|
if len(value) <= 8 {
|
||
|
|
return "••••••••"
|
||
|
|
}
|
||
|
|
return value[:4] + "••••" + value[len(value)-4:]
|
||
|
|
}
|
||
|
|
|
||
|
|
func GenerateSecretID() string {
|
||
|
|
b := make([]byte, 16)
|
||
|
|
rand.Read(b)
|
||
|
|
return hex.EncodeToString(b)
|
||
|
|
}
|