Embeddings API scores (#671)

- Put scores in all responses
- Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
cybermaggedon 2026-03-09 10:53:44 +00:00 committed by GitHub
parent 4fa7cc7d7c
commit f2ae0e8623
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 1339 additions and 1292 deletions

View file

@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services
import pytest
from unittest.mock import MagicMock
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error
from trustgraph.messaging.translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator,
DocumentEmbeddingsResponseTranslator
@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract:
"""Test that DocumentEmbeddingsRequest has expected fields"""
# Create a request
request = DocumentEmbeddingsRequest(
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
)
# Verify all expected fields exist
assert hasattr(request, 'vectors')
assert hasattr(request, 'vector')
assert hasattr(request, 'limit')
assert hasattr(request, 'user')
assert hasattr(request, 'collection')
# Verify field values
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert request.vector == [0.1, 0.2, 0.3]
assert request.limit == 10
assert request.user == "test_user"
assert request.collection == "test_collection"
@ -43,7 +43,7 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2], [0.3, 0.4]],
"vector": [0.1, 0.2, 0.3, 0.4],
"limit": 5,
"user": "custom_user",
"collection": "custom_collection"
@ -52,7 +52,7 @@ class TestDocumentEmbeddingsRequestContract:
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
assert result.vector == [0.1, 0.2, 0.3, 0.4]
assert result.limit == 5
assert result.user == "custom_user"
assert result.collection == "custom_collection"
@ -62,14 +62,14 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2]]
"vector": [0.1, 0.2]
# No limit, user, or collection provided
}
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2]]
assert result.vector == [0.1, 0.2]
assert result.limit == 10 # Default
assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default
@ -79,7 +79,7 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator()
request = DocumentEmbeddingsRequest(
vectors=[[0.5, 0.6]],
vector=[0.5, 0.6],
limit=20,
user="test_user",
collection="test_collection"
@ -88,7 +88,7 @@ class TestDocumentEmbeddingsRequestContract:
result = translator.from_pulsar(request)
assert isinstance(result, dict)
assert result["vectors"] == [[0.5, 0.6]]
assert result["vector"] == [0.5, 0.6]
assert result["limit"] == 20
assert result["user"] == "test_user"
assert result["collection"] == "test_collection"
@ -99,19 +99,25 @@ class TestDocumentEmbeddingsResponseContract:
def test_response_schema_fields(self):
"""Test that DocumentEmbeddingsResponse has expected fields"""
# Create a response with chunk_ids
# Create a response with chunks
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=["chunk1", "chunk2", "chunk3"]
chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8),
ChunkMatch(chunk_id="chunk3", score=0.7)
]
)
# Verify all expected fields exist
assert hasattr(response, 'error')
assert hasattr(response, 'chunk_ids')
assert hasattr(response, 'chunks')
# Verify field values
assert response.error is None
assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"]
assert len(response.chunks) == 3
assert response.chunks[0].chunk_id == "chunk1"
assert response.chunks[0].score == 0.9
def test_response_schema_with_error(self):
"""Test response schema with error"""
@ -122,53 +128,59 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse(
error=error,
chunk_ids=[]
chunks=[]
)
assert response.error == error
assert response.chunk_ids == []
assert response.chunks == []
def test_response_translator_from_pulsar_with_chunk_ids(self):
"""Test response translator converts Pulsar schema with chunk_ids to dict"""
def test_response_translator_from_pulsar_with_chunks(self):
"""Test response translator converts Pulsar schema with chunks to dict"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"]
chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85),
ChunkMatch(chunk_id="doc3/c3", score=0.75)
]
)
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunk_ids" in result
assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"]
assert "chunks" in result
assert len(result["chunks"]) == 3
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
assert result["chunks"][0]["score"] == 0.95
def test_response_translator_from_pulsar_with_empty_chunk_ids(self):
"""Test response translator handles empty chunk_ids list"""
def test_response_translator_from_pulsar_with_empty_chunks(self):
"""Test response translator handles empty chunks list"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=[]
chunks=[]
)
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunk_ids" in result
assert result["chunk_ids"] == []
assert "chunks" in result
assert result["chunks"] == []
def test_response_translator_from_pulsar_with_none_chunk_ids(self):
"""Test response translator handles None chunk_ids"""
def test_response_translator_from_pulsar_with_none_chunks(self):
"""Test response translator handles None chunks"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunk_ids = None
response.chunks = None
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunk_ids" not in result or result.get("chunk_ids") is None
assert "chunks" not in result or result.get("chunks") is None
def test_response_translator_from_response_with_completion(self):
"""Test response translator with completion flag"""
@ -176,14 +188,18 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=["chunk1", "chunk2"]
chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8)
]
)
result, is_final = translator.from_response_with_completion(response)
assert isinstance(result, dict)
assert "chunk_ids" in result
assert result["chunk_ids"] == ["chunk1", "chunk2"]
assert "chunks" in result
assert len(result["chunks"]) == 2
assert result["chunks"][0]["chunk_id"] == "chunk1"
assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self):
@ -191,7 +207,7 @@ class TestDocumentEmbeddingsResponseContract:
translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunk_ids": ["test"]})
translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
class TestDocumentEmbeddingsMessageCompatibility:
@ -201,7 +217,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
"""Test complete request-response flow maintains data integrity"""
# Create request
request_data = {
"vectors": [[0.1, 0.2, 0.3]],
"vector": [0.1, 0.2, 0.3],
"limit": 5,
"user": "test_user",
"collection": "test_collection"
@ -214,7 +230,10 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Simulate service processing and creating response
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=["doc1/c1", "doc2/c2"]
chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85)
]
)
# Convert response back to dict
@ -224,8 +243,8 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
assert isinstance(response_data, dict)
assert "chunk_ids" in response_data
assert len(response_data["chunk_ids"]) == 2
assert "chunks" in response_data
assert len(response_data["chunks"]) == 2
def test_error_response_flow(self):
"""Test error response flow"""
@ -237,7 +256,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
response = DocumentEmbeddingsResponse(
error=error,
chunk_ids=[]
chunks=[]
)
# Convert response to dict
@ -246,6 +265,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify error handling
assert isinstance(response_data, dict)
# The translator doesn't include error in the dict, only chunk_ids
assert "chunk_ids" in response_data
assert response_data["chunk_ids"] == []
# The translator doesn't include error in the dict, only chunks
assert "chunks" in response_data
assert response_data["chunks"] == []