mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 00:46:22 +02:00
Embeddings API scores (#671)
- Put scores in all responses - Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
parent
4fa7cc7d7c
commit
f2ae0e8623
65 changed files with 1339 additions and 1292 deletions
|
|
@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentEmbeddingsClient(RequestResponse):
|
||||
async def query(self, vectors, limit=20, user="trustgraph",
|
||||
async def query(self, vector, limit=20, user="trustgraph",
|
||||
collection="default", timeout=30):
|
||||
|
||||
resp = await self.request(
|
||||
DocumentEmbeddingsRequest(
|
||||
vectors = vectors,
|
||||
vector = vector,
|
||||
limit = limit,
|
||||
user = user,
|
||||
collection = collection
|
||||
|
|
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
|
|||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.chunk_ids
|
||||
# Return ChunkMatch objects with chunk_id and score
|
||||
return resp.chunks
|
||||
|
||||
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
docs = await self.query_document_embeddings(request)
|
||||
|
||||
logger.debug("Sending document embeddings query response...")
|
||||
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
|
||||
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
logger.debug("Document embeddings query request completed")
|
||||
|
|
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
type = "document-embeddings-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
chunk_ids=[],
|
||||
chunks=[],
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
|
|
|||
|
|
@ -19,12 +19,12 @@ def to_value(x):
|
|||
return Literal(x.value or x.iri)
|
||||
|
||||
class GraphEmbeddingsClient(RequestResponse):
|
||||
async def query(self, vectors, limit=20, user="trustgraph",
|
||||
async def query(self, vector, limit=20, user="trustgraph",
|
||||
collection="default", timeout=30):
|
||||
|
||||
resp = await self.request(
|
||||
GraphEmbeddingsRequest(
|
||||
vectors = vectors,
|
||||
vector = vector,
|
||||
limit = limit,
|
||||
user = user,
|
||||
collection = collection
|
||||
|
|
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
|
|||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return [
|
||||
to_value(v)
|
||||
for v in resp.entities
|
||||
]
|
||||
# Return EntityMatch objects with entity and score
|
||||
return resp.entities
|
||||
|
||||
class GraphEmbeddingsClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
|||
|
||||
class RowEmbeddingsQueryClient(RequestResponse):
|
||||
async def row_embeddings_query(
|
||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||
self, vector, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10, timeout=600
|
||||
):
|
||||
request = RowEmbeddingsRequest(
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
schema_name=schema_name,
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue