Graph embedding query exposed through gateway (#208)

This commit is contained in:
cybermaggedon 2024-12-10 22:15:56 +00:00 committed by GitHub
parent 8d326d34b3
commit cd8d0c8cbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 50 additions and 0 deletions

View file

@ -0,0 +1,40 @@
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import graph_embeddings_request_queue
from .. schema import graph_embeddings_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
from . serialize import serialize_value
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(GraphEmbeddingsQueryRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=graph_embeddings_request_queue,
response_queue=graph_embeddings_response_queue,
request_schema=GraphEmbeddingsRequest,
response_schema=GraphEmbeddingsResponse,
timeout=timeout,
)
def to_request(self, body):
limit = int(body.get("limit", 20))
return GraphEmbeddingsRequest(
vectors = body["vectors"],
limit = limit,
user = body.get("user", "trustgraph"),
collection = body.get("collection", "default"),
)
def from_response(self, message):
return {
"entities": [
serialize_value(ent) for ent in message.entities
]
}, True

View file

@ -35,6 +35,7 @@ from . text_completion import TextCompletionRequestor
from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
from . triples_query import TriplesQueryRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . embeddings import EmbeddingsRequestor
from . encyclopedia import EncyclopediaRequestor
from . agent import AgentRequestor
@ -95,6 +96,10 @@ class Api:
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"graph-embeddings-query": GraphEmbeddingsQueryRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"embeddings": EmbeddingsRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
@ -134,6 +139,11 @@ class Api:
endpoint_path = "/api/v1/triples-query", auth=self.auth,
requestor = self.services["triples-query"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/graph-embeddings-query",
auth=self.auth,
requestor = self.services["graph-embeddings-query"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/embeddings", auth=self.auth,
requestor = self.services["embeddings"],