mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
fix: enable knowledge base with Dograh config v2
This commit is contained in:
parent
d675fd1fda
commit
efb25a0cc5
19 changed files with 557 additions and 113 deletions
|
|
@ -55,7 +55,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
|||
assert effective.stt.provider == "dograh"
|
||||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
assert effective.embeddings.model == "dograh_embedding_v1"
|
||||
assert effective.managed_service_version == 2
|
||||
|
||||
|
||||
|
|
|
|||
162
api/tests/test_dograh_embedding_service.py
Normal file
162
api/tests/test_dograh_embedding_service.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Tests for the Dograh-managed embedding service and its correlation resolver."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.gen_ai.embedding.dograh_service import DograhEmbeddingService
|
||||
from api.services.gen_ai.embedding.factory import resolve_embedding_correlation_id
|
||||
|
||||
|
||||
def _service_with_fake_client(correlation_id):
|
||||
service = DograhEmbeddingService(
|
||||
db_client=None,
|
||||
api_key="sk-test",
|
||||
model_id="text-embedding-3-small",
|
||||
base_url=None,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
create = AsyncMock(
|
||||
return_value=SimpleNamespace(data=[SimpleNamespace(embedding=[0.1, 0.2])])
|
||||
)
|
||||
service.client = SimpleNamespace(embeddings=SimpleNamespace(create=create))
|
||||
return service, create
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_embedding_forwards_v2_protocol_when_correlation_present():
|
||||
service, create = _service_with_fake_client("corr-123")
|
||||
|
||||
await service.embed_texts(["hello"])
|
||||
|
||||
create.assert_awaited_once()
|
||||
kwargs = create.await_args.kwargs
|
||||
assert kwargs["input"] == ["hello"]
|
||||
assert kwargs["model"] == "text-embedding-3-small"
|
||||
assert kwargs["extra_body"] == {
|
||||
"metadata": {
|
||||
"correlation_id": "corr-123",
|
||||
"mps_billing_version": "2",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_embedding_sends_plain_without_correlation():
|
||||
service, create = _service_with_fake_client(None)
|
||||
|
||||
await service.embed_texts(["hello"])
|
||||
|
||||
create.assert_awaited_once()
|
||||
# No correlation id (e.g. a v1 org) → no MPS metadata; MPS accepts plain calls.
|
||||
assert "extra_body" not in create.await_args.kwargs
|
||||
|
||||
|
||||
def _fake_mps_client(*, status_return=None, minted="minted"):
|
||||
return SimpleNamespace(
|
||||
get_billing_account_status=AsyncMock(return_value=status_return),
|
||||
create_correlation_id=AsyncMock(return_value={"correlation_id": minted}),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_oss_mints_directly(monkeypatch):
|
||||
fake = _fake_mps_client()
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "oss")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=None, service_key="sk-mps"
|
||||
)
|
||||
|
||||
assert result == "minted"
|
||||
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
|
||||
fake.get_billing_account_status.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_hosted_v2_mints(monkeypatch):
|
||||
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key="sk-mps", created_by="user-1"
|
||||
)
|
||||
|
||||
assert result == "minted"
|
||||
fake.get_billing_account_status.assert_awaited_once_with(42, created_by="user-1")
|
||||
fake.create_correlation_id.assert_awaited_once_with(service_key="sk-mps")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_hosted_v1_returns_none_without_minting(monkeypatch):
|
||||
fake = _fake_mps_client(status_return={"billing_mode": "v1"})
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key="sk-mps"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
fake.create_correlation_id.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_hosted_no_account_returns_none(monkeypatch):
|
||||
fake = _fake_mps_client(status_return=None)
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key="sk-mps"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
fake.create_correlation_id.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_no_service_key_returns_none(monkeypatch):
|
||||
fake = _fake_mps_client(status_return={"billing_mode": "v2"})
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key=None
|
||||
)
|
||||
|
||||
assert result is None
|
||||
fake.get_billing_account_status.assert_not_awaited()
|
||||
fake.create_correlation_id.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_swallows_errors(monkeypatch):
|
||||
fake = SimpleNamespace(
|
||||
get_billing_account_status=AsyncMock(side_effect=RuntimeError("mps down")),
|
||||
create_correlation_id=AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.mps_service_key_client", fake
|
||||
)
|
||||
monkeypatch.setattr("api.constants.DEPLOYMENT_MODE", "hosted")
|
||||
|
||||
# A transient MPS failure must not break embeddings — fall back to no protocol.
|
||||
result = await resolve_embedding_correlation_id(
|
||||
organization_id=42, service_key="sk-mps"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
26
api/tests/test_knowledge_base_processing_embeddings.py
Normal file
26
api/tests/test_knowledge_base_processing_embeddings.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import pytest
|
||||
|
||||
from api.tasks.knowledge_base_processing import _embed_texts_in_batches
|
||||
|
||||
|
||||
class FakeEmbeddingService:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def embed_texts(self, texts):
|
||||
self.calls.append(list(texts))
|
||||
return [[float(len(text))] for text in texts]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_texts_in_batches_preserves_order():
|
||||
service = FakeEmbeddingService()
|
||||
|
||||
embeddings = await _embed_texts_in_batches(
|
||||
service,
|
||||
["a", "bb", "ccc", "dddd", "eeeee"],
|
||||
batch_size=2,
|
||||
)
|
||||
|
||||
assert service.calls == [["a", "bb"], ["ccc", "dddd"], ["eeeee"]]
|
||||
assert embeddings == [[1.0], [2.0], [3.0], [4.0], [5.0]]
|
||||
Loading…
Add table
Add a link
Reference in a new issue