From 6d200c79c5796de8fd9e04b7802c83041335a711 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 2 Dec 2024 17:41:30 +0000 Subject: [PATCH] Feature/wikipedia ddg (#185) API-side support for Wikipedia, DBpedia and internet search functions This incorporates a refactor of the API code to break it up, separate classes for endpoints to reduce duplication --- templates/components/azure-openai.jsonnet | 2 +- templates/components/azure.jsonnet | 2 +- templates/components/bedrock.jsonnet | 2 +- templates/components/claude.jsonnet | 2 +- templates/components/cohere.jsonnet | 2 +- templates/components/document-rag.jsonnet | 2 +- templates/components/googleaistudio.jsonnet | 2 +- templates/components/graph-rag.jsonnet | 2 +- templates/components/llamafile.jsonnet | 2 +- templates/components/ollama.jsonnet | 2 +- templates/components/openai.jsonnet | 2 +- templates/components/prompt-template.jsonnet | 6 +- templates/components/trustgraph.jsonnet | 2 +- templates/components/vertexai.jsonnet | 2 +- test-api/test-agent2-api | 28 + test-api/test-dbpedia | 30 + test-api/test-encyclopedia | 30 + test-api/test-internet-search | 30 + test-api/test-prompt-api | 1 - test-api/test-prompt2-api | 1 - test-api/test-triples-query-api | 5 +- trustgraph-base/trustgraph/schema/__init__.py | 2 + trustgraph-base/trustgraph/schema/lookup.py | 42 + trustgraph-cli/scripts/tg-load-kg-core | 1 - trustgraph-flow/scripts/wikipedia-lookup | 6 + trustgraph-flow/setup.py | 1 + .../trustgraph/api/gateway/agent.py | 30 + .../trustgraph/api/gateway/dbpedia.py | 29 + .../trustgraph/api/gateway/embeddings.py | 27 + .../trustgraph/api/gateway/encyclopedia.py | 29 + .../trustgraph/api/gateway/endpoint.py | 153 +++ .../api/gateway/graph_embeddings_load.py | 60 ++ .../api/gateway/graph_embeddings_stream.py | 56 ++ .../trustgraph/api/gateway/graph_rag.py | 30 + .../trustgraph/api/gateway/internet_search.py | 29 + .../trustgraph/api/gateway/prompt.py | 41 + .../trustgraph/api/gateway/publisher.py | 41 + .../trustgraph/api/gateway/running.py | 5 + .../trustgraph/api/gateway/serialize.py | 57 ++ .../trustgraph/api/gateway/service.py | 873 ++---------------- .../trustgraph/api/gateway/socket.py | 68 ++ .../trustgraph/api/gateway/subscriber.py | 68 ++ .../trustgraph/api/gateway/text_completion.py | 28 + .../trustgraph/api/gateway/triples_load.py | 59 ++ .../trustgraph/api/gateway/triples_query.py | 53 ++ .../trustgraph/api/gateway/triples_stream.py | 56 ++ .../trustgraph/external/__init__.py | 0 .../trustgraph/external/wikipedia/__init__.py | 3 + .../trustgraph/external/wikipedia/__main__.py | 7 + .../trustgraph/external/wikipedia/service.py | 102 ++ 50 files changed, 1287 insertions(+), 826 deletions(-) create mode 100755 test-api/test-agent2-api create mode 100755 test-api/test-dbpedia create mode 100755 test-api/test-encyclopedia create mode 100755 test-api/test-internet-search create mode 100644 trustgraph-base/trustgraph/schema/lookup.py create mode 100755 trustgraph-flow/scripts/wikipedia-lookup create mode 100644 trustgraph-flow/trustgraph/api/gateway/agent.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/dbpedia.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/embeddings.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/encyclopedia.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/endpoint.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/graph_embeddings_load.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/graph_embeddings_stream.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/graph_rag.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/internet_search.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/prompt.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/publisher.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/running.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/serialize.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/socket.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/subscriber.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/text_completion.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/triples_load.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/triples_query.py create mode 100644 trustgraph-flow/trustgraph/api/gateway/triples_stream.py create mode 100644 trustgraph-flow/trustgraph/external/__init__.py create mode 100644 trustgraph-flow/trustgraph/external/wikipedia/__init__.py create mode 100644 trustgraph-flow/trustgraph/external/wikipedia/__main__.py create mode 100644 trustgraph-flow/trustgraph/external/wikipedia/service.py diff --git a/templates/components/azure-openai.jsonnet b/templates/components/azure-openai.jsonnet index cc3847c0..8afcaf11 100644 --- a/templates/components/azure-openai.jsonnet +++ b/templates/components/azure-openai.jsonnet @@ -48,7 +48,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") diff --git a/templates/components/azure.jsonnet b/templates/components/azure.jsonnet index 82b79133..cf10dc66 100644 --- a/templates/components/azure.jsonnet +++ b/templates/components/azure.jsonnet @@ -46,7 +46,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") diff --git a/templates/components/bedrock.jsonnet b/templates/components/bedrock.jsonnet index 93978a59..6ccaa1c5 100644 --- a/templates/components/bedrock.jsonnet +++ b/templates/components/bedrock.jsonnet @@ -53,7 +53,7 @@ local chunker = import "chunker-recursive.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") diff --git a/templates/components/claude.jsonnet b/templates/components/claude.jsonnet index c6c94e21..00e4ec79 100644 --- a/templates/components/claude.jsonnet +++ b/templates/components/claude.jsonnet @@ -45,7 +45,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") diff --git a/templates/components/cohere.jsonnet b/templates/components/cohere.jsonnet index 11c30fbd..5bc9b39c 100644 --- a/templates/components/cohere.jsonnet +++ b/templates/components/cohere.jsonnet @@ -43,7 +43,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); diff --git a/templates/components/document-rag.jsonnet b/templates/components/document-rag.jsonnet index ac5c11ec..0a68dd52 100644 --- a/templates/components/document-rag.jsonnet +++ b/templates/components/document-rag.jsonnet @@ -19,7 +19,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "--prompt-request-queue", "non-persistent://tg/request/prompt-rag", "--prompt-response-queue", - "non-persistent://tg/response/prompt-rag-response", + "non-persistent://tg/response/prompt-rag", ]) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); diff --git a/templates/components/googleaistudio.jsonnet b/templates/components/googleaistudio.jsonnet index b6ee1d85..4088ceef 100644 --- a/templates/components/googleaistudio.jsonnet +++ b/templates/components/googleaistudio.jsonnet @@ -50,7 +50,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") diff --git a/templates/components/graph-rag.jsonnet b/templates/components/graph-rag.jsonnet index c0200d1e..860152c9 100644 --- a/templates/components/graph-rag.jsonnet +++ b/templates/components/graph-rag.jsonnet @@ -112,7 +112,7 @@ local url = import "values/url.jsonnet"; "--prompt-request-queue", "non-persistent://tg/request/prompt-rag", "--prompt-response-queue", - "non-persistent://tg/response/prompt-rag-response", + "non-persistent://tg/response/prompt-rag", "--entity-limit", std.toString($["graph-rag-entity-limit"]), "--triple-limit", diff --git a/templates/components/llamafile.jsonnet b/templates/components/llamafile.jsonnet index d51cda61..bc1a011c 100644 --- a/templates/components/llamafile.jsonnet +++ b/templates/components/llamafile.jsonnet @@ -40,7 +40,7 @@ local prompts = import "prompts/slm.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") diff --git a/templates/components/ollama.jsonnet b/templates/components/ollama.jsonnet index 2ae696b4..8da00848 100644 --- a/templates/components/ollama.jsonnet +++ b/templates/components/ollama.jsonnet @@ -40,7 +40,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") diff --git a/templates/components/openai.jsonnet b/templates/components/openai.jsonnet index 83cbd406..27725cb6 100644 --- a/templates/components/openai.jsonnet +++ b/templates/components/openai.jsonnet @@ -50,7 +50,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") diff --git a/templates/components/prompt-template.jsonnet b/templates/components/prompt-template.jsonnet index ac820df6..3dadf337 100644 --- a/templates/components/prompt-template.jsonnet +++ b/templates/components/prompt-template.jsonnet @@ -53,7 +53,7 @@ local default_prompts = import "prompts/default-prompts.jsonnet"; "--text-completion-request-queue", "non-persistent://tg/request/text-completion", "--text-completion-response-queue", - "non-persistent://tg/response/text-completion-response", + "non-persistent://tg/response/text-completion", "--system-prompt", $["prompts"]["system-template"], @@ -92,11 +92,11 @@ local default_prompts = import "prompts/default-prompts.jsonnet"; "-i", "non-persistent://tg/request/prompt-rag", "-o", - "non-persistent://tg/response/prompt-rag-response", + "non-persistent://tg/response/prompt-rag", "--text-completion-request-queue", "non-persistent://tg/request/text-completion-rag", "--text-completion-response-queue", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", "--system-prompt", $["prompts"]["system-template"], diff --git a/templates/components/trustgraph.jsonnet b/templates/components/trustgraph.jsonnet index 37c05dae..6c60921c 100644 --- a/templates/components/trustgraph.jsonnet +++ b/templates/components/trustgraph.jsonnet @@ -186,7 +186,7 @@ local prompt = import "prompt-template.jsonnet"; "-p", url.pulsar, "-i", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); diff --git a/templates/components/vertexai.jsonnet b/templates/components/vertexai.jsonnet index 44fe27c6..ef193156 100644 --- a/templates/components/vertexai.jsonnet +++ b/templates/components/vertexai.jsonnet @@ -93,7 +93,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-i", "non-persistent://tg/request/text-completion-rag", "-o", - "non-persistent://tg/response/text-completion-rag-response", + "non-persistent://tg/response/text-completion-rag", ]) .with_limits("0.5", "256M") .with_reservations("0.1", "256M") diff --git a/test-api/test-agent2-api b/test-api/test-agent2-api new file mode 100755 index 00000000..766b16c9 --- /dev/null +++ b/test-api/test-agent2-api @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +import requests +import json +import sys + +url = "http://localhost:8088/api/v1/" + +############################################################################ + +input = { + "question": "What is 14 plus 12. Justify your answer.", +} + +resp = requests.post( + f"{url}agent", + json=input, +) + +resp = resp.json() + +if "error" in resp: + print(f"Error: {resp['error']}") + sys.exit(1) + +print(resp["answer"]) + + diff --git a/test-api/test-dbpedia b/test-api/test-dbpedia new file mode 100755 index 00000000..e361f533 --- /dev/null +++ b/test-api/test-dbpedia @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +import requests +import json +import sys + +url = "http://localhost:8088/api/v1/" + +############################################################################ + +input = { + "term": "Cornwall", +} + +resp = requests.post( + f"{url}dbpedia", + json=input, +) + +resp = resp.json() + +if "error" in resp: + print(f"Error: {resp['error']}") + sys.exit(1) + +print(resp["text"]) + +sys.exit(0) +############################################################################ + diff --git a/test-api/test-encyclopedia b/test-api/test-encyclopedia new file mode 100755 index 00000000..ad4e5b36 --- /dev/null +++ b/test-api/test-encyclopedia @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +import requests +import json +import sys + +url = "http://localhost:8088/api/v1/" + +############################################################################ + +input = { + "term": "Cornwall", +} + +resp = requests.post( + f"{url}encyclopedia", + json=input, +) + +resp = resp.json() + +if "error" in resp: + print(f"Error: {resp['error']}") + sys.exit(1) + +print(resp["text"]) + +sys.exit(0) +############################################################################ + diff --git a/test-api/test-internet-search b/test-api/test-internet-search new file mode 100755 index 00000000..8c854c77 --- /dev/null +++ b/test-api/test-internet-search @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +import requests +import json +import sys + +url = "http://localhost:8088/api/v1/" + +############################################################################ + +input = { + "term": "Cornwall", +} + +resp = requests.post( + f"{url}internet-search", + json=input, +) + +resp = resp.json() + +if "error" in resp: + print(f"Error: {resp['error']}") + sys.exit(1) + +print(resp["text"]) + +sys.exit(0) +############################################################################ + diff --git a/test-api/test-prompt-api b/test-api/test-prompt-api index 1005bc90..4f69f09a 100755 --- a/test-api/test-prompt-api +++ b/test-api/test-prompt-api @@ -22,7 +22,6 @@ resp = requests.post( resp = resp.json() -print(resp) if "error" in resp: print(f"Error: {resp['error']}") sys.exit(1) diff --git a/test-api/test-prompt2-api b/test-api/test-prompt2-api index f1b80c48..1e641439 100755 --- a/test-api/test-prompt2-api +++ b/test-api/test-prompt2-api @@ -22,7 +22,6 @@ resp = requests.post( resp = resp.json() -print(resp) if "error" in resp: print(f"Error: {resp['error']}") sys.exit(1) diff --git a/test-api/test-triples-query-api b/test-api/test-triples-query-api index e2895a28..1aa8a0b1 100755 --- a/test-api/test-triples-query-api +++ b/test-api/test-triples-query-api @@ -9,7 +9,10 @@ url = "http://localhost:8088/api/v1/" ############################################################################ input = { - "p": "http://www.w3.org/2000/01/rdf-schema#label", + "p": { + "v": "http://www.w3.org/2000/01/rdf-schema#label", + "e": True, + }, "limit": 10 } diff --git a/trustgraph-base/trustgraph/schema/__init__.py b/trustgraph-base/trustgraph/schema/__init__.py index 3196691b..be41b670 100644 --- a/trustgraph-base/trustgraph/schema/__init__.py +++ b/trustgraph-base/trustgraph/schema/__init__.py @@ -9,4 +9,6 @@ from . graph import * from . retrieval import * from . metadata import * from . agent import * +from . lookup import * + diff --git a/trustgraph-base/trustgraph/schema/lookup.py b/trustgraph-base/trustgraph/schema/lookup.py new file mode 100644 index 00000000..d0a0517c --- /dev/null +++ b/trustgraph-base/trustgraph/schema/lookup.py @@ -0,0 +1,42 @@ + +from pulsar.schema import Record, String + +from . types import Error, Value, Triple +from . topic import topic +from . metadata import Metadata + +############################################################################ + +# Lookups + +class LookupRequest(Record): + kind = String() + term = String() + +class LookupResponse(Record): + text = String() + error = Error() + +encyclopedia_lookup_request_queue = topic( + 'encyclopedia', kind='non-persistent', namespace='request' +) +encyclopedia_lookup_response_queue = topic( + 'encyclopedia', kind='non-persistent', namespace='response', +) + +dbpedia_lookup_request_queue = topic( + 'dbpedia', kind='non-persistent', namespace='request' +) +dbpedia_lookup_response_queue = topic( + 'dbpedia', kind='non-persistent', namespace='response', +) + +internet_search_request_queue = topic( + 'internet-search', kind='non-persistent', namespace='request' +) +internet_search_response_queue = topic( + 'internet-search', kind='non-persistent', namespace='response', +) + +############################################################################ + diff --git a/trustgraph-cli/scripts/tg-load-kg-core b/trustgraph-cli/scripts/tg-load-kg-core index e2d0a405..4e207cf1 100755 --- a/trustgraph-cli/scripts/tg-load-kg-core +++ b/trustgraph-cli/scripts/tg-load-kg-core @@ -93,7 +93,6 @@ async def loader(ge_queue, t_queue, path, format, user, collection): if collection: unpacked["metadata"]["collection"] = collection - if unpacked[0] == "t": await t_queue.put(unpacked[1]) t_counts += 1 diff --git a/trustgraph-flow/scripts/wikipedia-lookup b/trustgraph-flow/scripts/wikipedia-lookup new file mode 100755 index 00000000..a89b1009 --- /dev/null +++ b/trustgraph-flow/scripts/wikipedia-lookup @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.external.wikipedia import run + +run() + diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index 8e81e12c..65bb7326 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -106,5 +106,6 @@ setuptools.setup( "scripts/triples-query-neo4j", "scripts/triples-write-cassandra", "scripts/triples-write-neo4j", + "scripts/wikipedia-lookup", ] ) diff --git a/trustgraph-flow/trustgraph/api/gateway/agent.py b/trustgraph-flow/trustgraph/api/gateway/agent.py new file mode 100644 index 00000000..28a1e185 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/agent.py @@ -0,0 +1,30 @@ + +from ... schema import AgentRequest, AgentResponse +from ... schema import agent_request_queue +from ... schema import agent_response_queue + +from . endpoint import MultiResponseServiceEndpoint + +class AgentEndpoint(MultiResponseServiceEndpoint): + def __init__(self, pulsar_host, timeout): + + super(AgentEndpoint, self).__init__( + pulsar_host=pulsar_host, + request_queue=agent_request_queue, + response_queue=agent_response_queue, + request_schema=AgentRequest, + response_schema=AgentResponse, + endpoint_path="/api/v1/agent", + timeout=timeout, + ) + + def to_request(self, body): + return AgentRequest( + question=body["question"] + ) + + def from_response(self, message): + if message.answer: + return { "answer": message.answer }, True + else: + return {}, False diff --git a/trustgraph-flow/trustgraph/api/gateway/dbpedia.py b/trustgraph-flow/trustgraph/api/gateway/dbpedia.py new file mode 100644 index 00000000..0ccb3d6b --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/dbpedia.py @@ -0,0 +1,29 @@ + +from ... schema import LookupRequest, LookupResponse +from ... schema import dbpedia_lookup_request_queue +from ... schema import dbpedia_lookup_response_queue + +from . endpoint import ServiceEndpoint + +class DbpediaEndpoint(ServiceEndpoint): + def __init__(self, pulsar_host, timeout): + + super(DbpediaEndpoint, self).__init__( + pulsar_host=pulsar_host, + request_queue=dbpedia_lookup_request_queue, + response_queue=dbpedia_lookup_response_queue, + request_schema=LookupRequest, + response_schema=LookupResponse, + endpoint_path="/api/v1/dbpedia", + timeout=timeout, + ) + + def to_request(self, body): + return LookupRequest( + term=body["term"], + kind=body.get("kind", None), + ) + + def from_response(self, message): + return { "text": message.text } + diff --git a/trustgraph-flow/trustgraph/api/gateway/embeddings.py b/trustgraph-flow/trustgraph/api/gateway/embeddings.py new file mode 100644 index 00000000..b5fcc0a4 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/embeddings.py @@ -0,0 +1,27 @@ + +from ... schema import EmbeddingsRequest, EmbeddingsResponse +from ... schema import embeddings_request_queue +from ... schema import embeddings_response_queue + +from . endpoint import ServiceEndpoint + +class EmbeddingsEndpoint(ServiceEndpoint): + def __init__(self, pulsar_host, timeout): + + super(EmbeddingsEndpoint, self).__init__( + pulsar_host=pulsar_host, + request_queue=embeddings_request_queue, + response_queue=embeddings_response_queue, + request_schema=EmbeddingsRequest, + response_schema=EmbeddingsResponse, + endpoint_path="/api/v1/embeddings", + timeout=timeout, + ) + + def to_request(self, body): + return EmbeddingsRequest( + text=body["text"] + ) + + def from_response(self, message): + return { "vectors": message.vectors } diff --git a/trustgraph-flow/trustgraph/api/gateway/encyclopedia.py b/trustgraph-flow/trustgraph/api/gateway/encyclopedia.py new file mode 100644 index 00000000..e379d7d4 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/encyclopedia.py @@ -0,0 +1,29 @@ + +from ... schema import LookupRequest, LookupResponse +from ... schema import encyclopedia_lookup_request_queue +from ... schema import encyclopedia_lookup_response_queue + +from . endpoint import ServiceEndpoint + +class EncyclopediaEndpoint(ServiceEndpoint): + def __init__(self, pulsar_host, timeout): + + super(EncyclopediaEndpoint, self).__init__( + pulsar_host=pulsar_host, + request_queue=encyclopedia_lookup_request_queue, + response_queue=encyclopedia_lookup_response_queue, + request_schema=LookupRequest, + response_schema=LookupResponse, + endpoint_path="/api/v1/encyclopedia", + timeout=timeout, + ) + + def to_request(self, body): + return LookupRequest( + term=body["term"], + kind=body.get("kind", None), + ) + + def from_response(self, message): + return { "text": message.text } + diff --git a/trustgraph-flow/trustgraph/api/gateway/endpoint.py b/trustgraph-flow/trustgraph/api/gateway/endpoint.py new file mode 100644 index 00000000..075e4a0e --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/endpoint.py @@ -0,0 +1,153 @@ + +import asyncio +from pulsar.schema import JsonSchema +from aiohttp import web +import uuid +import logging + +from . publisher import Publisher +from . subscriber import Subscriber + +logger = logging.getLogger("endpoint") +logger.setLevel(logging.INFO) + +class ServiceEndpoint: + + def __init__( + self, + pulsar_host, + request_queue, request_schema, + response_queue, response_schema, + endpoint_path, + subscription="api-gateway", consumer_name="api-gateway", + timeout=600, + ): + + self.pub = Publisher( + pulsar_host, request_queue, + schema=JsonSchema(request_schema) + ) + + self.sub = Subscriber( + pulsar_host, response_queue, + subscription, consumer_name, + JsonSchema(response_schema) + ) + + self.path = endpoint_path + self.timeout = timeout + + async def start(self): + + self.pub_task = asyncio.create_task(self.pub.run()) + self.sub_task = asyncio.create_task(self.sub.run()) + + def add_routes(self, app): + + app.add_routes([ + web.post(self.path, self.handle), + ]) + + def to_request(self, request): + raise RuntimeError("Not defined") + + def from_response(self, response): + raise RuntimeError("Not defined") + + async def handle(self, request): + + id = str(uuid.uuid4()) + + try: + + data = await request.json() + + q = await self.sub.subscribe(id) + + print(data) + + await self.pub.send( + id, + self.to_request(data), + ) + + try: + resp = await asyncio.wait_for(q.get(), self.timeout) + except: + raise RuntimeError("Timeout waiting for response") + + print(resp) + + if resp.error: + return web.json_response( + { "error": resp.error.message } + ) + + return web.json_response( + self.from_response(resp) + ) + + except Exception as e: + logging.error(f"Exception: {e}") + + return web.json_response( + { "error": str(e) } + ) + + finally: + await self.sub.unsubscribe(id) + + +class MultiResponseServiceEndpoint(ServiceEndpoint): + + async def handle(self, request): + + id = str(uuid.uuid4()) + + try: + + data = await request.json() + + q = await self.sub.subscribe(id) + + print(data) + + await self.pub.send( + id, + self.to_request(data), + ) + + # Keeps looking at responses... + + while True: + + try: + resp = await asyncio.wait_for(q.get(), self.timeout) + except: + raise RuntimeError("Timeout waiting for response") + + print(resp) + + if resp.error: + return web.json_response( + { "error": resp.error.message } + ) + + # Until from_response says we have a finished answer + resp, fin = self.from_response(resp) + + + if fin: + return web.json_response(resp) + + # Not finished, so loop round and continue + + except Exception as e: + logging.error(f"Exception: {e}") + + return web.json_response( + { "error": str(e) } + ) + + finally: + await self.sub.unsubscribe(id) diff --git a/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_load.py b/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_load.py new file mode 100644 index 00000000..3cc3f533 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_load.py @@ -0,0 +1,60 @@ + +import asyncio +from pulsar.schema import JsonSchema +import uuid +from aiohttp import WSMsgType + +from ... schema import Metadata +from ... schema import GraphEmbeddings +from ... schema import graph_embeddings_store_queue + +from . publisher import Publisher +from . socket import SocketEndpoint +from . serialize import to_subgraph, to_value + +class GraphEmbeddingsLoadEndpoint(SocketEndpoint): + + def __init__(self, pulsar_host, path="/api/v1/load/graph-embeddings"): + + super(GraphEmbeddingsLoadEndpoint, self).__init__( + endpoint_path=path + ) + + self.pulsar_host=pulsar_host + + self.publisher = Publisher( + self.pulsar_host, graph_embeddings_store_queue, + schema=JsonSchema(GraphEmbeddings) + ) + + async def start(self): + + self.task = asyncio.create_task( + self.publisher.run() + ) + + async def listener(self, ws, running): + + async for msg in ws: + # On error, finish + if msg.type == WSMsgType.ERROR: + break + else: + + data = msg.json() + + elt = GraphEmbeddings( + metadata=Metadata( + id=data["metadata"]["id"], + metadata=to_subgraph(data["metadata"]["metadata"]), + user=data["metadata"]["user"], + collection=data["metadata"]["collection"], + ), + entity=to_value(data["entity"]), + vectors=data["vectors"], + ) + + await self.publisher.send(None, elt) + + + running.stop() diff --git a/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_stream.py b/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_stream.py new file mode 100644 index 00000000..978684cf --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/graph_embeddings_stream.py @@ -0,0 +1,56 @@ + +import asyncio +from pulsar.schema import JsonSchema +import uuid + +from ... schema import GraphEmbeddings +from ... schema import graph_embeddings_store_queue + +from . subscriber import Subscriber +from . socket import SocketEndpoint +from . serialize import serialize_graph_embeddings + +class GraphEmbeddingsStreamEndpoint(SocketEndpoint): + + def __init__(self, pulsar_host, path="/api/v1/stream/graph-embeddings"): + + super(GraphEmbeddingsStreamEndpoint, self).__init__( + endpoint_path=path + ) + + self.pulsar_host=pulsar_host + + self.subscriber = Subscriber( + self.pulsar_host, graph_embeddings_store_queue, + "api-gateway", "api-gateway", + schema=JsonSchema(GraphEmbeddings) + ) + + async def start(self): + + self.task = asyncio.create_task( + self.subscriber.run() + ) + + async def async_thread(self, ws, running): + + id = str(uuid.uuid4()) + + q = await self.subscriber.subscribe_all(id) + + while running.get(): + try: + resp = await asyncio.wait_for(q.get(), 0.5) + await ws.send_json(serialize_graph_embeddings(resp)) + + except TimeoutError: + continue + + except Exception as e: + print(f"Exception: {str(e)}", flush=True) + break + + await self.subscriber.unsubscribe_all(id) + + running.stop() + diff --git a/trustgraph-flow/trustgraph/api/gateway/graph_rag.py b/trustgraph-flow/trustgraph/api/gateway/graph_rag.py new file mode 100644 index 00000000..1381dc23 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/graph_rag.py @@ -0,0 +1,30 @@ + +from ... schema import GraphRagQuery, GraphRagResponse +from ... schema import graph_rag_request_queue +from ... schema import graph_rag_response_queue + +from . endpoint import ServiceEndpoint + +class GraphRagEndpoint(ServiceEndpoint): + def __init__(self, pulsar_host, timeout): + + super(GraphRagEndpoint, self).__init__( + pulsar_host=pulsar_host, + request_queue=graph_rag_request_queue, + response_queue=graph_rag_response_queue, + request_schema=GraphRagQuery, + response_schema=GraphRagResponse, + endpoint_path="/api/v1/graph-rag", + timeout=timeout, + ) + + def to_request(self, body): + return GraphRagQuery( + query=body["query"], + user=body.get("user", "trustgraph"), + collection=body.get("collection", "default"), + ) + + def from_response(self, message): + return { "response": message.response } + diff --git a/trustgraph-flow/trustgraph/api/gateway/internet_search.py b/trustgraph-flow/trustgraph/api/gateway/internet_search.py new file mode 100644 index 00000000..c84ed82a --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/internet_search.py @@ -0,0 +1,29 @@ + +from ... schema import LookupRequest, LookupResponse +from ... schema import internet_search_request_queue +from ... schema import internet_search_response_queue + +from . endpoint import ServiceEndpoint + +class InternetSearchEndpoint(ServiceEndpoint): + def __init__(self, pulsar_host, timeout): + + super(InternetSearchEndpoint, self).__init__( + pulsar_host=pulsar_host, + request_queue=internet_search_request_queue, + response_queue=internet_search_response_queue, + request_schema=LookupRequest, + response_schema=LookupResponse, + endpoint_path="/api/v1/internet-search", + timeout=timeout, + ) + + def to_request(self, body): + return LookupRequest( + term=body["term"], + kind=body.get("kind", None), + ) + + def from_response(self, message): + return { "text": message.text } + diff --git a/trustgraph-flow/trustgraph/api/gateway/prompt.py b/trustgraph-flow/trustgraph/api/gateway/prompt.py new file mode 100644 index 00000000..e02effb9 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/prompt.py @@ -0,0 +1,41 @@ + +import json + +from ... schema import PromptRequest, PromptResponse +from ... schema import prompt_request_queue +from ... schema import prompt_response_queue + +from . endpoint import ServiceEndpoint + +class PromptEndpoint(ServiceEndpoint): + def __init__(self, pulsar_host, timeout): + + super(PromptEndpoint, self).__init__( + pulsar_host=pulsar_host, + request_queue=prompt_request_queue, + response_queue=prompt_response_queue, + request_schema=PromptRequest, + response_schema=PromptResponse, + endpoint_path="/api/v1/prompt", + timeout=timeout, + ) + + def to_request(self, body): + return PromptRequest( + id=body["id"], + terms={ + k: json.dumps(v) + for k, v in body["variables"].items() + } + ) + + def from_response(self, message): + if message.object: + return { + "object": message.object + } + else: + return { + "text": message.text + } + diff --git a/trustgraph-flow/trustgraph/api/gateway/publisher.py b/trustgraph-flow/trustgraph/api/gateway/publisher.py new file mode 100644 index 00000000..1bff44dd --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/publisher.py @@ -0,0 +1,41 @@ + +import asyncio +import aiopulsar + +class Publisher: + + def __init__(self, pulsar_host, topic, schema=None, max_size=10, + chunking_enabled=False): + self.pulsar_host = pulsar_host + self.topic = topic + self.schema = schema + self.q = asyncio.Queue(maxsize=max_size) + self.chunking_enabled = chunking_enabled + + async def run(self): + + while True: + + try: + async with aiopulsar.connect(self.pulsar_host) as client: + async with client.create_producer( + topic=self.topic, + schema=self.schema, + chunking_enabled=self.chunking_enabled, + ) as producer: + while True: + id, item = await self.q.get() + + if id: + await producer.send(item, { "id": id }) + else: + await producer.send(item) + + except Exception as e: + print("Exception:", e, flush=True) + + # If handler drops out, sleep a retry + await asyncio.sleep(2) + + async def send(self, id, msg): + await self.q.put((id, msg)) diff --git a/trustgraph-flow/trustgraph/api/gateway/running.py b/trustgraph-flow/trustgraph/api/gateway/running.py new file mode 100644 index 00000000..e6a91e66 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/running.py @@ -0,0 +1,5 @@ + +class Running: + def __init__(self): self.running = True + def get(self): return self.running + def stop(self): self.running = False diff --git a/trustgraph-flow/trustgraph/api/gateway/serialize.py b/trustgraph-flow/trustgraph/api/gateway/serialize.py new file mode 100644 index 00000000..2b955645 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/serialize.py @@ -0,0 +1,57 @@ +from ... schema import Value, Triple + +def to_value(x): + return Value(value=x["v"], is_uri=x["e"]) + +def to_subgraph(x): + return [ + Triple( + s=to_value(t["s"]), + p=to_value(t["p"]), + o=to_value(t["o"]) + ) + for t in x + ] + +def serialize_value(v): + return { + "v": v.value, + "e": v.is_uri, + } + +def serialize_triple(t): + return { + "s": serialize_value(t.s), + "p": serialize_value(t.p), + "o": serialize_value(t.o) + } + +def serialize_subgraph(sg): + return [ + serialize_triple(t) + for t in sg + ] + +def serialize_triples(message): + return { + "metadata": { + "id": message.metadata.id, + "metadata": serialize_subgraph(message.metadata.metadata), + "user": message.metadata.user, + "collection": message.metadata.collection, + }, + "triples": serialize_subgraph(message.triples), + } + +def serialize_graph_embeddings(message): + return { + "metadata": { + "id": message.metadata.id, + "metadata": serialize_subgraph(message.metadata.metadata), + "user": message.metadata.user, + "collection": message.metadata.collection, + }, + "vectors": message.vectors, + "entity": serialize_value(message.entity), + } + diff --git a/trustgraph-flow/trustgraph/api/gateway/service.py b/trustgraph-flow/trustgraph/api/gateway/service.py index 7b12e1a2..dcdd9779 100755 --- a/trustgraph-flow/trustgraph/api/gateway/service.py +++ b/trustgraph-flow/trustgraph/api/gateway/service.py @@ -1,4 +1,3 @@ - """ API gateway. Offers HTTP services which are translated to interaction on the Pulsar bus. @@ -14,57 +13,39 @@ module = ".".join(__name__.split(".")[1:-1]) import asyncio import argparse -from aiohttp import web, WSMsgType -import json +from aiohttp import web import logging -import uuid import os import base64 import pulsar -from pulsar.asyncio import Client from pulsar.schema import JsonSchema -import _pulsar -import aiopulsar from prometheus_client import start_http_server from ... log_level import LogLevel -from trustgraph.clients.llm_client import LlmClient -from trustgraph.clients.prompt_client import PromptClient - -from ... schema import Value, Metadata, Document, TextDocument, Triple - -from ... schema import TextCompletionRequest, TextCompletionResponse -from ... schema import text_completion_request_queue -from ... schema import text_completion_response_queue - -from ... schema import PromptRequest, PromptResponse -from ... schema import prompt_request_queue -from ... schema import prompt_response_queue - -from ... schema import GraphRagQuery, GraphRagResponse -from ... schema import graph_rag_request_queue -from ... schema import graph_rag_response_queue - -from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples -from ... schema import triples_request_queue -from ... schema import triples_response_queue -from ... schema import triples_store_queue - -from ... schema import GraphEmbeddings -from ... schema import graph_embeddings_store_queue - -from ... schema import AgentRequest, AgentResponse -from ... schema import agent_request_queue -from ... schema import agent_response_queue - -from ... schema import EmbeddingsRequest, EmbeddingsResponse -from ... schema import embeddings_request_queue -from ... schema import embeddings_response_queue - +from ... schema import Metadata, Document, TextDocument from ... schema import document_ingest_queue, text_ingest_queue +from . serialize import to_subgraph +from . running import Running +from . publisher import Publisher +from . subscriber import Subscriber +from . endpoint import ServiceEndpoint, MultiResponseServiceEndpoint +from . text_completion import TextCompletionEndpoint +from . prompt import PromptEndpoint +from . graph_rag import GraphRagEndpoint +from . triples_query import TriplesQueryEndpoint +from . embeddings import EmbeddingsEndpoint +from . encyclopedia import EncyclopediaEndpoint +from . agent import AgentEndpoint +from . dbpedia import DbpediaEndpoint +from . internet_search import InternetSearchEndpoint +from . triples_stream import TriplesStreamEndpoint +from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint +from . triples_load import TriplesLoadEndpoint +from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint + logger = logging.getLogger("api") logger.setLevel(logging.INFO) @@ -72,168 +53,6 @@ default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") default_timeout = 600 default_port = 8088 -def to_value(x): - return Value(value=x["v"], is_uri=x["e"]) - -def to_subgraph(x): - return [ - Triple( - s=to_value(t["s"]), - p=to_value(t["p"]), - o=to_value(t["o"]) - ) - for t in x - ] - -class Running: - def __init__(self): self.running = True - def get(self): return self.running - def stop(self): self.running = False - -class Publisher: - - def __init__(self, pulsar_host, topic, schema=None, max_size=10, - chunking_enabled=False): - self.pulsar_host = pulsar_host - self.topic = topic - self.schema = schema - self.q = asyncio.Queue(maxsize=max_size) - self.chunking_enabled = chunking_enabled - - async def run(self): - - while True: - - try: - async with aiopulsar.connect(self.pulsar_host) as client: - async with client.create_producer( - topic=self.topic, - schema=self.schema, - chunking_enabled=self.chunking_enabled, - ) as producer: - while True: - id, item = await self.q.get() - - if id: - await producer.send(item, { "id": id }) - else: - await producer.send(item) - - except Exception as e: - print("Exception:", e, flush=True) - - # If handler drops out, sleep a retry - await asyncio.sleep(2) - - async def send(self, id, msg): - await self.q.put((id, msg)) - -class Subscriber: - - def __init__(self, pulsar_host, topic, subscription, consumer_name, - schema=None, max_size=10): - self.pulsar_host = pulsar_host - self.topic = topic - self.subscription = subscription - self.consumer_name = consumer_name - self.schema = schema - self.q = {} - self.full = {} - - async def run(self): - while True: - try: - async with aiopulsar.connect(self.pulsar_host) as client: - async with client.subscribe( - topic=self.topic, - subscription_name=self.subscription, - consumer_name=self.consumer_name, - schema=self.schema, - ) as consumer: - while True: - msg = await consumer.receive() - - # Acknowledge successful reception of the message - await consumer.acknowledge(msg) - - try: - id = msg.properties()["id"] - except: - id = None - - value = msg.value() - if id in self.q: - await self.q[id].put(value) - - for q in self.full.values(): - await q.put(value) - - except Exception as e: - print("Exception:", e, flush=True) - - # If handler drops out, sleep a retry - await asyncio.sleep(2) - - async def subscribe(self, id): - q = asyncio.Queue() - self.q[id] = q - return q - - async def unsubscribe(self, id): - if id in self.q: - del self.q[id] - - async def subscribe_all(self, id): - q = asyncio.Queue() - self.full[id] = q - return q - - async def unsubscribe_all(self, id): - if id in self.full: - del self.full[id] - -def serialize_value(v): - return { - "v": v.value, - "e": v.is_uri, - } - -def serialize_triple(t): - return { - "s": serialize_value(t.s), - "p": serialize_value(t.p), - "o": serialize_value(t.o) - } - -def serialize_subgraph(sg): - return [ - serialize_triple(t) - for t in sg - ] - -def serialize_triples(message): - return { - "metadata": { - "id": message.metadata.id, - "metadata": serialize_subgraph(message.metadata.metadata), - "user": message.metadata.user, - "collection": message.metadata.collection, - }, - "triples": serialize_subgraph(message.triples), - } - -def serialize_graph_embeddings(message): - return { - "metadata": { - "id": message.metadata.id, - "metadata": serialize_subgraph(message.metadata.metadata), - "user": message.metadata.user, - "collection": message.metadata.collection, - }, - "vectors": message.vectors, - "entity": message.entity, - } - class Api: def __init__(self, **config): @@ -247,93 +66,47 @@ class Api: self.timeout = int(config.get("timeout", default_timeout)) self.pulsar_host = config.get("pulsar_host", default_pulsar_host) - self.llm_out = Publisher( - self.pulsar_host, text_completion_request_queue, - schema=JsonSchema(TextCompletionRequest) - ) - - self.llm_in = Subscriber( - self.pulsar_host, text_completion_response_queue, - "api-gateway", "api-gateway", - JsonSchema(TextCompletionResponse) - ) - - self.prompt_out = Publisher( - self.pulsar_host, prompt_request_queue, - schema=JsonSchema(PromptRequest) - ) - - self.prompt_in = Subscriber( - self.pulsar_host, prompt_response_queue, - "api-gateway", "api-gateway", - JsonSchema(PromptResponse) - ) - - self.graph_rag_out = Publisher( - self.pulsar_host, graph_rag_request_queue, - schema=JsonSchema(GraphRagQuery) - ) - - self.graph_rag_in = Subscriber( - self.pulsar_host, graph_rag_response_queue, - "api-gateway", "api-gateway", - JsonSchema(GraphRagResponse) - ) - - self.triples_query_out = Publisher( - self.pulsar_host, triples_request_queue, - schema=JsonSchema(TriplesQueryRequest) - ) - - self.triples_query_in = Subscriber( - self.pulsar_host, triples_response_queue, - "api-gateway", "api-gateway", - JsonSchema(TriplesQueryResponse) - ) - - self.agent_out = Publisher( - self.pulsar_host, agent_request_queue, - schema=JsonSchema(AgentRequest) - ) - - self.agent_in = Subscriber( - self.pulsar_host, agent_response_queue, - "api-gateway", "api-gateway", - JsonSchema(AgentResponse) - ) - - self.embeddings_out = Publisher( - self.pulsar_host, embeddings_request_queue, - schema=JsonSchema(EmbeddingsRequest) - ) - - self.embeddings_in = Subscriber( - self.pulsar_host, embeddings_response_queue, - "api-gateway", "api-gateway", - JsonSchema(EmbeddingsResponse) - ) - - self.triples_tap = Subscriber( - self.pulsar_host, triples_store_queue, - "api-gateway", "api-gateway", - schema=JsonSchema(Triples) - ) - - self.triples_pub = Publisher( - self.pulsar_host, triples_store_queue, - schema=JsonSchema(Triples) - ) - - self.graph_embeddings_tap = Subscriber( - self.pulsar_host, graph_embeddings_store_queue, - "api-gateway", "api-gateway", - schema=JsonSchema(GraphEmbeddings) - ) - - self.graph_embeddings_pub = Publisher( - self.pulsar_host, graph_embeddings_store_queue, - schema=JsonSchema(GraphEmbeddings) - ) + self.endpoints = [ + TextCompletionEndpoint( + pulsar_host=self.pulsar_host, timeout=self.timeout, + ), + PromptEndpoint( + pulsar_host=self.pulsar_host, timeout=self.timeout, + ), + GraphRagEndpoint( + pulsar_host=self.pulsar_host, timeout=self.timeout, + ), + TriplesQueryEndpoint( + pulsar_host=self.pulsar_host, timeout=self.timeout, + ), + EmbeddingsEndpoint( + pulsar_host=self.pulsar_host, timeout=self.timeout, + ), + AgentEndpoint( + pulsar_host=self.pulsar_host, timeout=self.timeout, + ), + EncyclopediaEndpoint( + pulsar_host=self.pulsar_host, timeout=self.timeout, + ), + DbpediaEndpoint( + pulsar_host=self.pulsar_host, timeout=self.timeout, + ), + InternetSearchEndpoint( + pulsar_host=self.pulsar_host, timeout=self.timeout, + ), + TriplesStreamEndpoint( + pulsar_host=self.pulsar_host + ), + GraphEmbeddingsStreamEndpoint( + pulsar_host=self.pulsar_host + ), + TriplesLoadEndpoint( + pulsar_host=self.pulsar_host + ), + GraphEmbeddingsLoadEndpoint( + pulsar_host=self.pulsar_host + ), + ] self.document_out = Publisher( self.pulsar_host, document_ingest_queue, @@ -347,323 +120,14 @@ class Api: chunking_enabled=True, ) + for ep in self.endpoints: + ep.add_routes(self.app) + self.app.add_routes([ - web.post("/api/v1/text-completion", self.llm), - web.post("/api/v1/prompt", self.prompt), - web.post("/api/v1/graph-rag", self.graph_rag), - web.post("/api/v1/triples-query", self.triples_query), - web.post("/api/v1/agent", self.agent), - web.post("/api/v1/embeddings", self.embeddings), web.post("/api/v1/load/document", self.load_document), web.post("/api/v1/load/text", self.load_text), - web.get("/api/v1/ws", self.socket), - - web.get("/api/v1/stream/triples", self.stream_triples), - web.get( - "/api/v1/stream/graph-embeddings", - self.stream_graph_embeddings - ), - - web.get("/api/v1/load/triples", self.load_triples), - web.get( - "/api/v1/load/graph-embeddings", - self.load_graph_embeddings - ), - ]) - async def llm(self, request): - - id = str(uuid.uuid4()) - - try: - - data = await request.json() - - q = await self.llm_in.subscribe(id) - - await self.llm_out.send( - id, - TextCompletionRequest( - system=data["system"], - prompt=data["prompt"] - ) - ) - - try: - resp = await asyncio.wait_for(q.get(), self.timeout) - except: - raise RuntimeError("Timeout waiting for response") - - if resp.error: - return web.json_response( - { "error": resp.error.message } - ) - - return web.json_response( - { "response": resp.response } - ) - - except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - - finally: - await self.llm_in.unsubscribe(id) - - async def prompt(self, request): - - id = str(uuid.uuid4()) - - try: - - data = await request.json() - - q = await self.prompt_in.subscribe(id) - - terms = { - k: json.dumps(v) - for k, v in data["variables"].items() - } - - await self.prompt_out.send( - id, - PromptRequest( - id=data["id"], - terms=terms - ) - ) - - try: - resp = await asyncio.wait_for(q.get(), self.timeout) - except: - raise RuntimeError("Timeout waiting for response") - - if resp.error: - return web.json_response( - { "error": resp.error.message } - ) - - if resp.object: - return web.json_response( - { "object": resp.object } - ) - - return web.json_response( - { "text": resp.text } - ) - - except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - - finally: - await self.prompt_in.unsubscribe(id) - - async def graph_rag(self, request): - - id = str(uuid.uuid4()) - - try: - - data = await request.json() - - q = await self.graph_rag_in.subscribe(id) - - await self.graph_rag_out.send( - id, - GraphRagQuery( - query=data["query"], - user=data.get("user", "trustgraph"), - collection=data.get("collection", "default"), - ) - ) - - try: - resp = await asyncio.wait_for(q.get(), self.timeout) - except: - raise RuntimeError("Timeout waiting for response") - - if resp.error: - return web.json_response( - { "error": resp.error.message } - ) - - return web.json_response( - { "response": resp.response } - ) - - except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - - finally: - await self.graph_rag_in.unsubscribe(id) - - async def triples_query(self, request): - - id = str(uuid.uuid4()) - - try: - - data = await request.json() - - q = await self.triples_query_in.subscribe(id) - - if "s" in data: - s = to_value(data["s"]) - else: - s = None - - if "p" in data: - p = to_value(data["p"]) - else: - p = None - - if "o" in data: - o = to_value(data["o"]) - else: - o = None - - limit = int(data.get("limit", 10000)) - - await self.triples_query_out.send( - id, - TriplesQueryRequest( - s = s, p = p, o = o, - limit = limit, - user = data.get("user", "trustgraph"), - collection = data.get("collection", "default"), - ) - ) - - try: - resp = await asyncio.wait_for(q.get(), self.timeout) - except: - raise RuntimeError("Timeout waiting for response") - - if resp.error: - return web.json_response( - { "error": resp.error.message } - ) - - return web.json_response( - { - "response": serialize_subgraph(resp.triples), - } - ) - - except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - - finally: - await self.graph_rag_in.unsubscribe(id) - - async def agent(self, request): - - id = str(uuid.uuid4()) - - try: - - data = await request.json() - - q = await self.agent_in.subscribe(id) - - await self.agent_out.send( - id, - AgentRequest( - question=data["question"], - ) - ) - - while True: - try: - resp = await asyncio.wait_for(q.get(), self.timeout) - except: - raise RuntimeError("Timeout waiting for response") - - if resp.error: - return web.json_response( - { "error": resp.error.message } - ) - - if resp.answer: break - - if resp.thought: print("thought:", resp.thought) - if resp.observation: print("observation:", resp.observation) - - if resp.answer: - return web.json_response( - { "answer": resp.answer } - ) - - # Can't happen, ook at the logic - raise RuntimeError("Strange state") - - except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - - finally: - await self.agent_in.unsubscribe(id) - - async def embeddings(self, request): - - id = str(uuid.uuid4()) - - try: - - data = await request.json() - - q = await self.embeddings_in.subscribe(id) - - await self.embeddings_out.send( - id, - EmbeddingsRequest( - text=data["text"], - ) - ) - - try: - resp = await asyncio.wait_for(q.get(), self.timeout) - except: - raise RuntimeError("Timeout waiting for response") - - if resp.error: - return web.json_response( - { "error": resp.error.message } - ) - - return web.json_response( - { "vectors": resp.vectors } - ) - - except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - - finally: - await self.embeddings_in.unsubscribe(id) - async def load_document(self, request): try: @@ -750,215 +214,12 @@ class Api: { "error": str(e) } ) - async def socket(self, request): - - ws = web.WebSocketResponse() - await ws.prepare(request) - - async for msg in ws: - if msg.type == WSMsgType.TEXT: - if msg.data == 'close': - await ws.close() - else: - await ws.send_str(msg.data + '/answer') - elif msg.type == WSMsgType.ERROR: - print('ws connection closed with exception %s' % - ws.exception()) - - print('websocket connection closed') - - return ws - - async def stream(self, q, ws, running, fn): - - while running.get(): - try: - resp = await asyncio.wait_for(q.get(), 0.5) - await ws.send_json(fn(resp)) - - except TimeoutError: - continue - - except Exception as e: - print(f"Exception: {str(e)}", flush=True) - - async def stream_triples(self, request): - - id = str(uuid.uuid4()) - - q = await self.triples_tap.subscribe_all(id) - running = Running() - - ws = web.WebSocketResponse() - await ws.prepare(request) - - tsk = asyncio.create_task(self.stream( - q, - ws, - running, - serialize_triples, - )) - - async for msg in ws: - if msg.type == WSMsgType.ERROR: - break - else: - # Ignore incoming messages - pass - - running.stop() - - await self.triples_tap.unsubscribe_all(id) - await tsk - - return ws - - async def stream_graph_embeddings(self, request): - - id = str(uuid.uuid4()) - - q = await self.graph_embeddings_tap.subscribe_all(id) - running = Running() - - ws = web.WebSocketResponse() - await ws.prepare(request) - - tsk = asyncio.create_task(self.stream( - q, - ws, - running, - serialize_graph_embeddings, - )) - - async for msg in ws: - if msg.type == WSMsgType.ERROR: - break - else: - # Ignore incoming messages - pass - - running.stop() - - await self.graph_embeddings_tap.unsubscribe_all(id) - await tsk - - return ws - - async def load_triples(self, request): - - ws = web.WebSocketResponse() - await ws.prepare(request) - - async for msg in ws: - - try: - - if msg.type == WSMsgType.TEXT: - - data = msg.json() - - elt = Triples( - metadata=Metadata( - id=data["metadata"]["id"], - metadata=to_subgraph(data["metadata"]["metadata"]), - user=data["metadata"]["user"], - collection=data["metadata"]["collection"], - ), - triples=to_subgraph(data["triples"]), - ) - - await self.triples_pub.send(None, elt) - - elif msg.type == WSMsgType.ERROR: - break - - except Exception as e: - - print("Exception:", e) - - return ws - - async def load_graph_embeddings(self, request): - - ws = web.WebSocketResponse() - await ws.prepare(request) - - async for msg in ws: - - try: - - if msg.type == WSMsgType.TEXT: - - data = msg.json() - - elt = GraphEmbeddings( - metadata=Metadata( - id=data["metadata"]["id"], - metadata=to_subgraph(data["metadata"]["metadata"]), - user=data["metadata"]["user"], - collection=data["metadata"]["collection"], - ), - entity=to_value(data["entity"]), - vectors=data["vectors"], - ) - - await self.graph_embeddings_pub.send(None, elt) - - elif msg.type == WSMsgType.ERROR: - break - - except Exception as e: - - print("Exception:", e) - - return ws - async def app_factory(self): - self.llm_pub_task = asyncio.create_task(self.llm_in.run()) - self.llm_sub_task = asyncio.create_task(self.llm_out.run()) - - self.prompt_pub_task = asyncio.create_task(self.prompt_in.run()) - self.prompt_sub_task = asyncio.create_task(self.prompt_out.run()) - - self.graph_rag_pub_task = asyncio.create_task(self.graph_rag_in.run()) - self.graph_rag_sub_task = asyncio.create_task(self.graph_rag_out.run()) - - self.triples_query_pub_task = asyncio.create_task( - self.triples_query_in.run() - ) - self.triples_query_sub_task = asyncio.create_task( - self.triples_query_out.run() - ) - - self.agent_pub_task = asyncio.create_task(self.agent_in.run()) - self.agent_sub_task = asyncio.create_task(self.agent_out.run()) - - self.embeddings_pub_task = asyncio.create_task( - self.embeddings_in.run() - ) - self.embeddings_sub_task = asyncio.create_task( - self.embeddings_out.run() - ) - - self.triples_tap_task = asyncio.create_task( - self.triples_tap.run() - ) - - self.triples_pub_task = asyncio.create_task( - self.triples_pub.run() - ) - - self.graph_embeddings_tap_task = asyncio.create_task( - self.graph_embeddings_tap.run() - ) - - self.graph_embeddings_pub_task = asyncio.create_task( - self.graph_embeddings_pub.run() - ) + for ep in self.endpoints: + await ep.start() self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run()) - self.text_ingest_pub_task = asyncio.create_task(self.text_out.run()) return self.app diff --git a/trustgraph-flow/trustgraph/api/gateway/socket.py b/trustgraph-flow/trustgraph/api/gateway/socket.py new file mode 100644 index 00000000..235bfd21 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/socket.py @@ -0,0 +1,68 @@ + +import asyncio +from aiohttp import web, WSMsgType +import logging + +from . running import Running + +logger = logging.getLogger("socket") +logger.setLevel(logging.INFO) + +class SocketEndpoint: + + def __init__( + self, + endpoint_path="/api/v1/socket", + ): + + self.path = endpoint_path + + async def listener(self, ws, running): + + async for msg in ws: + # On error, finish + if msg.type == WSMsgType.ERROR: + break + else: + # Ignore incoming messages + pass + + running.stop() + + async def async_thread(self, ws, running): + + while running.get(): + try: + await asyncio.sleep(1) + + except TimeoutError: + continue + + except Exception as e: + print(f"Exception: {str(e)}", flush=True) + + async def handle(self, request): + + running = Running() + ws = web.WebSocketResponse() + await ws.prepare(request) + + task = asyncio.create_task(self.async_thread(ws, running)) + + await self.listener(ws, running) + + await task + + running.stop() + + return ws + + async def start(self): + pass + + def add_routes(self, app): + + app.add_routes([ + web.get(self.path, self.handle), + ]) + diff --git a/trustgraph-flow/trustgraph/api/gateway/subscriber.py b/trustgraph-flow/trustgraph/api/gateway/subscriber.py new file mode 100644 index 00000000..3d8840f6 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/subscriber.py @@ -0,0 +1,68 @@ + +import asyncio +import aiopulsar + +class Subscriber: + + def __init__(self, pulsar_host, topic, subscription, consumer_name, + schema=None, max_size=10): + self.pulsar_host = pulsar_host + self.topic = topic + self.subscription = subscription + self.consumer_name = consumer_name + self.schema = schema + self.q = {} + self.full = {} + + async def run(self): + while True: + try: + async with aiopulsar.connect(self.pulsar_host) as client: + async with client.subscribe( + topic=self.topic, + subscription_name=self.subscription, + consumer_name=self.consumer_name, + schema=self.schema, + ) as consumer: + while True: + msg = await consumer.receive() + + # Acknowledge successful reception of the message + await consumer.acknowledge(msg) + + try: + id = msg.properties()["id"] + except: + id = None + + value = msg.value() + if id in self.q: + await self.q[id].put(value) + + for q in self.full.values(): + await q.put(value) + + except Exception as e: + print("Exception:", e, flush=True) + + # If handler drops out, sleep a retry + await asyncio.sleep(2) + + async def subscribe(self, id): + q = asyncio.Queue() + self.q[id] = q + return q + + async def unsubscribe(self, id): + if id in self.q: + del self.q[id] + + async def subscribe_all(self, id): + q = asyncio.Queue() + self.full[id] = q + return q + + async def unsubscribe_all(self, id): + if id in self.full: + del self.full[id] + diff --git a/trustgraph-flow/trustgraph/api/gateway/text_completion.py b/trustgraph-flow/trustgraph/api/gateway/text_completion.py new file mode 100644 index 00000000..04dbc9c8 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/text_completion.py @@ -0,0 +1,28 @@ + +from ... schema import TextCompletionRequest, TextCompletionResponse +from ... schema import text_completion_request_queue +from ... schema import text_completion_response_queue + +from . endpoint import ServiceEndpoint + +class TextCompletionEndpoint(ServiceEndpoint): + def __init__(self, pulsar_host, timeout): + + super(TextCompletionEndpoint, self).__init__( + pulsar_host=pulsar_host, + request_queue=text_completion_request_queue, + response_queue=text_completion_response_queue, + request_schema=TextCompletionRequest, + response_schema=TextCompletionResponse, + endpoint_path="/api/v1/text-completion", + timeout=timeout, + ) + + def to_request(self, body): + return TextCompletionRequest( + system=body["system"], + prompt=body["prompt"] + ) + + def from_response(self, message): + return { "response": message.response } diff --git a/trustgraph-flow/trustgraph/api/gateway/triples_load.py b/trustgraph-flow/trustgraph/api/gateway/triples_load.py new file mode 100644 index 00000000..d835a363 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/triples_load.py @@ -0,0 +1,59 @@ + +import asyncio +from pulsar.schema import JsonSchema +import uuid +from aiohttp import WSMsgType + +from ... schema import Metadata +from ... schema import Triples +from ... schema import triples_store_queue + +from . publisher import Publisher +from . socket import SocketEndpoint +from . serialize import to_subgraph + +class TriplesLoadEndpoint(SocketEndpoint): + + def __init__(self, pulsar_host, path="/api/v1/load/triples"): + + super(TriplesLoadEndpoint, self).__init__( + endpoint_path=path + ) + + self.pulsar_host=pulsar_host + + self.publisher = Publisher( + self.pulsar_host, triples_store_queue, + schema=JsonSchema(Triples) + ) + + async def start(self): + + self.task = asyncio.create_task( + self.publisher.run() + ) + + async def listener(self, ws, running): + + async for msg in ws: + # On error, finish + if msg.type == WSMsgType.ERROR: + break + else: + + data = msg.json() + + elt = Triples( + metadata=Metadata( + id=data["metadata"]["id"], + metadata=to_subgraph(data["metadata"]["metadata"]), + user=data["metadata"]["user"], + collection=data["metadata"]["collection"], + ), + triples=to_subgraph(data["triples"]), + ) + + await self.publisher.send(None, elt) + + + running.stop() diff --git a/trustgraph-flow/trustgraph/api/gateway/triples_query.py b/trustgraph-flow/trustgraph/api/gateway/triples_query.py new file mode 100644 index 00000000..8b4192d8 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/triples_query.py @@ -0,0 +1,53 @@ + +from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples +from ... schema import triples_request_queue +from ... schema import triples_response_queue + +from . endpoint import ServiceEndpoint +from . serialize import to_value, serialize_subgraph + +class TriplesQueryEndpoint(ServiceEndpoint): + def __init__(self, pulsar_host, timeout): + + super(TriplesQueryEndpoint, self).__init__( + pulsar_host=pulsar_host, + request_queue=triples_request_queue, + response_queue=triples_response_queue, + request_schema=TriplesQueryRequest, + response_schema=TriplesQueryResponse, + endpoint_path="/api/v1/triples-query", + timeout=timeout, + ) + + def to_request(self, body): + + if "s" in body: + s = to_value(body["s"]) + else: + s = None + + if "p" in body: + p = to_value(body["p"]) + else: + p = None + + if "o" in body: + o = to_value(body["o"]) + else: + o = None + + limit = int(body.get("limit", 10000)) + + return TriplesQueryRequest( + s = s, p = p, o = o, + limit = limit, + user = body.get("user", "trustgraph"), + collection = body.get("collection", "default"), + ) + + def from_response(self, message): + print(message) + return { + "response": serialize_subgraph(message.triples) + } + diff --git a/trustgraph-flow/trustgraph/api/gateway/triples_stream.py b/trustgraph-flow/trustgraph/api/gateway/triples_stream.py new file mode 100644 index 00000000..e8b538a4 --- /dev/null +++ b/trustgraph-flow/trustgraph/api/gateway/triples_stream.py @@ -0,0 +1,56 @@ + +import asyncio +from pulsar.schema import JsonSchema +import uuid + +from ... schema import Triples +from ... schema import triples_store_queue + +from . subscriber import Subscriber +from . socket import SocketEndpoint +from . serialize import serialize_triples + +class TriplesStreamEndpoint(SocketEndpoint): + + def __init__(self, pulsar_host, path="/api/v1/stream/triples"): + + super(TriplesStreamEndpoint, self).__init__( + endpoint_path=path + ) + + self.pulsar_host=pulsar_host + + self.subscriber = Subscriber( + self.pulsar_host, triples_store_queue, + "api-gateway", "api-gateway", + schema=JsonSchema(Triples) + ) + + async def start(self): + + self.task = asyncio.create_task( + self.subscriber.run() + ) + + async def async_thread(self, ws, running): + + id = str(uuid.uuid4()) + + q = await self.subscriber.subscribe_all(id) + + while running.get(): + try: + resp = await asyncio.wait_for(q.get(), 0.5) + await ws.send_json(serialize_triples(resp)) + + except TimeoutError: + continue + + except Exception as e: + print(f"Exception: {str(e)}", flush=True) + break + + await self.subscriber.unsubscribe_all(id) + + running.stop() + diff --git a/trustgraph-flow/trustgraph/external/__init__.py b/trustgraph-flow/trustgraph/external/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trustgraph-flow/trustgraph/external/wikipedia/__init__.py b/trustgraph-flow/trustgraph/external/wikipedia/__init__.py new file mode 100644 index 00000000..ba844705 --- /dev/null +++ b/trustgraph-flow/trustgraph/external/wikipedia/__init__.py @@ -0,0 +1,3 @@ + +from . service import * + diff --git a/trustgraph-flow/trustgraph/external/wikipedia/__main__.py b/trustgraph-flow/trustgraph/external/wikipedia/__main__.py new file mode 100644 index 00000000..e9136855 --- /dev/null +++ b/trustgraph-flow/trustgraph/external/wikipedia/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . service import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/external/wikipedia/service.py b/trustgraph-flow/trustgraph/external/wikipedia/service.py new file mode 100644 index 00000000..932e1213 --- /dev/null +++ b/trustgraph-flow/trustgraph/external/wikipedia/service.py @@ -0,0 +1,102 @@ + +""" +Wikipedia lookup service. Fetchs an extract from the Wikipedia page +using the API. +""" + +from trustgraph.schema import LookupRequest, LookupResponse, Error +from trustgraph.schema import encyclopedia_lookup_request_queue +from trustgraph.schema import encyclopedia_lookup_response_queue +from trustgraph.log_level import LogLevel +from trustgraph.base import ConsumerProducer +import requests + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = encyclopedia_lookup_request_queue +default_output_queue = encyclopedia_lookup_response_queue +default_subscriber = module +default_url="https://en.wikipedia.org/" + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + url = params.get("url", default_url) + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": LookupRequest, + "output_schema": LookupResponse, + } + ) + + self.url = url + + def handle(self, msg): + + v = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + print(f"Handling {v.kind} / {v.term}...", flush=True) + + try: + + url = f"{self.url}/api/rest_v1/page/summary/{v.term}" + + resp = Result = requests.get(url).json() + resp = resp["extract"] + + r = LookupResponse( + error=None, + text=resp + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + return + + except Exception as e: + + r = LookupResponse( + error=Error( + type = "lookup-error", + message = str(e), + ), + text=None, + ) + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + return + + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'LLM model (default: {default_url})' + ) + +def run(): + + Processor.start(module, __doc__) +