dograh/api/services/auth/depends.py
2026-06-19 20:37:06 +05:30

511 lines
19 KiB
Python

from typing import Annotated, Optional
import httpx
from fastapi import Depends, Header, HTTPException, Query, WebSocket
from loguru import logger
from pydantic import ValidationError
from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.stack_auth import stackauth
from api.services.configuration.registry import ServiceProviders
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.posthog_client import (
capture_event,
group_identify,
set_person_properties,
)
from api.utils.auth import decode_jwt_token
POSTHOG_ORGANIZATION_GROUP_TYPE = "organization"
POSTHOG_ORGANIZATION_USES_MPS_BILLING_V2_PROPERTY = "uses_mps_billing_v2"
async def get_user(
authorization: Annotated[str | None, Header()] = None,
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
) -> UserModel:
# ------------------------------------------------------------------
# Check if API key is provided (takes precedence)
# ------------------------------------------------------------------
if x_api_key:
return await _handle_api_key_auth(x_api_key)
# ------------------------------------------------------------------
# Check if we're using local (email/password) auth
# ------------------------------------------------------------------
if AUTH_PROVIDER == "local":
return await _handle_oss_auth(authorization)
# ------------------------------------------------------------------
# 1. Validate and fetch the authenticated Stack user
# ------------------------------------------------------------------
stack_user = await stackauth.get_user(authorization)
if stack_user is None:
raise HTTPException(status_code=401, detail="Unauthorized")
# ------------------------------------------------------------------
# 2. Ensure the user has a team (Stack "selected_team_id")
# ------------------------------------------------------------------
selected_team_id: str | None = stack_user.get("selected_team_id")
if not selected_team_id and stack_user.get("selected_team"):
selected_team_id = stack_user["selected_team"].get("id")
if not selected_team_id:
raise HTTPException(status_code=400, detail="No team selected")
# ------------------------------------------------------------------
# 3. Persist/Fetch the local User model
# ------------------------------------------------------------------
try:
(
user_model,
user_was_created,
) = await db_client.get_or_create_user_by_provider_id(stack_user["id"])
# Sync email from Stack Auth if available and not already set
stack_email = stack_user.get("primary_email_verified") and stack_user.get(
"primary_email"
)
if stack_email and user_model.email != stack_email:
await db_client.update_user_email(user_model.id, stack_email)
user_model.email = stack_email
if user_was_created:
capture_event(
distinct_id=str(stack_user["id"]),
event=PostHogEvent.SIGNED_UP,
properties={
"auth_provider": "stack",
},
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error while creating user from database {e}"
)
# ------------------------------------------------------------------
# 4. Persist Organization (team) and mapping in local database
# ------------------------------------------------------------------
try:
(
organization,
org_was_created,
) = await db_client.get_or_create_organization_by_provider_id(
org_provider_id=selected_team_id, user_id=user_model.id
)
if org_was_created:
_sync_created_organization_to_posthog(
organization=organization,
stack_user=stack_user,
)
# Check if user's selected organization differs from the current organization
if user_model.selected_organization_id != organization.id:
await db_client.add_user_to_organization(user_model.id, organization.id)
# Update user's selected organization
await db_client.update_user_selected_organization(
user_model.id, organization.id
)
# Update the user_model object to reflect the change
user_model.selected_organization_id = organization.id
_associate_user_with_posthog_organization(
user=user_model,
organization=organization,
stack_user=stack_user,
org_was_created=org_was_created,
)
# Only create default configuration if organization was just created
# This prevents race conditions where multiple concurrent requests
# might try to create configurations
if org_was_created:
try:
await ensure_hosted_mps_billing_account_v2(
organization.id,
created_by=str(stack_user["id"]),
)
except Exception:
logger.warning(
"Failed to initialize hosted MPS billing account for "
"organization {}",
organization.id,
exc_info=True,
)
existing_cfg = await db_client.get_user_configurations(user_model.id)
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
mps_config = await create_user_configuration_with_mps_key(
user_model.id, organization.id, stack_user["id"]
)
if mps_config:
await db_client.update_user_configuration(
user_model.id, mps_config
)
from api.enums import OrganizationConfigurationKey
from api.services.configuration.ai_model_configuration import (
convert_legacy_ai_model_configuration_to_v2,
)
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
mps_config
)
await db_client.upsert_configuration(
organization.id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
model_config_v2.model_dump(mode="json", exclude_none=True),
)
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"Failed to map user to organization: {exc}",
)
return user_model
def _sync_created_organization_to_posthog(
*,
organization,
stack_user: dict | None = None,
created_by_provider_id: str | None = None,
uses_mps_billing_v2: bool | None = None,
) -> None:
"""Create/update the PostHog organization group for a newly-created org."""
try:
organization_id = int(organization.id)
organization_provider_id = getattr(organization, "provider_id", None)
created_by = created_by_provider_id
if created_by is None and stack_user and stack_user.get("id"):
created_by = str(stack_user["id"])
properties = {
"organization_id": organization_id,
"organization_provider_id": organization_provider_id,
"auth_provider": "stack",
}
if created_by:
properties["created_by_provider_id"] = created_by
if uses_mps_billing_v2 is not None:
properties[POSTHOG_ORGANIZATION_USES_MPS_BILLING_V2_PROPERTY] = (
uses_mps_billing_v2
)
group_identify(
POSTHOG_ORGANIZATION_GROUP_TYPE,
str(organization_id),
properties,
distinct_id=created_by,
)
if created_by:
capture_event(
distinct_id=created_by,
event=PostHogEvent.ORGANIZATION_CREATED,
properties=properties,
groups={POSTHOG_ORGANIZATION_GROUP_TYPE: str(organization_id)},
)
except Exception:
logger.exception("Failed to sync created organization to PostHog")
def _sync_posthog_organization_group_properties(
*,
organization,
uses_mps_billing_v2: bool | None = None,
) -> None:
"""Update PostHog organization group properties without creating a person."""
try:
organization_id = int(organization.id)
properties = {
"organization_id": organization_id,
"organization_provider_id": getattr(organization, "provider_id", None),
"auth_provider": "stack",
}
if uses_mps_billing_v2 is not None:
properties[POSTHOG_ORGANIZATION_USES_MPS_BILLING_V2_PROPERTY] = (
uses_mps_billing_v2
)
group_identify(
POSTHOG_ORGANIZATION_GROUP_TYPE,
str(organization_id),
properties,
)
except Exception:
logger.exception("Failed to sync organization group properties to PostHog")
def _sync_posthog_organization_mps_billing_v2_status(
organization_id: int,
*,
uses_mps_billing_v2: bool,
) -> None:
"""Update the PostHog organization group with current MPS billing status."""
try:
organization_id = int(organization_id)
group_identify(
POSTHOG_ORGANIZATION_GROUP_TYPE,
str(organization_id),
{POSTHOG_ORGANIZATION_USES_MPS_BILLING_V2_PROPERTY: uses_mps_billing_v2},
)
except Exception:
logger.exception("Failed to sync organization billing status to PostHog")
def _associate_user_with_posthog_organization(
*,
user: UserModel,
organization,
stack_user: dict | None = None,
user_distinct_id: str | None = None,
org_was_created: bool,
organization_ids: list[int] | None = None,
selected_organization_id: int | None = None,
selected_organization_provider_id: str | None = None,
) -> None:
"""Attach the Stack user to the PostHog organization group."""
try:
organization_id = int(organization.id)
organization_provider_id = getattr(organization, "provider_id", None)
if user_distinct_id is None:
if stack_user and stack_user.get("id"):
user_distinct_id = str(stack_user["id"])
else:
user_distinct_id = str(user.provider_id)
selected_org_id = selected_organization_id or organization_id
selected_org_provider_id = (
selected_organization_provider_id or organization_provider_id
)
person_properties = {
"user_id": user.id,
"user_provider_id": user_distinct_id,
"selected_organization_id": selected_org_id,
"selected_organization_provider_id": selected_org_provider_id,
}
if organization_ids is not None:
person_properties["organization_ids"] = organization_ids
if user.email:
person_properties["email"] = user.email
set_person_properties(user_distinct_id, person_properties)
event_properties = {
"user_id": user.id,
"organization_id": organization_id,
"organization_provider_id": organization_provider_id,
"auth_provider": "stack",
"organization_was_created": org_was_created,
}
capture_event(
distinct_id=user_distinct_id,
event=PostHogEvent.ORGANIZATION_USER_ASSOCIATED,
properties=event_properties,
groups={POSTHOG_ORGANIZATION_GROUP_TYPE: str(organization_id)},
)
except Exception:
logger.exception("Failed to associate user with PostHog organization")
async def get_user_with_selected_organization(
user: Annotated[UserModel, Depends(get_user)],
) -> UserModel:
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
return user
async def _handle_oss_auth(authorization: str | None) -> UserModel:
"""
Handle authentication for OSS deployment mode.
Validates JWT tokens issued by the email/password auth flow.
"""
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header required")
# Remove "Bearer " prefix if present
token = (
authorization.replace("Bearer ", "")
if authorization.startswith("Bearer ")
else authorization
)
if not token:
raise HTTPException(status_code=401, detail="Invalid authorization token")
try:
payload = decode_jwt_token(token)
user = await db_client.get_user_by_id(int(payload["sub"]))
if user:
return user
raise HTTPException(status_code=401, detail="User not found")
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=401, detail="Invalid or expired token")
async def _handle_api_key_auth(api_key: str) -> UserModel:
"""
Handle authentication via X-API-Key header.
Returns the user who created the API key with the correct organization context.
"""
# Validate the API key
api_key_model = await db_client.validate_api_key(api_key)
if not api_key_model:
raise HTTPException(status_code=401, detail="Invalid or expired API key")
# API key must have a created_by user
if not api_key_model.created_by:
raise HTTPException(status_code=401, detail="API key has no associated user")
# Get the user who created this API key
user = await db_client.get_user_by_id(api_key_model.created_by)
if not user:
raise HTTPException(status_code=401, detail="API key owner not found")
# Set the organization context to the API key's organization
user.selected_organization_id = api_key_model.organization_id
logger.debug(
f"Authenticated via API key: {api_key_model.key_prefix}... "
f"(user_id={user.id}, org_id={api_key_model.organization_id})"
)
return user
async def create_user_configuration_with_mps_key(
user_id: int, organization_id: int, user_provider_id: str
) -> Optional[EffectiveAIModelConfiguration]:
"""Create user configuration using MPS service key.
Args:
user_id: The user's ID
organization_id: The organization's ID
user_provider_id: The user's provider ID (for created_by field)
Returns:
EffectiveAIModelConfiguration with MPS-provided API keys or None if failed
"""
async with httpx.AsyncClient() as client:
# Use MPS API URL from constants
if AUTH_PROVIDER == "local":
# For local auth mode, create a temporary service key without authentication
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": "Default Dograh Model Service Key",
"description": "Auto-generated key for OSS user",
"expires_in_days": 7, # Short-lived for OSS
"created_by": user_provider_id,
},
timeout=10.0,
)
else:
# For authenticated mode, use the secret key and organization ID
if not DOGRAH_MPS_SECRET_KEY:
logger.warning(
"Warning: DOGRAH_MPS_SECRET_KEY not set for authenticated mode"
)
raise ValidationError("Missing DOGRAH_MPS_SECRET_KEY in non oss mode")
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": "Default Dograh Model Service Key",
"description": f"Auto-generated key for organization {organization_id}",
"organization_id": organization_id,
"expires_in_days": 90, # Longer-lived for authenticated users
"created_by": user_provider_id,
},
headers={"X-Secret-Key": DOGRAH_MPS_SECRET_KEY},
timeout=10.0,
)
if response.status_code == 200:
data = response.json()
service_key = data.get("service_key")
if service_key:
# Create configuration JSON for storage in database
# The service_factory will use this to instantiate actual services
configuration = {
"llm": {
"provider": ServiceProviders.DOGRAH.value,
"api_key": [service_key],
"model": "default",
},
"tts": {
"provider": ServiceProviders.DOGRAH.value,
"api_key": [service_key],
"model": "default",
"voice": "default",
},
"stt": {
"provider": ServiceProviders.DOGRAH.value,
"api_key": [service_key],
"model": "default",
},
}
effective_config = EffectiveAIModelConfiguration(**configuration)
return effective_config
else:
logger.warning(
f"Failed to get MPS service key: {response.status_code} - {response.text}"
)
async def get_superuser(
authorization: Annotated[str | None, Header()] = None,
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
) -> UserModel:
"""
Dependency to check if the authenticated user is a superuser.
Raises HTTPException if user is not authenticated or not a superuser.
"""
user = await get_user(authorization, x_api_key)
if not user.is_superuser:
raise HTTPException(
status_code=403, detail="Access denied. Superuser privileges required."
)
return user
async def get_user_ws(
websocket: WebSocket,
token: str = Query(None),
api_key: str = Query(None, alias="api_key"),
) -> UserModel:
"""
WebSocket authentication dependency.
Uses token or api_key from query parameters for authentication.
"""
if not token and not api_key:
await websocket.close(code=1008, reason="Missing authentication token")
raise HTTPException(status_code=401, detail="Missing authentication token")
try:
# API key takes precedence
if api_key:
user = await get_user(None, api_key)
else:
# Use the same logic as get_user but with token from query
authorization = f"Bearer {token}"
user = await get_user(authorization, None)
return user
except HTTPException as e:
await websocket.close(code=1008, reason=e.detail)
raise