writekit/internal/tenant/runner.go
2026-01-09 00:16:46 +02:00

621 lines
16 KiB
Go

package tenant
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"time"
extism "github.com/extism/go-sdk"
)
type PluginRunner struct {
db *sql.DB
q *Queries
tenantID string
cache map[string]*extism.Plugin
mu sync.RWMutex
}
func NewPluginRunner(db *sql.DB, tenantID string) *PluginRunner {
return &PluginRunner{
db: db,
q: NewQueries(db),
tenantID: tenantID,
cache: make(map[string]*extism.Plugin),
}
}
type HookEvent struct {
Hook string `json:"hook"`
Data map[string]any `json:"data"`
}
type PluginResult struct {
PluginID string `json:"plugin_id"`
Success bool `json:"success"`
Output string `json:"output,omitempty"`
Error string `json:"error,omitempty"`
Duration int64 `json:"duration_ms"`
}
// TriggerHook executes plugins for an event hook (fire-and-forget)
func (r *PluginRunner) TriggerHook(ctx context.Context, hook string, data map[string]any) []PluginResult {
plugins, err := r.q.GetPluginsByHook(ctx, hook)
if err != nil || len(plugins) == 0 {
return nil
}
secrets, _ := GetSecretsMap(r.db, r.tenantID)
var results []PluginResult
for _, p := range plugins {
result := r.runPlugin(ctx, &p, hook, data, secrets)
results = append(results, result)
}
return results
}
// ValidationResult represents the result of a validation hook
type ValidationResult struct {
Allowed bool `json:"allowed"`
Reason string `json:"reason,omitempty"`
}
// TriggerValidation executes a validation hook and returns whether the action is allowed
// Returns (allowed, reason, error). If no plugins exist, allowed defaults to true.
func (r *PluginRunner) TriggerValidation(ctx context.Context, hook string, data map[string]any) (bool, string, error) {
plugins, err := r.q.GetPluginsByHook(ctx, hook)
if err != nil {
return true, "", err // Default to allowed on error
}
if len(plugins) == 0 {
return true, "", nil // Default to allowed if no plugins
}
secrets, _ := GetSecretsMap(r.db, r.tenantID)
// Run first enabled plugin only (validation is exclusive)
for _, p := range plugins {
result := r.runPlugin(ctx, &p, hook, data, secrets)
if !result.Success {
// Plugin failed to run, default to allowed
continue
}
var validation ValidationResult
if err := json.Unmarshal([]byte(result.Output), &validation); err != nil {
continue // Invalid output, skip this plugin
}
if !validation.Allowed {
return false, validation.Reason, nil
}
}
return true, "", nil
}
// TriggerTransform executes a transform hook and returns the transformed data
// If no plugins exist or all fail, returns the original data unchanged.
func (r *PluginRunner) TriggerTransform(ctx context.Context, hook string, data map[string]any) (map[string]any, error) {
plugins, err := r.q.GetPluginsByHook(ctx, hook)
if err != nil || len(plugins) == 0 {
return data, nil
}
secrets, _ := GetSecretsMap(r.db, r.tenantID)
current := data
// Chain transforms - each plugin receives output of previous
for _, p := range plugins {
result := r.runPlugin(ctx, &p, hook, current, secrets)
if !result.Success {
continue // Skip failed plugins
}
var transformed map[string]any
if err := json.Unmarshal([]byte(result.Output), &transformed); err != nil {
continue // Invalid output, skip
}
current = transformed
}
return current, nil
}
func (r *PluginRunner) runPlugin(ctx context.Context, p *Plugin, hook string, data map[string]any, secrets map[string]string) PluginResult {
start := time.Now()
result := PluginResult{PluginID: p.ID}
plugin, err := r.getOrCreatePlugin(p, secrets)
if err != nil {
result.Error = err.Error()
result.Duration = time.Since(start).Milliseconds()
return result
}
input, _ := json.Marshal(data)
funcName := hookToFunction(hook)
_, output, err := plugin.Call(funcName, input)
if err != nil {
result.Error = err.Error()
} else {
result.Success = true
result.Output = string(output)
}
result.Duration = time.Since(start).Milliseconds()
return result
}
func (r *PluginRunner) getOrCreatePlugin(p *Plugin, secrets map[string]string) (*extism.Plugin, error) {
r.mu.RLock()
cached, ok := r.cache[p.ID]
r.mu.RUnlock()
if ok {
return cached, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if cached, ok = r.cache[p.ID]; ok {
return cached, nil
}
manifest := extism.Manifest{
Wasm: []extism.Wasm{
extism.WasmData{Data: p.Wasm},
},
AllowedHosts: []string{"*"},
Config: secrets,
}
config := extism.PluginConfig{
EnableWasi: true,
}
plugin, err := extism.NewPlugin(context.Background(), manifest, config, r.hostFunctions())
if err != nil {
return nil, fmt.Errorf("create plugin: %w", err)
}
r.cache[p.ID] = plugin
return plugin, nil
}
func (r *PluginRunner) hostFunctions() []extism.HostFunction {
return []extism.HostFunction{
r.httpRequestHost(),
r.kvGetHost(),
r.kvSetHost(),
r.logHost(),
}
}
func (r *PluginRunner) httpRequestHost() extism.HostFunction {
return extism.NewHostFunctionWithStack(
"http_request",
func(ctx context.Context, p *extism.CurrentPlugin, stack []uint64) {
input, err := p.ReadBytes(stack[0])
if err != nil {
return
}
var req struct {
URL string `json:"url"`
Method string `json:"method"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
}
if err := json.Unmarshal(input, &req); err != nil {
return
}
if req.Method == "" {
req.Method = "GET"
}
httpReq, err := http.NewRequestWithContext(ctx, req.Method, req.URL, bytes.NewBufferString(req.Body))
if err != nil {
return
}
for k, v := range req.Headers {
httpReq.Header.Set(k, v)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
errResp, _ := json.Marshal(map[string]any{"error": err.Error()})
offset, _ := p.WriteBytes(errResp)
stack[0] = offset
return
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
headers := make(map[string]string)
for k := range resp.Header {
headers[k] = resp.Header.Get(k)
}
result, _ := json.Marshal(map[string]any{
"status": resp.StatusCode,
"headers": headers,
"body": string(body),
})
offset, _ := p.WriteBytes(result)
stack[0] = offset
},
[]extism.ValueType{extism.ValueTypeI64},
[]extism.ValueType{extism.ValueTypeI64},
)
}
func (r *PluginRunner) kvGetHost() extism.HostFunction {
return extism.NewHostFunctionWithStack(
"kv_get",
func(ctx context.Context, p *extism.CurrentPlugin, stack []uint64) {
key, err := p.ReadString(stack[0])
if err != nil {
return
}
value, _ := GetSecret(r.db, r.tenantID, "kv:"+key)
offset, _ := p.WriteString(value)
stack[0] = offset
},
[]extism.ValueType{extism.ValueTypeI64},
[]extism.ValueType{extism.ValueTypeI64},
)
}
func (r *PluginRunner) kvSetHost() extism.HostFunction {
return extism.NewHostFunctionWithStack(
"kv_set",
func(ctx context.Context, p *extism.CurrentPlugin, stack []uint64) {
input, err := p.ReadBytes(stack[0])
if err != nil {
return
}
var kv struct {
Key string `json:"key"`
Value string `json:"value"`
}
if err := json.Unmarshal(input, &kv); err != nil {
return
}
SetSecret(r.db, r.tenantID, "kv:"+kv.Key, kv.Value)
stack[0] = 0
},
[]extism.ValueType{extism.ValueTypeI64},
[]extism.ValueType{extism.ValueTypeI64},
)
}
func (r *PluginRunner) logHost() extism.HostFunction {
return extism.NewHostFunctionWithStack(
"log",
func(ctx context.Context, p *extism.CurrentPlugin, stack []uint64) {
msg, _ := p.ReadString(stack[0])
fmt.Printf("[plugin] %s\n", msg)
stack[0] = 0
},
[]extism.ValueType{extism.ValueTypeI64},
[]extism.ValueType{extism.ValueTypeI64},
)
}
func (r *PluginRunner) InvalidateCache(pluginID string) {
r.mu.Lock()
defer r.mu.Unlock()
if plugin, ok := r.cache[pluginID]; ok {
plugin.Close(context.Background())
delete(r.cache, pluginID)
}
}
func (r *PluginRunner) Close() {
r.mu.Lock()
defer r.mu.Unlock()
for _, plugin := range r.cache {
plugin.Close(context.Background())
}
r.cache = make(map[string]*extism.Plugin)
}
// HookPattern defines how a hook should be executed
type HookPattern string
const (
PatternEvent HookPattern = "event" // Fire-and-forget notifications
PatternValidation HookPattern = "validation" // Returns allowed/rejected decision
PatternTransform HookPattern = "transform" // Modifies and returns data
)
// HookInfo contains metadata about a hook
type HookInfo struct {
Name string
Pattern HookPattern
Description string
}
// AvailableHooks lists all supported hooks with metadata
var AvailableHooks = []HookInfo{
// Content hooks
{Name: "post.published", Pattern: PatternEvent, Description: "Triggered when a post is published"},
{Name: "post.updated", Pattern: PatternEvent, Description: "Triggered when a post is updated"},
{Name: "content.render", Pattern: PatternTransform, Description: "Transform HTML before display"},
// Engagement hooks
{Name: "comment.validate", Pattern: PatternValidation, Description: "Validate comment before creation"},
{Name: "comment.created", Pattern: PatternEvent, Description: "Triggered when a comment is created"},
{Name: "member.subscribed", Pattern: PatternEvent, Description: "Triggered when a member subscribes"},
// Utility hooks
{Name: "asset.uploaded", Pattern: PatternEvent, Description: "Triggered when an asset is uploaded"},
{Name: "analytics.sync", Pattern: PatternEvent, Description: "Triggered during analytics sync"},
}
// GetHookPattern returns the pattern for a given hook
func GetHookPattern(hook string) HookPattern {
for _, h := range AvailableHooks {
if h.Name == hook {
return h.Pattern
}
}
return PatternEvent
}
// GetHookNames returns just the hook names (for API responses)
func GetHookNames() []string {
names := make([]string, len(AvailableHooks))
for i, h := range AvailableHooks {
names[i] = h.Name
}
return names
}
// TestPluginRunner runs plugins for testing with log capture
type TestPluginRunner struct {
db *sql.DB
tenantID string
secrets map[string]string
logs []string
}
// TestResult contains the result of a plugin test run
type TestResult struct {
Success bool `json:"success"`
Output string `json:"output,omitempty"`
Logs []string `json:"logs"`
Error string `json:"error,omitempty"`
Duration int64 `json:"duration_ms"`
}
func NewTestPluginRunner(db *sql.DB, tenantID string, secrets map[string]string) *TestPluginRunner {
return &TestPluginRunner{
db: db,
tenantID: tenantID,
secrets: secrets,
logs: []string{},
}
}
func (r *TestPluginRunner) RunTest(ctx context.Context, wasm []byte, hook string, data map[string]any) TestResult {
start := time.Now()
result := TestResult{Logs: []string{}}
manifest := extism.Manifest{
Wasm: []extism.Wasm{
extism.WasmData{Data: wasm},
},
AllowedHosts: []string{"*"},
Config: r.secrets,
}
config := extism.PluginConfig{
EnableWasi: true,
}
plugin, err := extism.NewPlugin(ctx, manifest, config, r.testHostFunctions())
if err != nil {
result.Error = fmt.Sprintf("Failed to create plugin: %v", err)
result.Duration = time.Since(start).Milliseconds()
return result
}
defer plugin.Close(ctx)
input, _ := json.Marshal(data)
funcName := hookToFunction(hook)
_, output, err := plugin.Call(funcName, input)
if err != nil {
result.Error = err.Error()
} else {
result.Success = true
result.Output = string(output)
}
result.Logs = r.logs
result.Duration = time.Since(start).Milliseconds()
return result
}
func (r *TestPluginRunner) testHostFunctions() []extism.HostFunction {
return []extism.HostFunction{
r.testHttpRequestHost(),
r.testKvGetHost(),
r.testKvSetHost(),
r.testLogHost(),
}
}
func (r *TestPluginRunner) testHttpRequestHost() extism.HostFunction {
return extism.NewHostFunctionWithStack(
"http_request",
func(ctx context.Context, p *extism.CurrentPlugin, stack []uint64) {
input, err := p.ReadBytes(stack[0])
if err != nil {
return
}
var req struct {
URL string `json:"url"`
Method string `json:"method"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
}
if err := json.Unmarshal(input, &req); err != nil {
return
}
if req.Method == "" {
req.Method = "GET"
}
r.logs = append(r.logs, fmt.Sprintf("[HTTP] %s %s", req.Method, req.URL))
httpReq, err := http.NewRequestWithContext(ctx, req.Method, req.URL, bytes.NewBufferString(req.Body))
if err != nil {
errResp, _ := json.Marshal(map[string]any{"error": err.Error()})
offset, _ := p.WriteBytes(errResp)
stack[0] = offset
return
}
for k, v := range req.Headers {
httpReq.Header.Set(k, v)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
r.logs = append(r.logs, fmt.Sprintf("[HTTP] Error: %v", err))
errResp, _ := json.Marshal(map[string]any{"error": err.Error()})
offset, _ := p.WriteBytes(errResp)
stack[0] = offset
return
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
r.logs = append(r.logs, fmt.Sprintf("[HTTP] Response: %d (%d bytes)", resp.StatusCode, len(body)))
headers := make(map[string]string)
for k := range resp.Header {
headers[k] = resp.Header.Get(k)
}
result, _ := json.Marshal(map[string]any{
"status": resp.StatusCode,
"headers": headers,
"body": string(body),
})
offset, _ := p.WriteBytes(result)
stack[0] = offset
},
[]extism.ValueType{extism.ValueTypeI64},
[]extism.ValueType{extism.ValueTypeI64},
)
}
func (r *TestPluginRunner) testKvGetHost() extism.HostFunction {
return extism.NewHostFunctionWithStack(
"kv_get",
func(ctx context.Context, p *extism.CurrentPlugin, stack []uint64) {
key, err := p.ReadString(stack[0])
if err != nil {
return
}
value, _ := GetSecret(r.db, r.tenantID, "kv:"+key)
r.logs = append(r.logs, fmt.Sprintf("[KV] GET %s", key))
offset, _ := p.WriteString(value)
stack[0] = offset
},
[]extism.ValueType{extism.ValueTypeI64},
[]extism.ValueType{extism.ValueTypeI64},
)
}
func (r *TestPluginRunner) testKvSetHost() extism.HostFunction {
return extism.NewHostFunctionWithStack(
"kv_set",
func(ctx context.Context, p *extism.CurrentPlugin, stack []uint64) {
input, err := p.ReadBytes(stack[0])
if err != nil {
return
}
var kv struct {
Key string `json:"key"`
Value string `json:"value"`
}
if err := json.Unmarshal(input, &kv); err != nil {
return
}
r.logs = append(r.logs, fmt.Sprintf("[KV] SET %s = %s", kv.Key, kv.Value))
SetSecret(r.db, r.tenantID, "kv:"+kv.Key, kv.Value)
stack[0] = 0
},
[]extism.ValueType{extism.ValueTypeI64},
[]extism.ValueType{extism.ValueTypeI64},
)
}
func (r *TestPluginRunner) testLogHost() extism.HostFunction {
return extism.NewHostFunctionWithStack(
"log",
func(ctx context.Context, p *extism.CurrentPlugin, stack []uint64) {
msg, _ := p.ReadString(stack[0])
r.logs = append(r.logs, fmt.Sprintf("[LOG] %s", msg))
stack[0] = 0
},
[]extism.ValueType{extism.ValueTypeI64},
[]extism.ValueType{extism.ValueTypeI64},
)
}
func hookToFunction(hook string) string {
switch hook {
case "post.published":
return "on_post_published"
case "post.updated":
return "on_post_updated"
case "content.render":
return "render_content"
case "comment.validate":
return "validate_comment"
case "comment.created":
return "on_comment_created"
case "member.subscribed":
return "on_member_subscribed"
case "asset.uploaded":
return "on_asset_uploaded"
case "analytics.sync":
return "on_analytics_sync"
default:
return "run"
}
}