From a67c984e1a20b1e8e3fbda94d65ce3a0272e66a2 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Fri, 19 Jun 2026 20:37:06 +0530 Subject: [PATCH] feat: sync groups in posthog --- api/enums.py | 2 + api/routes/organization.py | 17 +- api/services/auth/depends.py | 161 ++++++++++++- api/services/posthog_client.py | 75 +++++- api/tests/test_ai_model_configuration_v2.py | 9 +- api/tests/test_auth_depends.py | 243 +++++++++++++++++++- 6 files changed, 496 insertions(+), 11 deletions(-) diff --git a/api/enums.py b/api/enums.py index 23f5852d..2b8ac637 100644 --- a/api/enums.py +++ b/api/enums.py @@ -174,3 +174,5 @@ class PostHogEvent(str, Enum): AGENT_EMBEDDED = "agent_embedded" SIGNED_UP = "signed_up" SIGNED_IN = "signed_in" + ORGANIZATION_CREATED = "organization_created" + ORGANIZATION_USER_ASSOCIATED = "organization_user_associated" diff --git a/api/routes/organization.py b/api/routes/organization.py index f64d3cfd..8abfa777 100644 --- a/api/routes/organization.py +++ b/api/routes/organization.py @@ -38,7 +38,11 @@ from api.schemas.telephony_phone_number import ( PhoneNumberUpdateRequest, ProviderSyncStatus, ) -from api.services.auth.depends import get_user, get_user_with_selected_organization +from api.services.auth.depends import ( + _sync_posthog_organization_mps_billing_v2_status, + get_user, + get_user_with_selected_organization, +) from api.services.configuration.ai_model_configuration import ( check_for_masked_keys_in_ai_model_configuration_v2, compile_ai_model_configuration_v2, @@ -373,9 +377,10 @@ async def migrate_model_configuration_v2( except ValueError as exc: raise HTTPException(status_code=422, detail=exc.args[0]) + billing_account_status = None if DEPLOYMENT_MODE != "oss": try: - await ensure_hosted_mps_billing_account_v2( + billing_account_status = await ensure_hosted_mps_billing_account_v2( organization_id, created_by=str(user.provider_id), ) @@ -398,6 +403,14 @@ async def migrate_model_configuration_v2( organization_id=organization_id, fallback_user_config=legacy, ) + if DEPLOYMENT_MODE != "oss": + _sync_posthog_organization_mps_billing_v2_status( + organization_id, + uses_mps_billing_v2=bool( + billing_account_status + and billing_account_status.get("billing_mode") == "v2" + ), + ) return await _model_configuration_v2_response( user=user, configuration=configuration, diff --git a/api/services/auth/depends.py b/api/services/auth/depends.py index 019dbc2f..94efae79 100644 --- a/api/services/auth/depends.py +++ b/api/services/auth/depends.py @@ -13,9 +13,16 @@ from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.auth.stack_auth import stackauth from api.services.configuration.registry import ServiceProviders from api.services.mps_billing import ensure_hosted_mps_billing_account_v2 -from api.services.posthog_client import capture_event +from api.services.posthog_client import ( + capture_event, + group_identify, + set_person_properties, +) from api.utils.auth import decode_jwt_token +POSTHOG_ORGANIZATION_GROUP_TYPE = "organization" +POSTHOG_ORGANIZATION_USES_MPS_BILLING_V2_PROPERTY = "uses_mps_billing_v2" + async def get_user( authorization: Annotated[str | None, Header()] = None, @@ -94,6 +101,11 @@ async def get_user( ) = await db_client.get_or_create_organization_by_provider_id( org_provider_id=selected_team_id, user_id=user_model.id ) + if org_was_created: + _sync_created_organization_to_posthog( + organization=organization, + stack_user=stack_user, + ) # Check if user's selected organization differs from the current organization if user_model.selected_organization_id != organization.id: @@ -107,6 +119,13 @@ async def get_user( # Update the user_model object to reflect the change user_model.selected_organization_id = organization.id + _associate_user_with_posthog_organization( + user=user_model, + organization=organization, + stack_user=stack_user, + org_was_created=org_was_created, + ) + # Only create default configuration if organization was just created # This prevents race conditions where multiple concurrent requests # might try to create configurations @@ -156,6 +175,146 @@ async def get_user( return user_model +def _sync_created_organization_to_posthog( + *, + organization, + stack_user: dict | None = None, + created_by_provider_id: str | None = None, + uses_mps_billing_v2: bool | None = None, +) -> None: + """Create/update the PostHog organization group for a newly-created org.""" + try: + organization_id = int(organization.id) + organization_provider_id = getattr(organization, "provider_id", None) + created_by = created_by_provider_id + if created_by is None and stack_user and stack_user.get("id"): + created_by = str(stack_user["id"]) + properties = { + "organization_id": organization_id, + "organization_provider_id": organization_provider_id, + "auth_provider": "stack", + } + if created_by: + properties["created_by_provider_id"] = created_by + if uses_mps_billing_v2 is not None: + properties[POSTHOG_ORGANIZATION_USES_MPS_BILLING_V2_PROPERTY] = ( + uses_mps_billing_v2 + ) + + group_identify( + POSTHOG_ORGANIZATION_GROUP_TYPE, + str(organization_id), + properties, + distinct_id=created_by, + ) + if created_by: + capture_event( + distinct_id=created_by, + event=PostHogEvent.ORGANIZATION_CREATED, + properties=properties, + groups={POSTHOG_ORGANIZATION_GROUP_TYPE: str(organization_id)}, + ) + except Exception: + logger.exception("Failed to sync created organization to PostHog") + + +def _sync_posthog_organization_group_properties( + *, + organization, + uses_mps_billing_v2: bool | None = None, +) -> None: + """Update PostHog organization group properties without creating a person.""" + try: + organization_id = int(organization.id) + properties = { + "organization_id": organization_id, + "organization_provider_id": getattr(organization, "provider_id", None), + "auth_provider": "stack", + } + if uses_mps_billing_v2 is not None: + properties[POSTHOG_ORGANIZATION_USES_MPS_BILLING_V2_PROPERTY] = ( + uses_mps_billing_v2 + ) + + group_identify( + POSTHOG_ORGANIZATION_GROUP_TYPE, + str(organization_id), + properties, + ) + except Exception: + logger.exception("Failed to sync organization group properties to PostHog") + + +def _sync_posthog_organization_mps_billing_v2_status( + organization_id: int, + *, + uses_mps_billing_v2: bool, +) -> None: + """Update the PostHog organization group with current MPS billing status.""" + try: + organization_id = int(organization_id) + group_identify( + POSTHOG_ORGANIZATION_GROUP_TYPE, + str(organization_id), + {POSTHOG_ORGANIZATION_USES_MPS_BILLING_V2_PROPERTY: uses_mps_billing_v2}, + ) + except Exception: + logger.exception("Failed to sync organization billing status to PostHog") + + +def _associate_user_with_posthog_organization( + *, + user: UserModel, + organization, + stack_user: dict | None = None, + user_distinct_id: str | None = None, + org_was_created: bool, + organization_ids: list[int] | None = None, + selected_organization_id: int | None = None, + selected_organization_provider_id: str | None = None, +) -> None: + """Attach the Stack user to the PostHog organization group.""" + try: + organization_id = int(organization.id) + organization_provider_id = getattr(organization, "provider_id", None) + if user_distinct_id is None: + if stack_user and stack_user.get("id"): + user_distinct_id = str(stack_user["id"]) + else: + user_distinct_id = str(user.provider_id) + selected_org_id = selected_organization_id or organization_id + selected_org_provider_id = ( + selected_organization_provider_id or organization_provider_id + ) + person_properties = { + "user_id": user.id, + "user_provider_id": user_distinct_id, + "selected_organization_id": selected_org_id, + "selected_organization_provider_id": selected_org_provider_id, + } + if organization_ids is not None: + person_properties["organization_ids"] = organization_ids + if user.email: + person_properties["email"] = user.email + set_person_properties(user_distinct_id, person_properties) + event_properties = { + "user_id": user.id, + "organization_id": organization_id, + "organization_provider_id": organization_provider_id, + "auth_provider": "stack", + "organization_was_created": org_was_created, + } + + capture_event( + distinct_id=user_distinct_id, + event=PostHogEvent.ORGANIZATION_USER_ASSOCIATED, + properties=event_properties, + groups={POSTHOG_ORGANIZATION_GROUP_TYPE: str(organization_id)}, + ) + except Exception: + logger.exception("Failed to associate user with PostHog organization") + + async def get_user_with_selected_organization( user: Annotated[UserModel, Depends(get_user)], ) -> UserModel: diff --git a/api/services/posthog_client.py b/api/services/posthog_client.py index 1b4d5e03..96e10280 100644 --- a/api/services/posthog_client.py +++ b/api/services/posthog_client.py @@ -1,7 +1,9 @@ +from typing import Any, Optional + from loguru import logger from posthog import Posthog -from api.constants import ENABLE_TELEMETRY, POSTHOG_API_KEY, POSTHOG_HOST +from api.constants import POSTHOG_API_KEY, POSTHOG_HOST _posthog_client: Posthog | None = None @@ -9,23 +11,84 @@ _posthog_client: Posthog | None = None def get_posthog() -> Posthog | None: """Return the lazily-initialised PostHog client, or None if not configured.""" global _posthog_client - if _posthog_client is None and POSTHOG_API_KEY and ENABLE_TELEMETRY: + if _posthog_client is None and POSTHOG_API_KEY: _posthog_client = Posthog(POSTHOG_API_KEY, host=POSTHOG_HOST) return _posthog_client +def shutdown_posthog() -> None: + """Flush queued PostHog messages before a short-lived process exits.""" + client = get_posthog() + if not client: + return + try: + client.shutdown() + except Exception: + logger.exception("Failed to shut down PostHog client") + + +def flush_posthog() -> None: + """Flush queued PostHog messages without shutting down the client.""" + client = get_posthog() + if not client: + return + try: + client.flush() + except Exception: + logger.exception("Failed to flush PostHog client") + + def capture_event( distinct_id: str, event: str, - properties: dict | None = None, + properties: dict[str, Any] | None = None, + groups: Optional[dict[str, str]] = None, ) -> None: """Fire a PostHog event. Silently no-ops if PostHog is not configured.""" client = get_posthog() if not client: return try: - client.capture( - distinct_id=distinct_id, event=event, properties=properties or {} - ) + kwargs: dict[str, Any] = { + "distinct_id": distinct_id, + "event": event, + "properties": properties or {}, + } + if groups: + kwargs["groups"] = groups + client.capture(**kwargs) except Exception: logger.exception(f"Failed to send PostHog event '{event}'") + + +def group_identify( + group_type: str, + group_key: str, + properties: dict[str, Any], + *, + distinct_id: Optional[str] = None, +) -> None: + """Set PostHog group properties. Silently no-ops if PostHog is not configured.""" + client = get_posthog() + if not client: + return + try: + client.group_identify( + group_type, + group_key, + properties, + distinct_id=distinct_id, + ) + except Exception: + logger.exception("Failed to identify PostHog group") + + +def set_person_properties(distinct_id: str, properties: dict[str, Any]) -> None: + """Set PostHog person properties. Silently no-ops if PostHog is not configured.""" + client = get_posthog() + if not client: + return + try: + client.set(distinct_id=distinct_id, properties=properties) + except Exception: + logger.exception("Failed to set PostHog person properties") diff --git a/api/tests/test_ai_model_configuration_v2.py b/api/tests/test_ai_model_configuration_v2.py index 57f7cf83..e39f85e1 100644 --- a/api/tests/test_ai_model_configuration_v2.py +++ b/api/tests/test_ai_model_configuration_v2.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock import pytest from pydantic import ValidationError @@ -401,6 +401,7 @@ async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing( ensure_billing = AsyncMock(return_value={"billing_mode": "v2"}) upsert = AsyncMock() migrate_workflows = AsyncMock() + sync_posthog_billing = Mock() monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas") monkeypatch.setattr( @@ -438,6 +439,11 @@ async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing( "_model_configuration_v2_response", AsyncMock(return_value=expected_response), ) + monkeypatch.setattr( + organization_routes, + "_sync_posthog_organization_mps_billing_v2_status", + sync_posthog_billing, + ) user = SimpleNamespace( id=7, @@ -456,4 +462,5 @@ async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing( organization_id=42, fallback_user_config=legacy, ) + sync_posthog_billing.assert_called_once_with(42, uses_mps_billing_v2=True) assert response == expected_response diff --git a/api/tests/test_auth_depends.py b/api/tests/test_auth_depends.py index 2f33ff58..8b82d379 100644 --- a/api/tests/test_auth_depends.py +++ b/api/tests/test_auth_depends.py @@ -19,10 +19,13 @@ async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch): provider_id="stack-user-1", selected_organization_id=None, ) - organization = SimpleNamespace(id=42) + organization = SimpleNamespace(id=42, provider_id="team-1") existing_config = SimpleNamespace(llm=object(), tts=None, stt=None) ensure_billing = AsyncMock(return_value={"billing_mode": "v2"}) + group_calls = [] + capture_calls = [] + person_calls = [] monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack") monkeypatch.setattr( @@ -60,9 +63,247 @@ async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch): "ensure_hosted_mps_billing_account_v2", ensure_billing, ) + monkeypatch.setattr( + auth_depends, + "group_identify", + lambda *args, **kwargs: group_calls.append((args, kwargs)), + ) + monkeypatch.setattr( + auth_depends, + "capture_event", + lambda *args, **kwargs: capture_calls.append((args, kwargs)), + ) + monkeypatch.setattr( + auth_depends, + "set_person_properties", + lambda *args, **kwargs: person_calls.append((args, kwargs)), + ) result = await auth_depends.get_user(authorization="Bearer token") assert result is user assert result.selected_organization_id == 42 ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1") + + assert len(group_calls) == 1 + group_args, group_kwargs = group_calls[0] + assert group_args == ( + "organization", + "42", + { + "organization_id": 42, + "organization_provider_id": "team-1", + "auth_provider": "stack", + "created_by_provider_id": "stack-user-1", + }, + ) + assert group_kwargs == {"distinct_id": "stack-user-1"} + + assert len(person_calls) == 1 + person_args, person_kwargs = person_calls[0] + assert person_args == ( + "stack-user-1", + { + "user_id": 7, + "user_provider_id": "stack-user-1", + "selected_organization_id": 42, + "selected_organization_provider_id": "team-1", + }, + ) + assert person_kwargs == {} + + assert len(capture_calls) == 2 + org_created_args, org_created_kwargs = capture_calls[0] + assert org_created_args == () + assert org_created_kwargs["distinct_id"] == "stack-user-1" + assert org_created_kwargs["event"] == auth_depends.PostHogEvent.ORGANIZATION_CREATED + assert org_created_kwargs["groups"] == {"organization": "42"} + assert org_created_kwargs["properties"] == { + "organization_id": 42, + "organization_provider_id": "team-1", + "auth_provider": "stack", + "created_by_provider_id": "stack-user-1", + } + + association_args, association_kwargs = capture_calls[1] + assert association_args == () + assert association_kwargs["distinct_id"] == "stack-user-1" + assert ( + association_kwargs["event"] + == auth_depends.PostHogEvent.ORGANIZATION_USER_ASSOCIATED + ) + assert association_kwargs["groups"] == {"organization": "42"} + assert association_kwargs["properties"] == { + "user_id": 7, + "organization_id": 42, + "organization_provider_id": "team-1", + "auth_provider": "stack", + "organization_was_created": True, + } + + +def test_associate_user_with_posthog_org_supports_backfill_arguments(monkeypatch): + user = SimpleNamespace( + id=7, + email="user@example.com", + provider_id="stack-user-1", + selected_organization_id=99, + ) + organization = SimpleNamespace(id=42, provider_id="team-1") + person_calls = [] + capture_calls = [] + + monkeypatch.setattr( + auth_depends, + "set_person_properties", + lambda *args, **kwargs: person_calls.append((args, kwargs)), + ) + monkeypatch.setattr( + auth_depends, + "capture_event", + lambda *args, **kwargs: capture_calls.append((args, kwargs)), + ) + + auth_depends._associate_user_with_posthog_organization( + user=user, + organization=organization, + user_distinct_id="stack-user-1", + org_was_created=False, + organization_ids=[42, 99], + selected_organization_id=99, + selected_organization_provider_id="team-99", + ) + + assert person_calls == [ + ( + ( + "stack-user-1", + { + "user_id": 7, + "user_provider_id": "stack-user-1", + "selected_organization_id": 99, + "selected_organization_provider_id": "team-99", + "organization_ids": [42, 99], + "email": "user@example.com", + }, + ), + {}, + ) + ] + + assert len(capture_calls) == 1 + _, capture_kwargs = capture_calls[0] + assert capture_kwargs["distinct_id"] == "stack-user-1" + assert ( + capture_kwargs["event"] + == auth_depends.PostHogEvent.ORGANIZATION_USER_ASSOCIATED + ) + assert capture_kwargs["groups"] == {"organization": "42"} + assert capture_kwargs["properties"] == { + "user_id": 7, + "organization_id": 42, + "organization_provider_id": "team-1", + "auth_provider": "stack", + "organization_was_created": False, + } + assert "backfilled" not in capture_kwargs["properties"] + + +def test_sync_created_organization_to_posthog_supports_billing_flag(monkeypatch): + organization = SimpleNamespace(id=42, provider_id="team-1") + group_calls = [] + capture_calls = [] + + monkeypatch.setattr( + auth_depends, + "group_identify", + lambda *args, **kwargs: group_calls.append((args, kwargs)), + ) + monkeypatch.setattr( + auth_depends, + "capture_event", + lambda *args, **kwargs: capture_calls.append((args, kwargs)), + ) + + auth_depends._sync_created_organization_to_posthog( + organization=organization, + created_by_provider_id="stack-user-1", + uses_mps_billing_v2=True, + ) + + _, group_kwargs = group_calls[0] + group_args, _ = group_calls[0] + assert group_args == ( + "organization", + "42", + { + "organization_id": 42, + "organization_provider_id": "team-1", + "auth_provider": "stack", + "created_by_provider_id": "stack-user-1", + "uses_mps_billing_v2": True, + }, + ) + assert group_kwargs == {"distinct_id": "stack-user-1"} + + _, capture_kwargs = capture_calls[0] + assert capture_kwargs["distinct_id"] == "stack-user-1" + assert capture_kwargs["properties"]["uses_mps_billing_v2"] is True + + +def test_sync_posthog_organization_group_properties_has_no_distinct_id(monkeypatch): + organization = SimpleNamespace(id=42, provider_id="team-1") + group_calls = [] + + monkeypatch.setattr( + auth_depends, + "group_identify", + lambda *args, **kwargs: group_calls.append((args, kwargs)), + ) + + auth_depends._sync_posthog_organization_group_properties( + organization=organization, + uses_mps_billing_v2=True, + ) + + assert group_calls == [ + ( + ( + "organization", + "42", + { + "organization_id": 42, + "organization_provider_id": "team-1", + "auth_provider": "stack", + "uses_mps_billing_v2": True, + }, + ), + {}, + ) + ] + + +def test_sync_posthog_organization_mps_billing_v2_status(monkeypatch): + group_calls = [] + + monkeypatch.setattr( + auth_depends, + "group_identify", + lambda *args, **kwargs: group_calls.append((args, kwargs)), + ) + + auth_depends._sync_posthog_organization_mps_billing_v2_status( + 42, + uses_mps_billing_v2=True, + ) + + assert group_calls == [ + ( + ( + "organization", + "42", + {"uses_mps_billing_v2": True}, + ), + {}, + ) + ]