fix: enable knowledge base with Dograh config v2

This commit is contained in:
Abhishek Kumar 2026-06-25 22:21:11 +05:30
parent d675fd1fda
commit efb25a0cc5
19 changed files with 557 additions and 113 deletions

View file

@ -55,7 +55,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
assert effective.stt.provider == "dograh"
assert effective.stt.language == "multi"
assert effective.embeddings.provider == "dograh"
assert effective.embeddings.model == "default"
assert effective.embeddings.model == "dograh_embedding_v1"
assert effective.managed_service_version == 2

View file

@ -0,0 +1,162 @@
"""Tests for the Dograh-managed embedding service and its correlation resolver."""
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.gen_ai.embedding.dograh_service import DograhEmbeddingService
from api.services.gen_ai.embedding.factory import resolve_embedding_correlation_id
def _service_with_fake_client(correlation_id):
service = DograhEmbeddingService(
db_client=None,
api_key="sk-test",
model_id="text-embedding-3-small",
base_url=None,
correlation_id=correlation_id,
)
create = AsyncMock(
return_value=SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2])])
)
service.client = SimpleNamespace(embeddings=SimpleNamespace(create=create))
return service, create
@pytest.mark.asyncio
async def test_dograh_embedding_forwards_v2_protocol_when_correlation_present():
service, create = _service_with_fake_client("corr-123")
await service.embed_texts(["hello"])
create.assert_awaited_once()
kwargs = create.await_args.kwargs
assert kwargs["input"] == ["hello"]
assert kwargs["model"] == "text-embedding-3-small"
assert kwargs["extra_body"] == {
"metadata": {
"correlation_id": "corr-123",
"mps_billing_version": "2",
}
}
@pytest.mark.asyncio
async def test_dograh_embedding_sends_plain_without_correlation():
service, create = _service_with_fake_client(None)
await service.embed_texts(["hello"])
create.assert_awaited_once()
# No correlation id (e.g. a v1 org) → no MPS metadata; MPS accepts plain calls.
assert "extra_body" not in create.await_args.kwargs
def _fake_mps_client(*, status_return=None, minted="minted"):
return SimpleNamespace(
get_billing_account_status=AsyncMock(return_value=status_return),
create_correlation_id=AsyncMock(return_value={"correlation_id": minted}),
)
@pytest.mark.asyncio
async def test_resolve_correlation_oss_mints_directly(monkeypatch):
fake = _fake_mps_client()
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "oss")
result = await resolve_embedding_correlation_id(
organization_id=None, service_key="sk-mps"
)
assert result == "minted"
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
fake.get_billing_account_status.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_correlation_hosted_v2_mints(monkeypatch):
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
result = await resolve_embedding_correlation_id(
organization_id=42, service_key="sk-mps", created_by="user-1"
)
assert result == "minted"
fake.get_billing_account_status.assert_awaited_once_with(42, created_by="user-1")
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
@pytest.mark.asyncio
async def test_resolve_correlation_hosted_v1_returns_none_without_minting(monkeypatch):
fake = _fake_mps_client(status_return={"billing_mode": "v1"})
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
result = await resolve_embedding_correlation_id(
organization_id=42, service_key="sk-mps"
)
assert result is None
fake.create_correlation_id.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_correlation_hosted_no_account_returns_none(monkeypatch):
fake = _fake_mps_client(status_return=None)
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
result = await resolve_embedding_correlation_id(
organization_id=42, service_key="sk-mps"
)
assert result is None
fake.create_correlation_id.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_correlation_no_service_key_returns_none(monkeypatch):
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
result = await resolve_embedding_correlation_id(
organization_id=42, service_key=None
)
assert result is None
fake.get_billing_account_status.assert_not_awaited()
fake.create_correlation_id.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_correlation_swallows_errors(monkeypatch):
fake = SimpleNamespace(
get_billing_account_status=AsyncMock(side_effect=RuntimeError("mps down")),
create_correlation_id=AsyncMock(),
)
monkeypatch.setattr(
"api.services.mps_service_key_client.mps_service_key_client", fake
)
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
# A transient MPS failure must not break embeddings — fall back to no protocol.
result = await resolve_embedding_correlation_id(
organization_id=42, service_key="sk-mps"
)
assert result is None

View file

@ -0,0 +1,26 @@
import pytest
from api.tasks.knowledge_base_processing import _embed_texts_in_batches
class FakeEmbeddingService:
def __init__(self):
self.calls = []
async def embed_texts(self, texts):
self.calls.append(list(texts))
return [[float(len(text))] for text in texts]
@pytest.mark.asyncio
async def test_embed_texts_in_batches_preserves_order():
service = FakeEmbeddingService()
embeddings = await _embed_texts_in_batches(
service,
["a", "bb", "ccc", "dddd", "eeeee"],
batch_size=2,
)
assert service.calls == [["a", "bb"], ["ccc", "dddd"], ["eeeee"]]
assert embeddings == [[1.0], [2.0], [3.0], [4.0], [5.0]]