mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
chore: update documentation
This commit is contained in:
parent
7cc0467cfb
commit
da4a8a005a
21 changed files with 314 additions and 179 deletions
|
|
@ -37,7 +37,7 @@ RUN --mount=type=bind,source=api/requirements.txt,target=/tmp/req.txt \
|
|||
# sys.prefix/nltk_data, so it travels with the venv on COPY.
|
||||
RUN --mount=type=bind,source=pipecat,target=/tmp/pipecat,rw \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb,mcp,inworld]' \
|
||||
uv pip install '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb,mcp,inworld,smallest]' \
|
||||
&& uv pip uninstall opencv-python \
|
||||
&& uv pip install opencv-python-headless \
|
||||
&& python -c "import nltk; nltk.download('punkt_tab', download_dir='/opt/venv/nltk_data', quiet=True)"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from datetime import datetime, timezone
|
|||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
|
|
@ -29,8 +30,6 @@ class UserClient(BaseDBClient):
|
|||
|
||||
# Use PostgreSQL's INSERT ... ON CONFLICT DO NOTHING
|
||||
# This is atomic and handles race conditions at the database level
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
stmt = insert(UserModel.__table__).values(
|
||||
provider_id=provider_id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
|
|
@ -88,21 +87,22 @@ class UserClient(BaseDBClient):
|
|||
) -> dict:
|
||||
"""Create or update the JSON value stored for a user under `key`."""
|
||||
async with self.async_session() as session:
|
||||
row = await self._get_user_configuration_row(session, user_id, key)
|
||||
if row:
|
||||
row.configuration = value
|
||||
else:
|
||||
row = UserConfigurationModel(
|
||||
user_id=user_id, key=key, configuration=value
|
||||
)
|
||||
session.add(row)
|
||||
stmt = insert(UserConfigurationModel.__table__).values(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
configuration=value,
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
constraint="_user_configuration_key_uc",
|
||||
set_={"configuration": stmt.excluded.configuration},
|
||||
).returning(UserConfigurationModel.configuration)
|
||||
try:
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(row)
|
||||
return row.configuration
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_user_configurations(
|
||||
self, user_id: int
|
||||
|
|
|
|||
|
|
@ -512,7 +512,7 @@ class MPSServiceKeyClient:
|
|||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
logger.warning(
|
||||
"Failed to authorize MPS workflow run start: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -37,6 +37,12 @@ BILLING_V2_QUOTA_EXCEEDED_MESSAGE = (
|
|||
"or change providers in Models configurations."
|
||||
)
|
||||
|
||||
SERVICE_TOKEN_ORG_MISMATCH_MESSAGE = (
|
||||
"The Dograh service token being used is created from another account. "
|
||||
"Please create a new service token from the Developers tab and use it in "
|
||||
"your model configuration."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCheckResult:
|
||||
|
|
@ -98,6 +104,26 @@ def _dograh_api_keys(user_config: Any) -> set[str]:
|
|||
return api_keys
|
||||
|
||||
|
||||
def _is_service_key_org_mismatch_error(error: Exception) -> bool:
|
||||
response = getattr(error, "response", None)
|
||||
if getattr(response, "status_code", None) != 403:
|
||||
return False
|
||||
|
||||
detail: Any = None
|
||||
try:
|
||||
payload = response.json()
|
||||
if isinstance(payload, dict):
|
||||
detail = payload.get("detail")
|
||||
except Exception:
|
||||
detail = None
|
||||
|
||||
if isinstance(detail, str):
|
||||
return detail.lower() == "service key organization mismatch"
|
||||
|
||||
response_text = getattr(response, "text", "")
|
||||
return "Service key organization mismatch" in response_text
|
||||
|
||||
|
||||
async def _store_run_correlation_id(
|
||||
workflow_run_id: int | None,
|
||||
correlation_id: str | None,
|
||||
|
|
@ -173,11 +199,20 @@ async def _authorize_hosted_workflow_run_start(
|
|||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
logger.warning(
|
||||
"Failed to authorize workflow start with MPS for org {}: {}",
|
||||
organization_id,
|
||||
e,
|
||||
)
|
||||
if _is_service_key_org_mismatch_error(e):
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="service_key_org_mismatch",
|
||||
error_message=SERVICE_TOKEN_ORG_MISMATCH_MESSAGE,
|
||||
),
|
||||
True,
|
||||
)
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from api.services import quota_service
|
||||
|
|
@ -284,6 +285,69 @@ async def test_authorize_workflow_run_managed_v2_stores_hosted_correlation(
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_service_token_from_wrong_org_prompts_new_token(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
get_config = AsyncMock(
|
||||
return_value=_dograh_config(api_key, managed_service_version=2)
|
||||
)
|
||||
request = httpx.Request(
|
||||
"POST",
|
||||
"http://localhost:8004/api/v1/billing/accounts/42/run-authorization",
|
||||
)
|
||||
response = httpx.Response(
|
||||
403,
|
||||
json={"detail": "Service key organization mismatch"},
|
||||
request=request,
|
||||
)
|
||||
authorize = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Failed to authorize MPS workflow run start",
|
||||
request=request,
|
||||
response=response,
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "service_key_org_mismatch"
|
||||
assert result.error_message == quota_service.SERVICE_TOKEN_ORG_MISMATCH_MESSAGE
|
||||
assert "new service token from the Developers tab" in result.error_message
|
||||
authorize.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
workflow_run_id=88,
|
||||
service_key=api_key,
|
||||
require_correlation_id=True,
|
||||
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by="provider-123",
|
||||
metadata={"dograh_user_id": "123", "workflow_id": 7},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_oss_uses_key_paths_not_workflow_org(
|
||||
monkeypatch,
|
||||
|
|
|
|||
98
api/tests/test_user_configuration_upsert.py
Normal file
98
api/tests/test_user_configuration_upsert.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from api.db.models import UserConfigurationModel
|
||||
from api.db.user_client import UserClient
|
||||
from api.enums import UserConfigurationKey
|
||||
|
||||
|
||||
class _FakeResult:
|
||||
def __init__(self, value: dict):
|
||||
self._value = value
|
||||
|
||||
def scalar_one(self) -> dict:
|
||||
return self._value
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, result_value: dict):
|
||||
self.result_value = result_value
|
||||
self.statements = []
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
async def execute(self, stmt):
|
||||
self.statements.append(stmt)
|
||||
return _FakeResult(self.result_value)
|
||||
|
||||
async def commit(self):
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self):
|
||||
self.rolled_back = True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_user_configuration_value_uses_atomic_conflict_update():
|
||||
result_value = {"completed_actions": ["web_call_started"]}
|
||||
fake_session = _FakeSession(result_value)
|
||||
client = UserClient.__new__(UserClient)
|
||||
client.async_session = lambda: fake_session
|
||||
|
||||
value = await client.upsert_user_configuration_value(
|
||||
86,
|
||||
UserConfigurationKey.ONBOARDING.value,
|
||||
result_value,
|
||||
)
|
||||
|
||||
assert value == result_value
|
||||
assert fake_session.committed is True
|
||||
assert fake_session.rolled_back is False
|
||||
assert len(fake_session.statements) == 1
|
||||
|
||||
compiled = str(fake_session.statements[0].compile(dialect=postgresql.dialect()))
|
||||
assert "ON CONFLICT ON CONSTRAINT _user_configuration_key_uc DO UPDATE" in compiled
|
||||
assert "configuration = excluded.configuration" in compiled
|
||||
assert "last_validated_at" not in compiled
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_user_configuration_value_updates_existing_row(
|
||||
db_session,
|
||||
async_session,
|
||||
):
|
||||
user, _ = await db_session.get_or_create_user_by_provider_id(
|
||||
"user-config-upsert-test"
|
||||
)
|
||||
|
||||
first = await db_session.upsert_user_configuration_value(
|
||||
user.id,
|
||||
UserConfigurationKey.ONBOARDING.value,
|
||||
{"skipped": False},
|
||||
)
|
||||
second = await db_session.upsert_user_configuration_value(
|
||||
user.id,
|
||||
UserConfigurationKey.ONBOARDING.value,
|
||||
{"skipped": True},
|
||||
)
|
||||
|
||||
assert first == {"skipped": False}
|
||||
assert second == {"skipped": True}
|
||||
|
||||
result = await async_session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user.id,
|
||||
UserConfigurationModel.key == UserConfigurationKey.ONBOARDING.value,
|
||||
)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
|
||||
assert len(rows) == 1
|
||||
assert rows[0].configuration == {"skipped": True}
|
||||
Loading…
Add table
Add a link
Reference in a new issue