mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
Prompt refactor (#125)
* Prompt manager integrated and working with 6 tests * Updated templates to for prompt-template update
This commit is contained in:
parent
51aef6c730
commit
1e137768ca
19 changed files with 649 additions and 479 deletions
|
|
@ -1,7 +1,9 @@
|
|||
|
||||
import _pulsar
|
||||
import json
|
||||
import dataclasses
|
||||
|
||||
from .. schema import PromptRequest, PromptResponse, Fact, RowSchema, Field
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
from .. schema import prompt_request_queue
|
||||
from .. schema import prompt_response_queue
|
||||
from . base import BaseClient
|
||||
|
|
@ -12,6 +14,23 @@ WARN=_pulsar.LoggerLevel.Warn
|
|||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Definition:
|
||||
name: str
|
||||
definition: str
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Relationship:
|
||||
s: str
|
||||
p: str
|
||||
o: str
|
||||
o_entity: str
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Topic:
|
||||
topic: str
|
||||
definition: str
|
||||
|
||||
class PromptClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
|
|
@ -38,63 +57,116 @@ class PromptClient(BaseClient):
|
|||
output_schema=PromptResponse,
|
||||
)
|
||||
|
||||
def request(self, id, terms, timeout=300):
|
||||
|
||||
resp = self.call(
|
||||
id=id,
|
||||
terms={
|
||||
k: json.dumps(v)
|
||||
for k, v in terms.items()
|
||||
},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.text: return resp.text
|
||||
|
||||
return json.loads(resp.object)
|
||||
|
||||
def request_definitions(self, chunk, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="extract-definitions", chunk=chunk,
|
||||
defs = self.request(
|
||||
id="extract-definitions",
|
||||
terms={
|
||||
"text": chunk
|
||||
},
|
||||
timeout=timeout
|
||||
).definitions
|
||||
|
||||
def request_topics(self, chunk, timeout=300):
|
||||
)
|
||||
|
||||
return self.call(
|
||||
kind="extract-topics", chunk=chunk,
|
||||
timeout=timeout
|
||||
).topics
|
||||
return [
|
||||
Definition(name=d["entity"], definition=d["definition"])
|
||||
for d in defs
|
||||
]
|
||||
|
||||
def request_relationships(self, chunk, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="extract-relationships", chunk=chunk,
|
||||
rels = self.request(
|
||||
id="extract-relationships",
|
||||
terms={
|
||||
"text": chunk
|
||||
},
|
||||
timeout=timeout
|
||||
).relationships
|
||||
)
|
||||
|
||||
return [
|
||||
Relationship(
|
||||
s=d["subject"],
|
||||
p=d["predicate"],
|
||||
o=d["object"],
|
||||
o_entity=d["object-entity"]
|
||||
)
|
||||
for d in rels
|
||||
]
|
||||
|
||||
def request_topics(self, chunk, timeout=300):
|
||||
|
||||
topics = self.request(
|
||||
id="extract-topics",
|
||||
terms={
|
||||
"text": chunk
|
||||
},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return [
|
||||
Topic(topic=d["topic"], definition=d["definition"])
|
||||
for d in topics
|
||||
]
|
||||
|
||||
def request_rows(self, schema, chunk, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="extract-rows", chunk=chunk,
|
||||
row_schema=RowSchema(
|
||||
name=schema.name,
|
||||
description=schema.description,
|
||||
fields=[
|
||||
Field(
|
||||
name=f.name, type=str(f.type), size=f.size,
|
||||
primary=f.primary, description=f.description,
|
||||
)
|
||||
for f in schema.fields
|
||||
]
|
||||
),
|
||||
return self.request(
|
||||
id="extract-rows",
|
||||
terms={
|
||||
"chunk": chunk,
|
||||
"row-schema": {
|
||||
"name": schema.name,
|
||||
"description": schema.description,
|
||||
"fields": [
|
||||
{
|
||||
"name": f.name, "type": str(f.type),
|
||||
"size": f.size, "primary": f.primary,
|
||||
"description": f.description,
|
||||
}
|
||||
for f in schema.fields
|
||||
]
|
||||
}
|
||||
},
|
||||
timeout=timeout
|
||||
).rows
|
||||
)
|
||||
|
||||
def request_kg_prompt(self, query, kg, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="kg-prompt",
|
||||
query=query,
|
||||
kg=[
|
||||
Fact(s=v[0], p=v[1], o=v[2])
|
||||
for v in kg
|
||||
],
|
||||
return self.request(
|
||||
id="kg-prompt",
|
||||
terms={
|
||||
"query": query,
|
||||
"knowledge": [
|
||||
{ "s": v[0], "p": v[1], "o": v[2] }
|
||||
for v in kg
|
||||
]
|
||||
},
|
||||
timeout=timeout
|
||||
).answer
|
||||
)
|
||||
|
||||
def request_document_prompt(self, query, documents, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="document-prompt",
|
||||
query=query,
|
||||
documents=documents,
|
||||
return self.request(
|
||||
id="document-prompt",
|
||||
terms={
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
timeout=timeout
|
||||
).answer
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,20 +39,21 @@ class Fact(Record):
|
|||
# schema, chunk -> rows
|
||||
|
||||
class PromptRequest(Record):
|
||||
kind = String()
|
||||
chunk = String()
|
||||
query = String()
|
||||
kg = Array(Fact())
|
||||
documents = Array(Bytes())
|
||||
row_schema = RowSchema()
|
||||
id = String()
|
||||
|
||||
# JSON encoded values
|
||||
terms = Map(String())
|
||||
|
||||
class PromptResponse(Record):
|
||||
|
||||
# Error case
|
||||
error = Error()
|
||||
answer = String()
|
||||
definitions = Array(Definition())
|
||||
topics = Array(Topic())
|
||||
relationships = Array(Relationship())
|
||||
rows = Array(Map(String()))
|
||||
|
||||
# Just plain text
|
||||
text = String()
|
||||
|
||||
# JSON encoded
|
||||
object = String()
|
||||
|
||||
prompt_request_queue = topic(
|
||||
'prompt', kind='non-persistent', namespace='request'
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue