2025-11-15 17:32:37 +05:30
|
|
|
"""Database client for managing embed tokens and sessions."""
|
|
|
|
|
|
|
|
|
|
import secrets
|
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
from sqlalchemy import and_, select, update
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
|
|
from api.db.base_client import BaseDBClient
|
|
|
|
|
from api.db.models import EmbedSessionModel, EmbedTokenModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbedTokenClient(BaseDBClient):
|
|
|
|
|
"""Client for managing embed tokens and sessions."""
|
|
|
|
|
|
|
|
|
|
async def create_embed_token(
|
|
|
|
|
self,
|
|
|
|
|
workflow_id: int,
|
|
|
|
|
organization_id: int,
|
|
|
|
|
created_by: int,
|
|
|
|
|
allowed_domains: Optional[List[str]] = None,
|
|
|
|
|
settings: Optional[dict] = None,
|
|
|
|
|
usage_limit: Optional[int] = None,
|
|
|
|
|
expires_at: Optional[datetime] = None,
|
|
|
|
|
) -> EmbedTokenModel:
|
|
|
|
|
"""Create a new embed token for a workflow.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
workflow_id: ID of the workflow to embed
|
|
|
|
|
organization_id: ID of the organization
|
|
|
|
|
created_by: ID of the user creating the token
|
|
|
|
|
allowed_domains: List of domains allowed to use this token
|
|
|
|
|
settings: Widget customization settings
|
|
|
|
|
usage_limit: Optional limit on number of uses
|
|
|
|
|
expires_at: Optional expiration date
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Created EmbedTokenModel
|
|
|
|
|
"""
|
|
|
|
|
async with self.async_session() as session:
|
|
|
|
|
# Generate a unique token
|
|
|
|
|
token = f"emb_{secrets.token_urlsafe(32)}"
|
|
|
|
|
|
|
|
|
|
# Ensure uniqueness
|
|
|
|
|
while await self._token_exists(session, token):
|
|
|
|
|
token = f"emb_{secrets.token_urlsafe(32)}"
|
|
|
|
|
|
|
|
|
|
embed_token = EmbedTokenModel(
|
|
|
|
|
token=token,
|
|
|
|
|
workflow_id=workflow_id,
|
|
|
|
|
organization_id=organization_id,
|
|
|
|
|
created_by=created_by,
|
|
|
|
|
allowed_domains=allowed_domains,
|
|
|
|
|
settings=settings or {},
|
|
|
|
|
usage_limit=usage_limit,
|
|
|
|
|
expires_at=expires_at,
|
|
|
|
|
is_active=True,
|
|
|
|
|
usage_count=0,
|
|
|
|
|
created_at=datetime.now(UTC),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
session.add(embed_token)
|
|
|
|
|
await session.commit()
|
|
|
|
|
await session.refresh(embed_token)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Created embed token {token} for workflow {workflow_id}")
|
|
|
|
|
return embed_token
|
|
|
|
|
|
|
|
|
|
async def _token_exists(self, session: AsyncSession, token: str) -> bool:
|
|
|
|
|
"""Check if a token already exists."""
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(EmbedTokenModel).where(EmbedTokenModel.token == token)
|
|
|
|
|
)
|
|
|
|
|
return result.scalar_one_or_none() is not None
|
|
|
|
|
|
|
|
|
|
async def get_embed_token_by_token(self, token: str) -> Optional[EmbedTokenModel]:
|
|
|
|
|
"""Get an embed token by its token string.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
token: The token string
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
EmbedTokenModel if found, None otherwise
|
|
|
|
|
"""
|
|
|
|
|
async with self.async_session() as session:
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(EmbedTokenModel).where(EmbedTokenModel.token == token)
|
|
|
|
|
)
|
|
|
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
async def get_embed_tokens_by_workflow(
|
|
|
|
|
self, workflow_id: int, organization_id: int, active_only: bool = True
|
|
|
|
|
) -> List[EmbedTokenModel]:
|
|
|
|
|
"""Get all embed tokens for a workflow.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
workflow_id: ID of the workflow
|
|
|
|
|
organization_id: ID of the organization
|
|
|
|
|
active_only: If True, only return active tokens
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List of EmbedTokenModel instances
|
|
|
|
|
"""
|
|
|
|
|
async with self.async_session() as session:
|
|
|
|
|
query = select(EmbedTokenModel).where(
|
|
|
|
|
and_(
|
|
|
|
|
EmbedTokenModel.workflow_id == workflow_id,
|
|
|
|
|
EmbedTokenModel.organization_id == organization_id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if active_only:
|
|
|
|
|
query = query.where(EmbedTokenModel.is_active == True)
|
|
|
|
|
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
query.order_by(EmbedTokenModel.created_at.desc())
|
|
|
|
|
)
|
|
|
|
|
return result.scalars().all()
|
|
|
|
|
|
|
|
|
|
async def update_embed_token(
|
|
|
|
|
self, token_id: int, organization_id: int, **kwargs
|
|
|
|
|
) -> Optional[EmbedTokenModel]:
|
|
|
|
|
"""Update an embed token.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
token_id: ID of the token to update
|
|
|
|
|
organization_id: ID of the organization (for access control)
|
|
|
|
|
**kwargs: Fields to update (allowed_domains, settings, is_active, etc.)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Updated EmbedTokenModel if found, None otherwise
|
|
|
|
|
"""
|
|
|
|
|
async with self.async_session() as session:
|
|
|
|
|
# First get the token to verify organization
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(EmbedTokenModel).where(
|
|
|
|
|
and_(
|
|
|
|
|
EmbedTokenModel.id == token_id,
|
|
|
|
|
EmbedTokenModel.organization_id == organization_id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
embed_token = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not embed_token:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Update allowed fields
|
|
|
|
|
allowed_fields = {
|
|
|
|
|
"allowed_domains",
|
|
|
|
|
"settings",
|
|
|
|
|
"is_active",
|
|
|
|
|
"usage_limit",
|
|
|
|
|
"expires_at",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for field, value in kwargs.items():
|
|
|
|
|
if field in allowed_fields:
|
|
|
|
|
setattr(embed_token, field, value)
|
|
|
|
|
|
|
|
|
|
embed_token.updated_at = datetime.now(UTC)
|
|
|
|
|
|
|
|
|
|
await session.commit()
|
|
|
|
|
await session.refresh(embed_token)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Updated embed token {token_id}")
|
|
|
|
|
return embed_token
|
|
|
|
|
|
|
|
|
|
async def deactivate_embed_token(self, token_id: int, organization_id: int) -> bool:
|
|
|
|
|
"""Deactivate an embed token.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
token_id: ID of the token to deactivate
|
|
|
|
|
organization_id: ID of the organization
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
True if token was deactivated, False if not found
|
|
|
|
|
"""
|
|
|
|
|
token = await self.update_embed_token(
|
|
|
|
|
token_id, organization_id, is_active=False
|
|
|
|
|
)
|
|
|
|
|
return token is not None
|
|
|
|
|
|
|
|
|
|
async def increment_embed_token_usage(self, token_id: int) -> None:
|
|
|
|
|
"""Increment the usage count for an embed token.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
token_id: ID of the token
|
|
|
|
|
"""
|
|
|
|
|
async with self.async_session() as session:
|
|
|
|
|
await session.execute(
|
|
|
|
|
update(EmbedTokenModel)
|
|
|
|
|
.where(EmbedTokenModel.id == token_id)
|
|
|
|
|
.values(usage_count=EmbedTokenModel.usage_count + 1)
|
|
|
|
|
)
|
|
|
|
|
await session.commit()
|
|
|
|
|
|
|
|
|
|
async def create_embed_session(
|
|
|
|
|
self,
|
|
|
|
|
session_token: str,
|
|
|
|
|
embed_token_id: int,
|
|
|
|
|
workflow_run_id: int,
|
|
|
|
|
client_ip: Optional[str] = None,
|
|
|
|
|
user_agent: Optional[str] = None,
|
|
|
|
|
origin: Optional[str] = None,
|
|
|
|
|
expires_at: Optional[datetime] = None,
|
|
|
|
|
) -> EmbedSessionModel:
|
|
|
|
|
"""Create a new embed session.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
session_token: Unique session token
|
|
|
|
|
embed_token_id: ID of the embed token
|
|
|
|
|
workflow_run_id: ID of the workflow run
|
|
|
|
|
client_ip: Client IP address
|
|
|
|
|
user_agent: User agent string
|
|
|
|
|
origin: Origin header
|
|
|
|
|
expires_at: Session expiration time
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Created EmbedSessionModel
|
|
|
|
|
"""
|
|
|
|
|
async with self.async_session() as session:
|
|
|
|
|
if expires_at is None:
|
|
|
|
|
expires_at = datetime.now(UTC) + timedelta(hours=1)
|
|
|
|
|
|
|
|
|
|
embed_session = EmbedSessionModel(
|
|
|
|
|
session_token=session_token,
|
|
|
|
|
embed_token_id=embed_token_id,
|
|
|
|
|
workflow_run_id=workflow_run_id,
|
|
|
|
|
client_ip=client_ip,
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
origin=origin,
|
|
|
|
|
created_at=datetime.now(UTC),
|
|
|
|
|
expires_at=expires_at,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
session.add(embed_session)
|
|
|
|
|
await session.commit()
|
|
|
|
|
await session.refresh(embed_session)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Created embed session {session_token}")
|
|
|
|
|
return embed_session
|
|
|
|
|
|
|
|
|
|
async def get_embed_session_by_token(
|
|
|
|
|
self, session_token: str
|
|
|
|
|
) -> Optional[EmbedSessionModel]:
|
|
|
|
|
"""Get an embed session by token (alias for get_embed_session).
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
session_token: The session token
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
EmbedSessionModel if found, None otherwise (doesn't check expiry)
|
|
|
|
|
"""
|
|
|
|
|
async with self.async_session() as session:
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(EmbedSessionModel).where(
|
|
|
|
|
EmbedSessionModel.session_token == session_token
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
async def get_embed_token_by_id(self, token_id: int) -> Optional[EmbedTokenModel]:
|
|
|
|
|
"""Get an embed token by ID.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
token_id: ID of the token
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
EmbedTokenModel if found, None otherwise
|
|
|
|
|
"""
|
|
|
|
|
async with self.async_session() as session:
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(EmbedTokenModel).where(EmbedTokenModel.id == token_id)
|
|
|
|
|
)
|
|
|
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
async def get_embed_token_stats(self, token_id: int, organization_id: int) -> dict:
|
|
|
|
|
"""Get usage statistics for an embed token.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
token_id: ID of the token
|
|
|
|
|
organization_id: ID of the organization
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dictionary with usage statistics
|
|
|
|
|
"""
|
2026-01-29 20:43:53 +05:30
|
|
|
from sqlalchemy import func
|
|
|
|
|
|
2025-11-15 17:32:37 +05:30
|
|
|
async with self.async_session() as session:
|
|
|
|
|
# Get the token
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(EmbedTokenModel).where(
|
|
|
|
|
and_(
|
|
|
|
|
EmbedTokenModel.id == token_id,
|
|
|
|
|
EmbedTokenModel.organization_id == organization_id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
token = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not token:
|
|
|
|
|
return {}
|
|
|
|
|
|
2026-01-29 20:43:53 +05:30
|
|
|
# Count active sessions using SQL COUNT
|
2025-11-15 17:32:37 +05:30
|
|
|
active_sessions_result = await session.execute(
|
2026-01-29 20:43:53 +05:30
|
|
|
select(func.count(EmbedSessionModel.id)).where(
|
2025-11-15 17:32:37 +05:30
|
|
|
and_(
|
|
|
|
|
EmbedSessionModel.embed_token_id == token_id,
|
|
|
|
|
EmbedSessionModel.expires_at > datetime.now(UTC),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
2026-01-29 20:43:53 +05:30
|
|
|
active_sessions = active_sessions_result.scalar() or 0
|
2025-11-15 17:32:37 +05:30
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"token_id": token_id,
|
|
|
|
|
"usage_count": token.usage_count,
|
|
|
|
|
"usage_limit": token.usage_limit,
|
|
|
|
|
"active_sessions": active_sessions,
|
|
|
|
|
"is_active": token.is_active,
|
|
|
|
|
"created_at": token.created_at.isoformat()
|
|
|
|
|
if token.created_at
|
|
|
|
|
else None,
|
|
|
|
|
"expires_at": token.expires_at.isoformat()
|
|
|
|
|
if token.expires_at
|
|
|
|
|
else None,
|
|
|
|
|
"allowed_domains": token.allowed_domains,
|
|
|
|
|
}
|