dograh/api/tests/test_dograh_embedding_service.py
2026-06-25 22:21:11 +05:30

162 lines
5.3 KiB
Python

"""Tests for the Dograh-managed embedding service and its correlation resolver."""
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.gen_ai.embedding.dograh_service import DograhEmbeddingService
from api.services.gen_ai.embedding.factory import resolve_embedding_correlation_id
def _service_with_fake_client(correlation_id):
service = DograhEmbeddingService(
db_client=None,
api_key="sk-test",
model_id="text-embedding-3-small",
base_url=None,
correlation_id=correlation_id,
)
create = AsyncMock(
return_value=SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2])])
)
service.client = SimpleNamespace(embeddings=SimpleNamespace(create=create))
return service, create
@pytest.mark.asyncio
async def test_dograh_embedding_forwards_v2_protocol_when_correlation_present():
service, create = _service_with_fake_client("corr-123")
await service.embed_texts(["hello"])
create.assert_awaited_once()
kwargs = create.await_args.kwargs
assert kwargs["input"] == ["hello"]
assert kwargs["model"] == "text-embedding-3-small"
assert kwargs["extra_body"] == {
"metadata": {
"correlation_id": "corr-123",
"mps_billing_version": "2",
}
}
@pytest.mark.asyncio
async def test_dograh_embedding_sends_plain_without_correlation():
service, create = _service_with_fake_client(None)
await service.embed_texts(["hello"])
create.assert_awaited_once()
# No correlation id (e.g. a v1 org) → no MPS metadata; MPS accepts plain calls.
assert "extra_body" not in create.await_args.kwargs
def _fake_mps_client(*, status_return=None, minted="minted"):
return SimpleNamespace(
get_billing_account_status=AsyncMock(return_value=status_return),
create_correlation_id=AsyncMock(return_value={"correlation_id": minted}),
)
@pytest.mark.asyncio
async def test_resolve_correlation_oss_mints_directly(monkeypatch):
fake = _fake_mps_client()
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "oss")
result = await resolve_embedding_correlation_id(
organization_id=None, service_key="sk-mps"
)
assert result == "minted"
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
fake.get_billing_account_status.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_correlation_hosted_v2_mints(monkeypatch):
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
result = await resolve_embedding_correlation_id(
organization_id=42, service_key="sk-mps", created_by="user-1"
)
assert result == "minted"
fake.get_billing_account_status.assert_awaited_once_with(42, created_by="user-1")
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
@pytest.mark.asyncio
async def test_resolve_correlation_hosted_v1_returns_none_without_minting(monkeypatch):
fake = _fake_mps_client(status_return={"billing_mode": "v1"})
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
result = await resolve_embedding_correlation_id(
organization_id=42, service_key="sk-mps"
)
assert result is None
fake.create_correlation_id.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_correlation_hosted_no_account_returns_none(monkeypatch):
fake = _fake_mps_client(status_return=None)
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
result = await resolve_embedding_correlation_id(
organization_id=42, service_key="sk-mps"
)
assert result is None
fake.create_correlation_id.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_correlation_no_service_key_returns_none(monkeypatch):
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
result = await resolve_embedding_correlation_id(
organization_id=42, service_key=None
)
assert result is None
fake.get_billing_account_status.assert_not_awaited()
fake.create_correlation_id.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_correlation_swallows_errors(monkeypatch):
fake = SimpleNamespace(
get_billing_account_status=AsyncMock(side_effect=RuntimeError("mps down")),
create_correlation_id=AsyncMock(),
)
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
# A transient MPS failure must not break embeddings — fall back to no protocol.
result = await resolve_embedding_correlation_id(
organization_id=42, service_key="sk-mps"
)
assert result is None