mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-10 15:25:14 +02:00
Bulk clients (sync and async) were not forwarding the workspace parameter, causing all bulk operations to hit the default workspace regardless of the Api instance's workspace setting. Also fixes the gateway socket endpoint to pass query parameters (including workspace) to the dispatcher, and prevents the auth handshake from overwriting an explicitly set workspace. Updates knowledge table store tests for paged query interface.
1037 lines
34 KiB
Python
1037 lines
34 KiB
Python
"""
|
|
TrustGraph Synchronous WebSocket Client
|
|
|
|
This module provides synchronous WebSocket-based access to TrustGraph services with
|
|
streaming support for real-time responses from agents, RAG queries, and text completions.
|
|
|
|
Uses a persistent WebSocket connection with a background reader task that
|
|
multiplexes requests by ID.
|
|
"""
|
|
|
|
import json
|
|
import asyncio
|
|
import websockets
|
|
from websockets.exceptions import ConnectionClosed
|
|
from typing import Optional, Dict, Any, Iterator, Union, List
|
|
from threading import Lock
|
|
|
|
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent, TextCompletionResult
|
|
from . exceptions import ProtocolException, raise_from_error_dict
|
|
|
|
|
|
def build_term(value: Any, term_type: Optional[str] = None,
|
|
datatype: Optional[str] = None, language: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Build wire-format Term dict from a value.
|
|
|
|
Auto-detection rules (when term_type is None):
|
|
- Already a dict with 't' key -> return as-is (already a Term)
|
|
- Starts with http://, https://, urn: -> IRI
|
|
- Wrapped in <> (e.g., <http://...>) -> IRI (angle brackets stripped)
|
|
- Anything else -> literal
|
|
|
|
Args:
|
|
value: The term value (string, dict, or None)
|
|
term_type: One of 'iri', 'literal', or None for auto-detect
|
|
datatype: Datatype for literal objects (e.g., xsd:integer)
|
|
language: Language tag for literal objects (e.g., en)
|
|
|
|
Returns:
|
|
dict: Wire-format Term dict, or None if value is None
|
|
"""
|
|
if value is None:
|
|
return None
|
|
|
|
# If already a Term dict, return as-is
|
|
if isinstance(value, dict) and "t" in value:
|
|
return value
|
|
|
|
# Convert to string for processing
|
|
value = str(value)
|
|
|
|
# Auto-detect type if not specified
|
|
if term_type is None:
|
|
if value.startswith("<") and value.endswith(">") and not value.startswith("<<"):
|
|
# Angle-bracket wrapped IRI: <http://...>
|
|
value = value[1:-1] # Strip < and >
|
|
term_type = "iri"
|
|
elif value.startswith(("http://", "https://", "urn:")):
|
|
term_type = "iri"
|
|
else:
|
|
term_type = "literal"
|
|
|
|
if term_type == "iri":
|
|
# Strip angle brackets if present
|
|
if value.startswith("<") and value.endswith(">"):
|
|
value = value[1:-1]
|
|
return {"t": "i", "i": value}
|
|
elif term_type == "literal":
|
|
result = {"t": "l", "v": value}
|
|
if datatype:
|
|
result["dt"] = datatype
|
|
if language:
|
|
result["ln"] = language
|
|
return result
|
|
else:
|
|
raise ValueError(f"Unknown term type: {term_type}")
|
|
|
|
|
|
class SocketClient:
|
|
"""
|
|
Synchronous WebSocket client with persistent connection.
|
|
|
|
Maintains a single websocket connection and multiplexes requests
|
|
by ID via a background reader task. Provides synchronous generators
|
|
for streaming responses.
|
|
"""
|
|
|
|
def __init__(
|
|
self, url: str, timeout: int, token: Optional[str],
|
|
workspace: str = "default",
|
|
) -> None:
|
|
self.url: str = self._convert_to_ws_url(url)
|
|
self.timeout: int = timeout
|
|
self.token: Optional[str] = token
|
|
self.workspace: str = workspace
|
|
self._request_counter: int = 0
|
|
self._lock: Lock = Lock()
|
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
self._socket = None
|
|
self._connect_cm = None
|
|
self._reader_task = None
|
|
self._pending: Dict[str, asyncio.Queue] = {}
|
|
self._connected: bool = False
|
|
|
|
def _convert_to_ws_url(self, url: str) -> str:
|
|
if url.startswith("http://"):
|
|
return url.replace("http://", "ws://", 1)
|
|
elif url.startswith("https://"):
|
|
return url.replace("https://", "wss://", 1)
|
|
elif url.startswith("ws://") or url.startswith("wss://"):
|
|
return url
|
|
else:
|
|
return f"ws://{url}"
|
|
|
|
def _build_ws_url(self):
|
|
# /api/v1/socket uses the first-frame auth protocol — the
|
|
# token is sent as the first frame after connecting rather
|
|
# than in the URL.
|
|
return f"{self.url.rstrip('/')}/api/v1/socket"
|
|
|
|
def _get_loop(self):
|
|
"""Get or create the event loop, reusing across calls."""
|
|
if self._loop is None or self._loop.is_closed():
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
except RuntimeError:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
self._loop = loop
|
|
return self._loop
|
|
|
|
async def _ensure_connected(self):
|
|
"""Lazily establish the persistent websocket connection and
|
|
run the first-frame auth handshake."""
|
|
if self._connected:
|
|
return
|
|
|
|
ws_url = self._build_ws_url()
|
|
self._connect_cm = websockets.connect(
|
|
ws_url, ping_interval=20, ping_timeout=self.timeout
|
|
)
|
|
self._socket = await self._connect_cm.__aenter__()
|
|
|
|
# First-frame auth — run before starting the reader so the
|
|
# auth-ok / auth-failed response isn't consumed by the reader
|
|
# loop's id-based routing.
|
|
await self._socket.send(json.dumps({
|
|
"type": "auth", "token": self.token or "",
|
|
}))
|
|
try:
|
|
raw = await asyncio.wait_for(
|
|
self._socket.recv(), timeout=self.timeout,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
await self._socket.close()
|
|
raise ProtocolException("Timeout waiting for auth response")
|
|
|
|
try:
|
|
resp = json.loads(raw)
|
|
except Exception:
|
|
await self._socket.close()
|
|
raise ProtocolException(
|
|
f"Unexpected non-JSON auth response: {raw!r}"
|
|
)
|
|
|
|
if resp.get("type") == "auth-ok":
|
|
if self.workspace == "default":
|
|
self.workspace = resp.get("workspace", self.workspace)
|
|
elif resp.get("type") == "auth-failed":
|
|
await self._socket.close()
|
|
raise ProtocolException(
|
|
f"auth failure: {resp.get('error', 'unknown')}"
|
|
)
|
|
else:
|
|
await self._socket.close()
|
|
raise ProtocolException(
|
|
f"Unexpected auth response: {resp!r}"
|
|
)
|
|
|
|
self._connected = True
|
|
self._reader_task = asyncio.create_task(self._reader())
|
|
|
|
async def _reader(self):
|
|
"""Background task to read responses and route by request ID."""
|
|
try:
|
|
async for raw_message in self._socket:
|
|
response = json.loads(raw_message)
|
|
request_id = response.get("id")
|
|
|
|
if request_id and request_id in self._pending:
|
|
await self._pending[request_id].put(response)
|
|
|
|
except ConnectionClosed:
|
|
pass
|
|
except Exception as e:
|
|
for queue in self._pending.values():
|
|
try:
|
|
await queue.put({"error": str(e)})
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
self._connected = False
|
|
|
|
def _next_request_id(self):
|
|
with self._lock:
|
|
self._request_counter += 1
|
|
return f"req-{self._request_counter}"
|
|
|
|
def flow(self, flow_id: str) -> "SocketFlowInstance":
|
|
return SocketFlowInstance(self, flow_id)
|
|
|
|
def _send_request_sync(
|
|
self,
|
|
service: str,
|
|
flow: Optional[str],
|
|
request: Dict[str, Any],
|
|
streaming: bool = False,
|
|
streaming_raw: bool = False,
|
|
include_provenance: bool = False
|
|
) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]:
|
|
"""Synchronous wrapper around async WebSocket communication."""
|
|
loop = self._get_loop()
|
|
|
|
if streaming_raw:
|
|
return self._streaming_generator_raw(service, flow, request, loop)
|
|
elif streaming:
|
|
return self._streaming_generator(service, flow, request, loop, include_provenance)
|
|
else:
|
|
return loop.run_until_complete(self._send_request_async(service, flow, request))
|
|
|
|
def _streaming_generator(
|
|
self,
|
|
service: str,
|
|
flow: Optional[str],
|
|
request: Dict[str, Any],
|
|
loop: asyncio.AbstractEventLoop,
|
|
include_provenance: bool = False
|
|
) -> Iterator[StreamingChunk]:
|
|
"""Generator that yields streaming chunks."""
|
|
async_gen = self._send_request_async_streaming(service, flow, request, include_provenance)
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
chunk = loop.run_until_complete(async_gen.__anext__())
|
|
yield chunk
|
|
except StopAsyncIteration:
|
|
break
|
|
finally:
|
|
try:
|
|
loop.run_until_complete(async_gen.aclose())
|
|
except Exception:
|
|
pass
|
|
|
|
def _streaming_generator_raw(
|
|
self,
|
|
service: str,
|
|
flow: Optional[str],
|
|
request: Dict[str, Any],
|
|
loop: asyncio.AbstractEventLoop
|
|
) -> Iterator[Dict[str, Any]]:
|
|
"""Generator that yields raw response dicts."""
|
|
async_gen = self._send_request_async_streaming_raw(service, flow, request)
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
data = loop.run_until_complete(async_gen.__anext__())
|
|
yield data
|
|
except StopAsyncIteration:
|
|
break
|
|
finally:
|
|
try:
|
|
loop.run_until_complete(async_gen.aclose())
|
|
except Exception:
|
|
pass
|
|
|
|
async def _send_request_async_streaming_raw(
|
|
self,
|
|
service: str,
|
|
flow: Optional[str],
|
|
request: Dict[str, Any]
|
|
) -> Iterator[Dict[str, Any]]:
|
|
"""Async streaming that yields raw response dicts."""
|
|
await self._ensure_connected()
|
|
|
|
request_id = self._next_request_id()
|
|
queue = asyncio.Queue()
|
|
self._pending[request_id] = queue
|
|
|
|
try:
|
|
message = {
|
|
"id": request_id,
|
|
"workspace": self.workspace,
|
|
"service": service,
|
|
"request": request
|
|
}
|
|
if flow:
|
|
message["flow"] = flow
|
|
|
|
await self._socket.send(json.dumps(message))
|
|
|
|
while True:
|
|
response = await queue.get()
|
|
|
|
if "error" in response:
|
|
raise_from_error_dict(response["error"])
|
|
|
|
if "response" in response:
|
|
yield response["response"]
|
|
|
|
if response.get("complete"):
|
|
break
|
|
|
|
finally:
|
|
self._pending.pop(request_id, None)
|
|
|
|
async def _send_request_async(
|
|
self,
|
|
service: str,
|
|
flow: Optional[str],
|
|
request: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Async non-streaming request over persistent connection."""
|
|
await self._ensure_connected()
|
|
|
|
request_id = self._next_request_id()
|
|
queue = asyncio.Queue()
|
|
self._pending[request_id] = queue
|
|
|
|
try:
|
|
message = {
|
|
"id": request_id,
|
|
"workspace": self.workspace,
|
|
"service": service,
|
|
"request": request
|
|
}
|
|
if flow:
|
|
message["flow"] = flow
|
|
|
|
await self._socket.send(json.dumps(message))
|
|
|
|
response = await queue.get()
|
|
|
|
if "error" in response:
|
|
raise_from_error_dict(response["error"])
|
|
|
|
if "response" not in response:
|
|
raise ProtocolException("Missing response in message")
|
|
|
|
return response["response"]
|
|
|
|
finally:
|
|
self._pending.pop(request_id, None)
|
|
|
|
async def _send_request_async_streaming(
|
|
self,
|
|
service: str,
|
|
flow: Optional[str],
|
|
request: Dict[str, Any],
|
|
include_provenance: bool = False
|
|
) -> Iterator[StreamingChunk]:
|
|
"""Async streaming request over persistent connection."""
|
|
await self._ensure_connected()
|
|
|
|
request_id = self._next_request_id()
|
|
queue = asyncio.Queue()
|
|
self._pending[request_id] = queue
|
|
|
|
try:
|
|
message = {
|
|
"id": request_id,
|
|
"workspace": self.workspace,
|
|
"service": service,
|
|
"request": request
|
|
}
|
|
if flow:
|
|
message["flow"] = flow
|
|
|
|
await self._socket.send(json.dumps(message))
|
|
|
|
while True:
|
|
response = await queue.get()
|
|
|
|
if "error" in response:
|
|
raise_from_error_dict(response["error"])
|
|
|
|
if "response" in response:
|
|
resp = response["response"]
|
|
|
|
if "error" in resp:
|
|
raise_from_error_dict(resp["error"])
|
|
|
|
chunk = self._parse_chunk(resp, include_provenance=include_provenance)
|
|
if chunk is not None:
|
|
yield chunk
|
|
|
|
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
|
|
break
|
|
|
|
finally:
|
|
self._pending.pop(request_id, None)
|
|
|
|
def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]:
|
|
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
|
|
message_type = resp.get("message_type")
|
|
|
|
if message_type == "explain":
|
|
if include_provenance:
|
|
return self._build_provenance_event(resp)
|
|
return None
|
|
|
|
if message_type == "thought":
|
|
return AgentThought(
|
|
content=resp.get("content", ""),
|
|
end_of_message=resp.get("end_of_message", False),
|
|
message_id=resp.get("message_id", ""),
|
|
)
|
|
elif message_type == "observation":
|
|
return AgentObservation(
|
|
content=resp.get("content", ""),
|
|
end_of_message=resp.get("end_of_message", False),
|
|
message_id=resp.get("message_id", ""),
|
|
)
|
|
elif message_type == "answer" or message_type == "final-answer":
|
|
return AgentAnswer(
|
|
content=resp.get("content", ""),
|
|
end_of_message=resp.get("end_of_message", False),
|
|
end_of_dialog=resp.get("end_of_dialog", False),
|
|
message_id=resp.get("message_id", ""),
|
|
in_token=resp.get("in_token"),
|
|
out_token=resp.get("out_token"),
|
|
model=resp.get("model"),
|
|
)
|
|
elif message_type == "action":
|
|
return AgentThought(
|
|
content=resp.get("content", ""),
|
|
end_of_message=resp.get("end_of_message", False)
|
|
)
|
|
else:
|
|
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
|
return RAGChunk(
|
|
content=content,
|
|
end_of_stream=resp.get("end_of_stream", False),
|
|
error=None,
|
|
in_token=resp.get("in_token"),
|
|
out_token=resp.get("out_token"),
|
|
model=resp.get("model"),
|
|
)
|
|
|
|
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
|
|
"""Build a ProvenanceEvent from a response dict, parsing inline triples
|
|
into an ExplainEntity if available."""
|
|
explain_id = resp.get("explain_id", "")
|
|
explain_graph = resp.get("explain_graph", "")
|
|
raw_triples = resp.get("explain_triples", [])
|
|
|
|
entity = None
|
|
if raw_triples:
|
|
try:
|
|
from .explainability import ExplainEntity
|
|
# Convert wire-format triple dicts to (s, p, o) tuples
|
|
parsed = []
|
|
for t in raw_triples:
|
|
s = t.get("s", {}).get("i", "") if t.get("s") else ""
|
|
p = t.get("p", {}).get("i", "") if t.get("p") else ""
|
|
o_term = t.get("o", {})
|
|
if o_term:
|
|
if o_term.get("t") == "i":
|
|
o = o_term.get("i", "")
|
|
else:
|
|
o = o_term.get("v", "")
|
|
else:
|
|
o = ""
|
|
parsed.append((s, p, o))
|
|
entity = ExplainEntity.from_triples(explain_id, parsed)
|
|
except Exception:
|
|
pass
|
|
|
|
return ProvenanceEvent(
|
|
explain_id=explain_id,
|
|
explain_graph=explain_graph,
|
|
entity=entity,
|
|
triples=raw_triples,
|
|
)
|
|
|
|
def get_kg_core(self, id: str) -> Iterator[Dict[str, Any]]:
|
|
request = {
|
|
"operation": "get-kg-core",
|
|
"workspace": self.workspace,
|
|
"id": id,
|
|
}
|
|
for response in self._send_request_sync(
|
|
"knowledge", None, request, streaming_raw=True,
|
|
):
|
|
if response.get("eos"):
|
|
break
|
|
yield response
|
|
|
|
def put_kg_core(
|
|
self, id: str, triples=None, graph_embeddings=None,
|
|
) -> Dict[str, Any]:
|
|
request = {
|
|
"operation": "put-kg-core",
|
|
"workspace": self.workspace,
|
|
"id": id,
|
|
}
|
|
if triples is not None:
|
|
request["triples"] = triples
|
|
if graph_embeddings is not None:
|
|
request["graph-embeddings"] = graph_embeddings
|
|
return self._send_request_sync("knowledge", None, request)
|
|
|
|
def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]:
|
|
request = {
|
|
"operation": "get-de-core",
|
|
"workspace": self.workspace,
|
|
"id": id,
|
|
}
|
|
for response in self._send_request_sync(
|
|
"knowledge", None, request, streaming_raw=True,
|
|
):
|
|
if response.get("eos"):
|
|
break
|
|
yield response
|
|
|
|
def put_de_core(
|
|
self, id: str, document_embeddings=None,
|
|
) -> Dict[str, Any]:
|
|
request = {
|
|
"operation": "put-de-core",
|
|
"workspace": self.workspace,
|
|
"id": id,
|
|
}
|
|
if document_embeddings is not None:
|
|
request["document-embeddings"] = document_embeddings
|
|
return self._send_request_sync("knowledge", None, request)
|
|
|
|
def close(self) -> None:
|
|
"""Close the persistent WebSocket connection."""
|
|
if self._loop and not self._loop.is_closed():
|
|
try:
|
|
self._loop.run_until_complete(self._close_async())
|
|
except Exception:
|
|
pass
|
|
|
|
async def _close_async(self):
|
|
# Cancel reader task
|
|
if self._reader_task:
|
|
self._reader_task.cancel()
|
|
try:
|
|
await self._reader_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._reader_task = None
|
|
|
|
# Exit the websockets context manager — this cleanly shuts down
|
|
# the connection and its keepalive task
|
|
if self._connect_cm:
|
|
try:
|
|
await self._connect_cm.__aexit__(None, None, None)
|
|
except Exception:
|
|
pass
|
|
self._connect_cm = None
|
|
|
|
self._socket = None
|
|
self._connected = False
|
|
self._pending.clear()
|
|
|
|
|
|
class SocketFlowInstance:
|
|
"""
|
|
Synchronous WebSocket flow instance for streaming operations.
|
|
|
|
Provides the same interface as REST FlowInstance but with WebSocket-based
|
|
streaming support for real-time responses.
|
|
"""
|
|
|
|
def __init__(self, client: SocketClient, flow_id: str) -> None:
|
|
self.client: SocketClient = client
|
|
self.flow_id: str = flow_id
|
|
|
|
def agent(
|
|
self,
|
|
question: str,
|
|
state: Optional[Dict[str, Any]] = None,
|
|
group: Optional[str] = None,
|
|
history: Optional[List[Dict[str, Any]]] = None,
|
|
streaming: bool = False,
|
|
**kwargs: Any
|
|
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
|
|
"""Execute an agent operation with streaming support."""
|
|
request = {
|
|
"question": question,
|
|
"streaming": streaming
|
|
}
|
|
if state is not None:
|
|
request["state"] = state
|
|
if group is not None:
|
|
request["group"] = group
|
|
if history is not None:
|
|
request["history"] = history
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync("agent", self.flow_id, request, streaming=True)
|
|
|
|
def agent_explain(
|
|
self,
|
|
question: str,
|
|
collection: str,
|
|
state: Optional[Dict[str, Any]] = None,
|
|
group: Optional[str] = None,
|
|
history: Optional[List[Dict[str, Any]]] = None,
|
|
**kwargs: Any
|
|
) -> Iterator[Union[StreamingChunk, ProvenanceEvent]]:
|
|
"""Execute an agent operation with explainability support."""
|
|
request = {
|
|
"question": question,
|
|
"collection": collection,
|
|
"streaming": True
|
|
}
|
|
if state is not None:
|
|
request["state"] = state
|
|
if group is not None:
|
|
request["group"] = group
|
|
if history is not None:
|
|
request["history"] = history
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync(
|
|
"agent", self.flow_id, request,
|
|
streaming=True, include_provenance=True
|
|
)
|
|
|
|
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
|
"""Execute text completion with optional streaming.
|
|
|
|
Non-streaming: returns a TextCompletionResult with text and token counts.
|
|
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
|
"""
|
|
request = {
|
|
"system": system,
|
|
"prompt": prompt,
|
|
"streaming": streaming
|
|
}
|
|
request.update(kwargs)
|
|
|
|
result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming)
|
|
|
|
if streaming:
|
|
return self._text_completion_generator(result)
|
|
else:
|
|
return TextCompletionResult(
|
|
text=result.get("response", ""),
|
|
in_token=result.get("in_token"),
|
|
out_token=result.get("out_token"),
|
|
model=result.get("model"),
|
|
)
|
|
|
|
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
|
|
for chunk in result:
|
|
if isinstance(chunk, RAGChunk):
|
|
yield chunk
|
|
|
|
def graph_rag(
|
|
self,
|
|
query: str,
|
|
collection: str,
|
|
entity_limit: int = 50,
|
|
triple_limit: int = 30,
|
|
max_subgraph_size: int = 1000,
|
|
max_path_length: int = 2,
|
|
edge_score_limit: int = 30,
|
|
edge_limit: int = 25,
|
|
streaming: bool = False,
|
|
**kwargs: Any
|
|
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
|
"""Execute graph-based RAG query with optional streaming.
|
|
|
|
Non-streaming: returns a TextCompletionResult with text and token counts.
|
|
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
|
"""
|
|
request = {
|
|
"query": query,
|
|
"collection": collection,
|
|
"entity-limit": entity_limit,
|
|
"triple-limit": triple_limit,
|
|
"max-subgraph-size": max_subgraph_size,
|
|
"max-path-length": max_path_length,
|
|
"edge-score-limit": edge_score_limit,
|
|
"edge-limit": edge_limit,
|
|
"streaming": streaming
|
|
}
|
|
request.update(kwargs)
|
|
|
|
result = self.client._send_request_sync("graph-rag", self.flow_id, request, streaming)
|
|
|
|
if streaming:
|
|
return self._rag_generator(result)
|
|
else:
|
|
return TextCompletionResult(
|
|
text=result.get("response", ""),
|
|
in_token=result.get("in_token"),
|
|
out_token=result.get("out_token"),
|
|
model=result.get("model"),
|
|
)
|
|
|
|
def graph_rag_explain(
|
|
self,
|
|
query: str,
|
|
collection: str,
|
|
entity_limit: int = 50,
|
|
triple_limit: int = 30,
|
|
max_subgraph_size: int = 1000,
|
|
max_path_length: int = 2,
|
|
edge_score_limit: int = 30,
|
|
edge_limit: int = 25,
|
|
**kwargs: Any
|
|
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
|
"""Execute graph-based RAG query with explainability support."""
|
|
request = {
|
|
"query": query,
|
|
"collection": collection,
|
|
"entity-limit": entity_limit,
|
|
"triple-limit": triple_limit,
|
|
"max-subgraph-size": max_subgraph_size,
|
|
"max-path-length": max_path_length,
|
|
"edge-score-limit": edge_score_limit,
|
|
"edge-limit": edge_limit,
|
|
"streaming": True,
|
|
"explainable": True,
|
|
}
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync(
|
|
"graph-rag", self.flow_id, request,
|
|
streaming=True, include_provenance=True
|
|
)
|
|
|
|
def document_rag(
|
|
self,
|
|
query: str,
|
|
collection: str,
|
|
doc_limit: int = 10,
|
|
streaming: bool = False,
|
|
**kwargs: Any
|
|
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
|
"""Execute document-based RAG query with optional streaming.
|
|
|
|
Non-streaming: returns a TextCompletionResult with text and token counts.
|
|
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
|
"""
|
|
request = {
|
|
"query": query,
|
|
"collection": collection,
|
|
"doc-limit": doc_limit,
|
|
"streaming": streaming
|
|
}
|
|
request.update(kwargs)
|
|
|
|
result = self.client._send_request_sync("document-rag", self.flow_id, request, streaming)
|
|
|
|
if streaming:
|
|
return self._rag_generator(result)
|
|
else:
|
|
return TextCompletionResult(
|
|
text=result.get("response", ""),
|
|
in_token=result.get("in_token"),
|
|
out_token=result.get("out_token"),
|
|
model=result.get("model"),
|
|
)
|
|
|
|
def document_rag_explain(
|
|
self,
|
|
query: str,
|
|
collection: str,
|
|
doc_limit: int = 10,
|
|
**kwargs: Any
|
|
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
|
"""Execute document-based RAG query with explainability support."""
|
|
request = {
|
|
"query": query,
|
|
"collection": collection,
|
|
"doc-limit": doc_limit,
|
|
"streaming": True,
|
|
"explainable": True,
|
|
}
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync(
|
|
"document-rag", self.flow_id, request,
|
|
streaming=True, include_provenance=True
|
|
)
|
|
|
|
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
|
|
for chunk in result:
|
|
if isinstance(chunk, RAGChunk):
|
|
yield chunk
|
|
|
|
def prompt(
|
|
self,
|
|
id: str,
|
|
variables: Dict[str, str],
|
|
streaming: bool = False,
|
|
**kwargs: Any
|
|
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
|
"""Execute a prompt template with optional streaming.
|
|
|
|
Non-streaming: returns a TextCompletionResult with text and token counts.
|
|
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
|
"""
|
|
request = {
|
|
"id": id,
|
|
"variables": variables,
|
|
"streaming": streaming
|
|
}
|
|
request.update(kwargs)
|
|
|
|
result = self.client._send_request_sync("prompt", self.flow_id, request, streaming)
|
|
|
|
if streaming:
|
|
return self._rag_generator(result)
|
|
else:
|
|
return TextCompletionResult(
|
|
text=result.get("text", result.get("response", "")),
|
|
in_token=result.get("in_token"),
|
|
out_token=result.get("out_token"),
|
|
model=result.get("model"),
|
|
)
|
|
|
|
def graph_embeddings_query(
|
|
self,
|
|
text: str,
|
|
collection: str,
|
|
limit: int = 10,
|
|
**kwargs: Any
|
|
) -> Dict[str, Any]:
|
|
"""Query knowledge graph entities using semantic similarity."""
|
|
emb_result = self.embeddings(texts=[text])
|
|
vector = emb_result.get("vectors", [[]])[0]
|
|
|
|
request = {
|
|
"vector": vector,
|
|
"collection": collection,
|
|
"limit": limit
|
|
}
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync("graph-embeddings", self.flow_id, request, False)
|
|
|
|
def document_embeddings_query(
|
|
self,
|
|
text: str,
|
|
collection: str,
|
|
limit: int = 10,
|
|
**kwargs: Any
|
|
) -> Dict[str, Any]:
|
|
"""Query document chunks using semantic similarity."""
|
|
emb_result = self.embeddings(texts=[text])
|
|
vector = emb_result.get("vectors", [[]])[0]
|
|
|
|
request = {
|
|
"vector": vector,
|
|
"collection": collection,
|
|
"limit": limit
|
|
}
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync("document-embeddings", self.flow_id, request, False)
|
|
|
|
def embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]:
|
|
"""Generate vector embeddings for one or more texts."""
|
|
request = {"texts": texts}
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
|
|
|
|
def triples_query(
|
|
self,
|
|
s: Optional[Union[str, Dict[str, Any]]] = None,
|
|
p: Optional[Union[str, Dict[str, Any]]] = None,
|
|
o: Optional[Union[str, Dict[str, Any]]] = None,
|
|
g: Optional[str] = None,
|
|
collection: Optional[str] = None,
|
|
limit: int = 100,
|
|
**kwargs: Any
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query knowledge graph triples using pattern matching."""
|
|
request = {"limit": limit}
|
|
|
|
s_term = build_term(s)
|
|
p_term = build_term(p)
|
|
o_term = build_term(o)
|
|
|
|
if s_term is not None:
|
|
request["s"] = s_term
|
|
if p_term is not None:
|
|
request["p"] = p_term
|
|
if o_term is not None:
|
|
request["o"] = o_term
|
|
if g is not None:
|
|
request["g"] = g
|
|
if collection is not None:
|
|
request["collection"] = collection
|
|
request.update(kwargs)
|
|
|
|
result = self.client._send_request_sync("triples", self.flow_id, request, False)
|
|
if isinstance(result, dict) and "response" in result:
|
|
return result["response"]
|
|
return result
|
|
|
|
def triples_query_stream(
|
|
self,
|
|
s: Optional[Union[str, Dict[str, Any]]] = None,
|
|
p: Optional[Union[str, Dict[str, Any]]] = None,
|
|
o: Optional[Union[str, Dict[str, Any]]] = None,
|
|
g: Optional[str] = None,
|
|
collection: Optional[str] = None,
|
|
limit: int = 100,
|
|
batch_size: int = 20,
|
|
**kwargs: Any
|
|
) -> Iterator[List[Dict[str, Any]]]:
|
|
"""Query knowledge graph triples with streaming batches."""
|
|
request = {
|
|
"limit": limit,
|
|
"streaming": True,
|
|
"batch-size": batch_size,
|
|
}
|
|
|
|
s_term = build_term(s)
|
|
p_term = build_term(p)
|
|
o_term = build_term(o)
|
|
|
|
if s_term is not None:
|
|
request["s"] = s_term
|
|
if p_term is not None:
|
|
request["p"] = p_term
|
|
if o_term is not None:
|
|
request["o"] = o_term
|
|
if g is not None:
|
|
request["g"] = g
|
|
if collection is not None:
|
|
request["collection"] = collection
|
|
request.update(kwargs)
|
|
|
|
for response in self.client._send_request_sync("triples", self.flow_id, request, streaming_raw=True):
|
|
if isinstance(response, dict) and "response" in response:
|
|
yield response["response"]
|
|
else:
|
|
yield response
|
|
|
|
def sparql_query_stream(
|
|
self,
|
|
query: str,
|
|
collection: str = "default",
|
|
limit: int = 10000,
|
|
batch_size: int = 20,
|
|
**kwargs: Any
|
|
) -> Iterator[Dict[str, Any]]:
|
|
"""Execute a SPARQL query with streaming batches."""
|
|
request = {
|
|
"query": query,
|
|
"collection": collection,
|
|
"limit": limit,
|
|
"streaming": True,
|
|
"batch-size": batch_size,
|
|
}
|
|
request.update(kwargs)
|
|
|
|
for response in self.client._send_request_sync(
|
|
"sparql", self.flow_id, request, streaming_raw=True
|
|
):
|
|
yield response
|
|
|
|
def rows_query(
|
|
self,
|
|
query: str,
|
|
collection: str,
|
|
variables: Optional[Dict[str, Any]] = None,
|
|
operation_name: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> Dict[str, Any]:
|
|
"""Execute a GraphQL query against structured rows."""
|
|
request = {
|
|
"query": query,
|
|
"collection": collection
|
|
}
|
|
if variables:
|
|
request["variables"] = variables
|
|
if operation_name:
|
|
request["operationName"] = operation_name
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync("rows", self.flow_id, request, False)
|
|
|
|
def mcp_tool(
|
|
self,
|
|
name: str,
|
|
parameters: Dict[str, Any],
|
|
**kwargs: Any
|
|
) -> Dict[str, Any]:
|
|
"""Execute a Model Context Protocol (MCP) tool."""
|
|
request = {
|
|
"name": name,
|
|
"parameters": parameters
|
|
}
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
|
|
|
|
def row_embeddings_query(
|
|
self,
|
|
text: str,
|
|
schema_name: str,
|
|
collection: str = "default",
|
|
index_name: Optional[str] = None,
|
|
limit: int = 10,
|
|
**kwargs: Any
|
|
) -> Dict[str, Any]:
|
|
"""Query row data using semantic similarity on indexed fields."""
|
|
emb_result = self.embeddings(texts=[text])
|
|
vector = emb_result.get("vectors", [[]])[0]
|
|
|
|
request = {
|
|
"vector": vector,
|
|
"schema_name": schema_name,
|
|
"collection": collection,
|
|
"limit": limit
|
|
}
|
|
if index_name:
|
|
request["index_name"] = index_name
|
|
request.update(kwargs)
|
|
|
|
return self.client._send_request_sync("row-embeddings", self.flow_id, request, False)
|