mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
Feature/streaming llm phase 1 (#566)
* Tidy up duplicate tech specs in doc directory * Streaming LLM text-completion service tech spec. * text-completion and prompt interfaces * streaming change applied to all LLMs, so far tested with VertexAI * Skip Pinecone unit tests, upstream module issue is affecting things, tests are passing again * Added agent streaming, not working and has broken tests
This commit is contained in:
parent
943a9d83b0
commit
310a2deb06
44 changed files with 2684 additions and 937 deletions
|
|
@ -29,7 +29,7 @@ def output(text, prefix="> ", width=78):
|
|||
|
||||
async def question(
|
||||
url, question, flow_id, user, collection,
|
||||
plan=None, state=None, group=None, verbose=False
|
||||
plan=None, state=None, group=None, verbose=False, streaming=True
|
||||
):
|
||||
|
||||
if not url.endswith("/"):
|
||||
|
|
@ -62,16 +62,17 @@ async def question(
|
|||
"request": {
|
||||
"question": question,
|
||||
"user": user,
|
||||
"history": []
|
||||
"history": [],
|
||||
"streaming": streaming
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Only add optional fields if they have values
|
||||
if state is not None:
|
||||
req["request"]["state"] = state
|
||||
if group is not None:
|
||||
req["request"]["group"] = group
|
||||
|
||||
|
||||
req = json.dumps(req)
|
||||
|
||||
await ws.send(req)
|
||||
|
|
@ -89,14 +90,34 @@ async def question(
|
|||
print("Ignore message")
|
||||
continue
|
||||
|
||||
if "thought" in obj["response"]:
|
||||
think(obj["response"]["thought"])
|
||||
response = obj["response"]
|
||||
|
||||
if "observation" in obj["response"]:
|
||||
observe(obj["response"]["observation"])
|
||||
# Handle streaming format (new format with chunk_type)
|
||||
if "chunk_type" in response:
|
||||
chunk_type = response["chunk_type"]
|
||||
content = response.get("content", "")
|
||||
|
||||
if "answer" in obj["response"]:
|
||||
print(obj["response"]["answer"])
|
||||
if chunk_type == "thought":
|
||||
think(content)
|
||||
elif chunk_type == "observation":
|
||||
observe(content)
|
||||
elif chunk_type == "answer":
|
||||
print(content)
|
||||
elif chunk_type == "error":
|
||||
raise RuntimeError(content)
|
||||
else:
|
||||
# Handle legacy format (backward compatibility)
|
||||
if "thought" in response:
|
||||
think(response["thought"])
|
||||
|
||||
if "observation" in response:
|
||||
observe(response["observation"])
|
||||
|
||||
if "answer" in response:
|
||||
print(response["answer"])
|
||||
|
||||
if "error" in response:
|
||||
raise RuntimeError(response["error"])
|
||||
|
||||
if obj["complete"]: break
|
||||
|
||||
|
|
@ -161,6 +182,12 @@ def main():
|
|||
help=f'Output thinking/observations'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-streaming',
|
||||
action="store_true",
|
||||
help=f'Disable streaming (use legacy mode)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -176,6 +203,7 @@ def main():
|
|||
state = args.state,
|
||||
group = args.group,
|
||||
verbose = args.verbose,
|
||||
streaming = not args.no_streaming,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,17 +6,63 @@ and user prompt. Both arguments are required.
|
|||
import argparse
|
||||
import os
|
||||
import json
|
||||
from trustgraph.api import Api
|
||||
import uuid
|
||||
import asyncio
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
|
||||
|
||||
def query(url, flow_id, system, prompt):
|
||||
async def query(url, flow_id, system, prompt, streaming=True):
|
||||
|
||||
api = Api(url).flow().id(flow_id)
|
||||
if not url.endswith("/"):
|
||||
url += "/"
|
||||
|
||||
resp = api.text_completion(system=system, prompt=prompt)
|
||||
url = url + "api/v1/socket"
|
||||
|
||||
print(resp)
|
||||
mid = str(uuid.uuid4())
|
||||
|
||||
async with connect(url) as ws:
|
||||
|
||||
req = {
|
||||
"id": mid,
|
||||
"service": "text-completion",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"system": system,
|
||||
"prompt": prompt,
|
||||
"streaming": streaming
|
||||
}
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(req))
|
||||
|
||||
while True:
|
||||
|
||||
msg = await ws.recv()
|
||||
|
||||
obj = json.loads(msg)
|
||||
|
||||
if "error" in obj:
|
||||
raise RuntimeError(obj["error"])
|
||||
|
||||
if obj["id"] != mid:
|
||||
continue
|
||||
|
||||
if "response" in obj["response"]:
|
||||
if streaming:
|
||||
# Stream output to stdout without newline
|
||||
print(obj["response"]["response"], end="", flush=True)
|
||||
else:
|
||||
# Non-streaming: print complete response
|
||||
print(obj["response"]["response"])
|
||||
|
||||
if obj["complete"]:
|
||||
if streaming:
|
||||
# Add final newline after streaming
|
||||
print()
|
||||
break
|
||||
|
||||
await ws.close()
|
||||
|
||||
def main():
|
||||
|
||||
|
|
@ -49,16 +95,23 @@ def main():
|
|||
help=f'Flow ID (default: default)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-streaming',
|
||||
action='store_true',
|
||||
help='Disable streaming (default: streaming enabled)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
query(
|
||||
asyncio.run(query(
|
||||
url=args.url,
|
||||
flow_id = args.flow_id,
|
||||
flow_id=args.flow_id,
|
||||
system=args.system[0],
|
||||
prompt=args.prompt[0],
|
||||
)
|
||||
streaming=not args.no_streaming
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -10,20 +10,76 @@ using key=value arguments on the command line, and these replace
|
|||
import argparse
|
||||
import os
|
||||
import json
|
||||
from trustgraph.api import Api
|
||||
import uuid
|
||||
import asyncio
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
|
||||
|
||||
def query(url, flow_id, template_id, variables):
|
||||
async def query(url, flow_id, template_id, variables, streaming=True):
|
||||
|
||||
api = Api(url).flow().id(flow_id)
|
||||
if not url.endswith("/"):
|
||||
url += "/"
|
||||
|
||||
resp = api.prompt(id=template_id, variables=variables)
|
||||
url = url + "api/v1/socket"
|
||||
|
||||
if isinstance(resp, str):
|
||||
print(resp)
|
||||
else:
|
||||
print(json.dumps(resp, indent=4))
|
||||
mid = str(uuid.uuid4())
|
||||
|
||||
async with connect(url) as ws:
|
||||
|
||||
req = {
|
||||
"id": mid,
|
||||
"service": "prompt",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"id": template_id,
|
||||
"variables": variables,
|
||||
"streaming": streaming
|
||||
}
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(req))
|
||||
|
||||
full_response = {"text": "", "object": ""}
|
||||
|
||||
while True:
|
||||
|
||||
msg = await ws.recv()
|
||||
|
||||
obj = json.loads(msg)
|
||||
|
||||
if "error" in obj:
|
||||
raise RuntimeError(obj["error"])
|
||||
|
||||
if obj["id"] != mid:
|
||||
continue
|
||||
|
||||
response = obj["response"]
|
||||
|
||||
# Handle text responses (streaming)
|
||||
if "text" in response and response["text"]:
|
||||
if streaming:
|
||||
# Stream output to stdout without newline
|
||||
print(response["text"], end="", flush=True)
|
||||
full_response["text"] += response["text"]
|
||||
else:
|
||||
# Non-streaming: print complete response
|
||||
print(response["text"])
|
||||
|
||||
# Handle object responses (JSON, never streamed)
|
||||
if "object" in response and response["object"]:
|
||||
full_response["object"] = response["object"]
|
||||
|
||||
if obj["complete"]:
|
||||
if streaming and full_response["text"]:
|
||||
# Add final newline after streaming text
|
||||
print()
|
||||
elif full_response["object"]:
|
||||
# Print JSON object (pretty-printed)
|
||||
print(json.dumps(json.loads(full_response["object"]), indent=4))
|
||||
break
|
||||
|
||||
await ws.close()
|
||||
|
||||
def main():
|
||||
|
||||
|
|
@ -59,6 +115,12 @@ def main():
|
|||
specified multiple times''',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-streaming',
|
||||
action='store_true',
|
||||
help='Disable streaming (default: streaming enabled for text responses)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
variables = {}
|
||||
|
|
@ -73,12 +135,13 @@ specified multiple times''',
|
|||
|
||||
try:
|
||||
|
||||
query(
|
||||
asyncio.run(query(
|
||||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
template_id=args.id[0],
|
||||
variables=variables,
|
||||
)
|
||||
streaming=not args.no_streaming
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue