diff --git a/trustgraph-flow/trustgraph/gateway/graph_embeddings_query.py b/trustgraph-flow/trustgraph/gateway/graph_embeddings_query.py new file mode 100644 index 00000000..5e3c0ce9 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/graph_embeddings_query.py @@ -0,0 +1,40 @@ + +from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse +from .. schema import graph_embeddings_request_queue +from .. schema import graph_embeddings_response_queue + +from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor +from . serialize import serialize_value + +class GraphEmbeddingsQueryRequestor(ServiceRequestor): + def __init__(self, pulsar_host, timeout, auth): + + super(GraphEmbeddingsQueryRequestor, self).__init__( + pulsar_host=pulsar_host, + request_queue=graph_embeddings_request_queue, + response_queue=graph_embeddings_response_queue, + request_schema=GraphEmbeddingsRequest, + response_schema=GraphEmbeddingsResponse, + timeout=timeout, + ) + + def to_request(self, body): + + limit = int(body.get("limit", 20)) + + return GraphEmbeddingsRequest( + vectors = body["vectors"], + limit = limit, + user = body.get("user", "trustgraph"), + collection = body.get("collection", "default"), + ) + + def from_response(self, message): + + return { + "entities": [ + serialize_value(ent) for ent in message.entities + ] + }, True + diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 6a8a62eb..af15e981 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -35,6 +35,7 @@ from . text_completion import TextCompletionRequestor from . prompt import PromptRequestor from . graph_rag import GraphRagRequestor from . triples_query import TriplesQueryRequestor +from . graph_embeddings_query import GraphEmbeddingsQueryRequestor from . embeddings import EmbeddingsRequestor from . encyclopedia import EncyclopediaRequestor from . agent import AgentRequestor @@ -95,6 +96,10 @@ class Api: pulsar_host=self.pulsar_host, timeout=self.timeout, auth = self.auth, ), + "graph-embeddings-query": GraphEmbeddingsQueryRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), "embeddings": EmbeddingsRequestor( pulsar_host=self.pulsar_host, timeout=self.timeout, auth = self.auth, @@ -134,6 +139,11 @@ class Api: endpoint_path = "/api/v1/triples-query", auth=self.auth, requestor = self.services["triples-query"], ), + ServiceEndpoint( + endpoint_path = "/api/v1/graph-embeddings-query", + auth=self.auth, + requestor = self.services["graph-embeddings-query"], + ), ServiceEndpoint( endpoint_path = "/api/v1/embeddings", auth=self.auth, requestor = self.services["embeddings"],