mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
162 lines
5.3 KiB
Python
162 lines
5.3 KiB
Python
"""Tests for the Dograh-managed embedding service and its correlation resolver."""
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
from api.services.gen_ai.embedding.dograh_service import DograhEmbeddingService
|
|
from api.services.gen_ai.embedding.factory import resolve_embedding_correlation_id
|
|
|
|
|
|
def _service_with_fake_client(correlation_id):
|
|
service = DograhEmbeddingService(
|
|
db_client=None,
|
|
api_key="sk-test",
|
|
model_id="text-embedding-3-small",
|
|
base_url=None,
|
|
correlation_id=correlation_id,
|
|
)
|
|
create = AsyncMock(
|
|
return_value=SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2])])
|
|
)
|
|
service.client = SimpleNamespace(embeddings=SimpleNamespace(create=create))
|
|
return service, create
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dograh_embedding_forwards_v2_protocol_when_correlation_present():
|
|
service, create = _service_with_fake_client("corr-123")
|
|
|
|
await service.embed_texts(["hello"])
|
|
|
|
create.assert_awaited_once()
|
|
kwargs = create.await_args.kwargs
|
|
assert kwargs["input"] == ["hello"]
|
|
assert kwargs["model"] == "text-embedding-3-small"
|
|
assert kwargs["extra_body"] == {
|
|
"metadata": {
|
|
"correlation_id": "corr-123",
|
|
"mps_billing_version": "2",
|
|
}
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dograh_embedding_sends_plain_without_correlation():
|
|
service, create = _service_with_fake_client(None)
|
|
|
|
await service.embed_texts(["hello"])
|
|
|
|
create.assert_awaited_once()
|
|
# No correlation id (e.g. a v1 org) → no MPS metadata; MPS accepts plain calls.
|
|
assert "extra_body" not in create.await_args.kwargs
|
|
|
|
|
|
def _fake_mps_client(*, status_return=None, minted="minted"):
|
|
return SimpleNamespace(
|
|
get_billing_account_status=AsyncMock(return_value=status_return),
|
|
create_correlation_id=AsyncMock(return_value={"correlation_id": minted}),
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_correlation_oss_mints_directly(monkeypatch):
|
|
fake = _fake_mps_client()
|
|
monkeypatch.setattr(
|
|
"api.services.mps_service_key_client.mps_service_key_client", fake
|
|
)
|
|
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "oss")
|
|
|
|
result = await resolve_embedding_correlation_id(
|
|
organization_id=None, service_key="sk-mps"
|
|
)
|
|
|
|
assert result == "minted"
|
|
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
|
|
fake.get_billing_account_status.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_correlation_hosted_v2_mints(monkeypatch):
|
|
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
|
|
monkeypatch.setattr(
|
|
"api.services.mps_service_key_client.mps_service_key_client", fake
|
|
)
|
|
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
|
|
|
result = await resolve_embedding_correlation_id(
|
|
organization_id=42, service_key="sk-mps", created_by="user-1"
|
|
)
|
|
|
|
assert result == "minted"
|
|
fake.get_billing_account_status.assert_awaited_once_with(42, created_by="user-1")
|
|
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_correlation_hosted_v1_returns_none_without_minting(monkeypatch):
|
|
fake = _fake_mps_client(status_return={"billing_mode": "v1"})
|
|
monkeypatch.setattr(
|
|
"api.services.mps_service_key_client.mps_service_key_client", fake
|
|
)
|
|
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
|
|
|
result = await resolve_embedding_correlation_id(
|
|
organization_id=42, service_key="sk-mps"
|
|
)
|
|
|
|
assert result is None
|
|
fake.create_correlation_id.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_correlation_hosted_no_account_returns_none(monkeypatch):
|
|
fake = _fake_mps_client(status_return=None)
|
|
monkeypatch.setattr(
|
|
"api.services.mps_service_key_client.mps_service_key_client", fake
|
|
)
|
|
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
|
|
|
result = await resolve_embedding_correlation_id(
|
|
organization_id=42, service_key="sk-mps"
|
|
)
|
|
|
|
assert result is None
|
|
fake.create_correlation_id.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_correlation_no_service_key_returns_none(monkeypatch):
|
|
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
|
|
monkeypatch.setattr(
|
|
"api.services.mps_service_key_client.mps_service_key_client", fake
|
|
)
|
|
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
|
|
|
result = await resolve_embedding_correlation_id(
|
|
organization_id=42, service_key=None
|
|
)
|
|
|
|
assert result is None
|
|
fake.get_billing_account_status.assert_not_awaited()
|
|
fake.create_correlation_id.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_correlation_swallows_errors(monkeypatch):
|
|
fake = SimpleNamespace(
|
|
get_billing_account_status=AsyncMock(side_effect=RuntimeError("mps down")),
|
|
create_correlation_id=AsyncMock(),
|
|
)
|
|
monkeypatch.setattr(
|
|
"api.services.mps_service_key_client.mps_service_key_client", fake
|
|
)
|
|
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
|
|
|
# A transient MPS failure must not break embeddings — fall back to no protocol.
|
|
result = await resolve_embedding_correlation_id(
|
|
organization_id=42, service_key="sk-mps"
|
|
)
|
|
|
|
assert result is None
|