mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
115 lines
4.7 KiB
Python
115 lines
4.7 KiB
Python
|
|
from datetime import datetime, timezone
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from sqlalchemy.dialects.postgresql import insert
|
||
|
|
from sqlalchemy.future import select
|
||
|
|
|
||
|
|
from api.db.base_client import BaseDBClient
|
||
|
|
from api.db.models import (
|
||
|
|
APIKeyModel,
|
||
|
|
OrganizationModel,
|
||
|
|
organization_users_association,
|
||
|
|
)
|
||
|
|
from api.utils.api_key import generate_api_key
|
||
|
|
|
||
|
|
|
||
|
|
class OrganizationClient(BaseDBClient):
|
||
|
|
async def get_organization_by_id(
|
||
|
|
self, organization_id: int
|
||
|
|
) -> Optional[OrganizationModel]:
|
||
|
|
"""Get an organization by its ID."""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
result = await session.execute(
|
||
|
|
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||
|
|
)
|
||
|
|
return result.scalars().first()
|
||
|
|
|
||
|
|
async def get_or_create_organization_by_provider_id(
|
||
|
|
self, org_provider_id: str, user_id: int
|
||
|
|
) -> tuple[OrganizationModel, bool]:
|
||
|
|
"""Get an existing organization by provider_id or create a new one.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A tuple of (organization, was_created) where was_created is True if the organization
|
||
|
|
was created in this call, False if it already existed.
|
||
|
|
"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
# First try to get existing organization
|
||
|
|
result = await session.execute(
|
||
|
|
select(OrganizationModel).where(
|
||
|
|
OrganizationModel.provider_id == org_provider_id
|
||
|
|
)
|
||
|
|
)
|
||
|
|
organization = result.scalars().first()
|
||
|
|
|
||
|
|
if organization is None:
|
||
|
|
# Use PostgreSQL's INSERT ... ON CONFLICT DO NOTHING
|
||
|
|
# This is atomic and handles race conditions at the database level
|
||
|
|
|
||
|
|
stmt = insert(OrganizationModel.__table__).values(
|
||
|
|
provider_id=org_provider_id, created_at=datetime.now(timezone.utc)
|
||
|
|
)
|
||
|
|
# ON CONFLICT DO NOTHING - if another request already inserted, this becomes a no-op
|
||
|
|
stmt = stmt.on_conflict_do_nothing(index_elements=["provider_id"])
|
||
|
|
|
||
|
|
result = await session.execute(stmt)
|
||
|
|
await session.commit()
|
||
|
|
|
||
|
|
# Check if we actually inserted (rowcount > 0) or if there was a conflict (rowcount == 0)
|
||
|
|
was_created = result.rowcount > 0
|
||
|
|
|
||
|
|
# Now fetch the organization (either the one we just created or the one that existed)
|
||
|
|
result = await session.execute(
|
||
|
|
select(OrganizationModel).where(
|
||
|
|
OrganizationModel.provider_id == org_provider_id
|
||
|
|
)
|
||
|
|
)
|
||
|
|
organization = result.scalars().first()
|
||
|
|
|
||
|
|
if organization is None:
|
||
|
|
# This should never happen, but handle it just in case
|
||
|
|
error_msg = f"Failed to create or fetch organization with provider_id {org_provider_id}"
|
||
|
|
raise ValueError(error_msg)
|
||
|
|
|
||
|
|
# Only create API key if we actually created the organization
|
||
|
|
if was_created:
|
||
|
|
# Create a default API key for the new organization
|
||
|
|
_, key_hash, key_prefix = generate_api_key()
|
||
|
|
|
||
|
|
api_key = APIKeyModel(
|
||
|
|
organization_id=organization.id,
|
||
|
|
name="Default API Key",
|
||
|
|
key_hash=key_hash,
|
||
|
|
key_prefix=key_prefix,
|
||
|
|
is_active=True,
|
||
|
|
created_by=user_id,
|
||
|
|
)
|
||
|
|
session.add(api_key)
|
||
|
|
await session.commit()
|
||
|
|
|
||
|
|
await session.refresh(organization)
|
||
|
|
return organization, was_created
|
||
|
|
return organization, False
|
||
|
|
|
||
|
|
async def add_user_to_organization(
|
||
|
|
self, user_id: int, organization_id: int
|
||
|
|
) -> None:
|
||
|
|
"""Ensure that a user is linked to an organization (many-to-many).
|
||
|
|
|
||
|
|
The association is created only if it does not already exist.
|
||
|
|
Uses INSERT ... ON CONFLICT DO NOTHING to handle race conditions.
|
||
|
|
"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
# Use PostgreSQL's INSERT ... ON CONFLICT DO NOTHING
|
||
|
|
# This handles race conditions at the database level
|
||
|
|
|
||
|
|
stmt = insert(organization_users_association).values(
|
||
|
|
user_id=user_id, organization_id=organization_id
|
||
|
|
)
|
||
|
|
# ON CONFLICT DO NOTHING - if another request already inserted, this becomes a no-op
|
||
|
|
# The primary key constraint on (user_id, organization_id) will trigger the conflict
|
||
|
|
stmt = stmt.on_conflict_do_nothing()
|
||
|
|
|
||
|
|
await session.execute(stmt)
|
||
|
|
await session.commit()
|