package billing import ( "crypto/hmac" "crypto/sha256" "encoding/hex" "testing" ) func TestVerifyWebhook(t *testing.T) { secret := "test-webhook-secret" client := &LemonClient{webhookSecret: secret} t.Run("valid signature", func(t *testing.T) { payload := []byte(`{"meta":{"event_name":"subscription_created"}}`) mac := hmac.New(sha256.New, []byte(secret)) mac.Write(payload) signature := hex.EncodeToString(mac.Sum(nil)) if !client.VerifyWebhook(payload, signature) { t.Error("expected valid signature to pass verification") } }) t.Run("invalid signature", func(t *testing.T) { payload := []byte(`{"meta":{"event_name":"subscription_created"}}`) signature := "invalid-signature" if client.VerifyWebhook(payload, signature) { t.Error("expected invalid signature to fail verification") } }) t.Run("wrong secret", func(t *testing.T) { payload := []byte(`{"meta":{"event_name":"subscription_created"}}`) mac := hmac.New(sha256.New, []byte("different-secret")) mac.Write(payload) signature := hex.EncodeToString(mac.Sum(nil)) if client.VerifyWebhook(payload, signature) { t.Error("expected signature with wrong secret to fail") } }) t.Run("empty secret", func(t *testing.T) { clientNoSecret := &LemonClient{webhookSecret: ""} payload := []byte(`{"test": true}`) if clientNoSecret.VerifyWebhook(payload, "any-signature") { t.Error("expected empty secret to fail verification") } }) t.Run("empty payload", func(t *testing.T) { payload := []byte{} mac := hmac.New(sha256.New, []byte(secret)) mac.Write(payload) signature := hex.EncodeToString(mac.Sum(nil)) if !client.VerifyWebhook(payload, signature) { t.Error("expected empty payload with valid signature to pass") } }) } func TestNormalizeStatus(t *testing.T) { tests := []struct { input string expected string }{ {"on_trial", "active"}, {"active", "active"}, {"paused", "past_due"}, {"past_due", "past_due"}, {"unpaid", "past_due"}, {"cancelled", "cancelled"}, {"expired", "cancelled"}, {"unknown_status", "unknown_status"}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { got := normalizeStatus(tt.input) if got != tt.expected { t.Errorf("normalizeStatus(%q) = %q, want %q", tt.input, got, tt.expected) } }) } } func TestParseWebhookEvent(t *testing.T) { client := &LemonClient{} t.Run("valid event", func(t *testing.T) { payload := []byte(`{ "meta": { "event_name": "subscription_created", "custom_data": {"tenant_id": "tenant-123", "user_id": "user-456"} }, "data": {"id": "sub-789"} }`) event, err := client.ParseWebhookEvent(payload) if err != nil { t.Fatalf("unexpected error: %v", err) } if event.Meta.EventName != "subscription_created" { t.Errorf("EventName = %q, want %q", event.Meta.EventName, "subscription_created") } if event.Meta.CustomData["tenant_id"] != "tenant-123" { t.Errorf("tenant_id = %q, want %q", event.Meta.CustomData["tenant_id"], "tenant-123") } }) t.Run("invalid JSON", func(t *testing.T) { payload := []byte(`{invalid json}`) _, err := client.ParseWebhookEvent(payload) if err == nil { t.Error("expected error for invalid JSON") } }) } func TestGetSubscriptionData(t *testing.T) { event := &WebhookEvent{ Data: []byte(`{ "id": "sub-123", "attributes": { "customer_id": 456, "variant_name": "Pro Monthly", "user_email": "test@example.com", "status": "active" } }`), } data, err := event.GetSubscriptionData() if err != nil { t.Fatalf("unexpected error: %v", err) } if data.ID != "sub-123" { t.Errorf("ID = %q, want %q", data.ID, "sub-123") } if data.Attributes.CustomerID != 456 { t.Errorf("CustomerID = %d, want %d", data.Attributes.CustomerID, 456) } if data.Attributes.Status != "active" { t.Errorf("Status = %q, want %q", data.Attributes.Status, "active") } } func TestGetOrderData(t *testing.T) { event := &WebhookEvent{ Data: []byte(`{ "id": "order-123", "attributes": { "user_name": "John Doe", "user_email": "john@example.com", "total_usd": 500 } }`), } data, err := event.GetOrderData() if err != nil { t.Fatalf("unexpected error: %v", err) } if data.ID != "order-123" { t.Errorf("ID = %q, want %q", data.ID, "order-123") } if data.Attributes.TotalUsd != 500 { t.Errorf("TotalUsd = %d, want %d", data.Attributes.TotalUsd, 500) } }