mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 10:56:23 +02:00
Embeddings API scores (#671)
- Put scores in all responses - Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
parent
4fa7cc7d7c
commit
f2ae0e8623
65 changed files with 1339 additions and 1292 deletions
|
|
@ -9,6 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing.
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
|
||||
|
||||
# Sample chunk content for testing - maps chunk_id to content
|
||||
|
|
@ -39,10 +40,14 @@ class TestDocumentRagIntegration:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client that returns chunk IDs"""
|
||||
"""Mock document embeddings client that returns chunk matches"""
|
||||
client = AsyncMock()
|
||||
# Now returns chunk_ids instead of actual content
|
||||
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
|
||||
# Returns ChunkMatch objects with chunk_id and score
|
||||
client.query.return_value = [
|
||||
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
||||
ChunkMatch(chunk_id="doc/c2", score=0.90),
|
||||
ChunkMatch(chunk_id="doc/c3", score=0.85)
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -97,7 +102,7 @@ class TestDocumentRagIntegration:
|
|||
mock_embeddings_client.embed.assert_called_once_with([query])
|
||||
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
limit=doc_limit,
|
||||
user=user,
|
||||
collection=collection
|
||||
|
|
@ -298,7 +303,7 @@ class TestDocumentRagIntegration:
|
|||
assert "DocumentRag initialized" in log_messages
|
||||
assert "Constructing prompt..." in log_messages
|
||||
assert "Computing embeddings..." in log_messages
|
||||
assert "chunk_ids" in log_messages.lower()
|
||||
assert "chunks" in log_messages.lower()
|
||||
assert "Invoking LLM..." in log_messages
|
||||
assert "Query processing complete" in log_messages
|
||||
|
||||
|
|
@ -307,9 +312,9 @@ class TestDocumentRagIntegration:
|
|||
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
||||
mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG performance with large document retrieval"""
|
||||
# Arrange - Mock large chunk_id set (100 chunks)
|
||||
large_chunk_ids = [f"doc/c{i}" for i in range(100)]
|
||||
mock_doc_embeddings_client.query.return_value = large_chunk_ids
|
||||
# Arrange - Mock large chunk match set (100 chunks)
|
||||
large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)]
|
||||
mock_doc_embeddings_client.query.return_value = large_chunk_matches
|
||||
|
||||
# Act
|
||||
import time
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ response delivery through the complete pipeline.
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -36,10 +37,14 @@ class TestDocumentRagStreaming:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client that returns chunk IDs"""
|
||||
"""Mock document embeddings client that returns chunk matches"""
|
||||
client = AsyncMock()
|
||||
# Now returns chunk_ids instead of actual content
|
||||
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
|
||||
# Returns ChunkMatch objects with chunk_id and score
|
||||
client.query.return_value = [
|
||||
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
||||
ChunkMatch(chunk_id="doc/c2", score=0.90),
|
||||
ChunkMatch(chunk_id="doc/c3", score=0.85)
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -35,9 +36,9 @@ class TestGraphRagIntegration:
|
|||
"""Mock graph embeddings client that returns realistic entities"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
"http://trustgraph.ai/e/machine-learning",
|
||||
"http://trustgraph.ai/e/artificial-intelligence",
|
||||
"http://trustgraph.ai/e/neural-networks"
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90),
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85)
|
||||
]
|
||||
return client
|
||||
|
||||
|
|
@ -130,7 +131,7 @@ class TestGraphRagIntegration:
|
|||
# 2. Should query graph embeddings to find relevant entities
|
||||
mock_graph_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_graph_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert call_args.kwargs['limit'] == entity_limit
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ response delivery through the complete pipeline.
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_rag_streaming_chunks,
|
||||
|
|
@ -33,7 +34,7 @@ class TestGraphRagStreaming:
|
|||
"""Mock graph embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
"http://trustgraph.ai/e/machine-learning",
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
|
||||
]
|
||||
return client
|
||||
|
||||
|
|
|
|||
|
|
@ -411,7 +411,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri="http://example.org/entity"),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
|
||||
|
||||
|
||||
class TestGraphRagStreamingProtocol:
|
||||
|
|
@ -25,7 +26,10 @@ class TestGraphRagStreamingProtocol:
|
|||
def mock_graph_embeddings_client(self):
|
||||
"""Mock graph embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = ["entity1", "entity2"]
|
||||
client.query.return_value = [
|
||||
EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95),
|
||||
EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90)
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -202,9 +206,12 @@ class TestDocumentRagStreamingProtocol:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client that returns chunk IDs"""
|
||||
"""Mock document embeddings client that returns chunk matches"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = ["doc/c1", "doc/c2"]
|
||||
client.query.return_value = [
|
||||
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
||||
ChunkMatch(chunk_id="doc/c2", score=0.90)
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue