mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-05 13:22:37 +02:00
Feature/pkgsplit (#83)
* Starting to spawn base package * More package hacking * Bedrock and VertexAI * Parquet split * Updated templates * Utils
This commit is contained in:
parent
3fb75c617b
commit
9b91d5eee3
262 changed files with 630 additions and 420 deletions
0
trustgraph-flow/trustgraph/model/__init__.py
Normal file
0
trustgraph-flow/trustgraph/model/__init__.py
Normal file
0
trustgraph-flow/trustgraph/model/prompt/__init__.py
Normal file
0
trustgraph-flow/trustgraph/model/prompt/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/prompt/generic/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/prompt/generic/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
176
trustgraph-flow/trustgraph/model/prompt/generic/prompts.py
Normal file
176
trustgraph-flow/trustgraph/model/prompt/generic/prompts.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
|
||||
def to_relationships(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.
|
||||
|
||||
Read the provided text. You will model the text as an information network for a RDF knowledge graph in JSON.
|
||||
|
||||
Information Network Rules:
|
||||
- An information network has subjects connected by predicates to objects.
|
||||
- A subject is a named-entity or a conceptual topic.
|
||||
- One subject can have many predicates and objects.
|
||||
- An object is a property or attribute of a subject.
|
||||
- A subject can be connected by a predicate to another subject.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Obey the information network rules.
|
||||
- Do not return special characters.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of JSON objects with keys "subject", "predicate", "object", and "object-entity".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"subject": string, "predicate": string, "object": string, "object-entity": boolean}}]
|
||||
```
|
||||
|
||||
- The key "object-entity" is TRUE only if the "object" is a subject.
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_topics(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify topics and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only topics that are concepts and unique to the provided text.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "topic" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"topic": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_definitions(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify entities and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only entities that are named-entities such as: people, organizations, physical objects, locations, animals, products, commodotities, or substances.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "entity" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"entity": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_rows(schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{field_schema}"""
|
||||
|
||||
prompt = f"""<instructions>
|
||||
Study the following text and derive objects which match the schema provided.
|
||||
|
||||
You must output an array of JSON objects for each object you discover
|
||||
which matches the schema. For each object, output a JSON object whose fields
|
||||
carry the name field specified in the schema.
|
||||
</instructions>
|
||||
|
||||
<schema>
|
||||
{schema}
|
||||
</schema>
|
||||
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
|
||||
<requirements>
|
||||
You will respond only with raw JSON format data. Do not provide
|
||||
explanations. Do not add markdown formatting or headers or prefixes.
|
||||
</requirements>"""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
|
||||
sg2 = []
|
||||
|
||||
for f in kg:
|
||||
|
||||
print(f)
|
||||
|
||||
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
|
||||
|
||||
print(sg2)
|
||||
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
|
||||
return kg
|
||||
|
||||
def to_kg_query(query, kg):
|
||||
|
||||
cypher = get_cypher(kg)
|
||||
|
||||
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here's the knowledge statements:
|
||||
{cypher}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_document_query(query, documents):
|
||||
|
||||
documents = "\n\n".join(documents)
|
||||
|
||||
prompt=f"""Study the following context. Use only the information provided in the context in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here is the context:
|
||||
{documents}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
473
trustgraph-flow/trustgraph/model/prompt/generic/service.py
Executable file
473
trustgraph-flow/trustgraph/model/prompt/generic/service.py
Executable file
|
|
@ -0,0 +1,473 @@
|
|||
"""
|
||||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships, to_topics
|
||||
from . prompts import to_kg_query, to_document_query, to_rows
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host
|
||||
)
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
kind = v.kind
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
|
||||
if kind == "extract-definitions":
|
||||
|
||||
self.handle_extract_definitions(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-topics":
|
||||
|
||||
self.handle_extract_topics(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-relationships":
|
||||
|
||||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
return
|
||||
|
||||
elif kind == "document-prompt":
|
||||
|
||||
self.handle_document_prompt(id, v)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
print("Invalid kind.", flush=True)
|
||||
return
|
||||
|
||||
def handle_extract_definitions(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_definitions(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["entity"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Definition(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(definitions=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_topics(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_topics(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["topic"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Topic(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(topics=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_relationships(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_relationships(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
|
||||
s = defn["subject"]
|
||||
p = defn["predicate"]
|
||||
o = defn["object"]
|
||||
o_entity = defn["object-entity"]
|
||||
|
||||
if s == "": continue
|
||||
if s is None: continue
|
||||
|
||||
if p == "": continue
|
||||
if p is None: continue
|
||||
|
||||
if o == "": continue
|
||||
if o is None: continue
|
||||
|
||||
if o_entity == "" or o_entity is None:
|
||||
o_entity = False
|
||||
|
||||
output.append(
|
||||
Relationship(
|
||||
s = s,
|
||||
p = p,
|
||||
o = o,
|
||||
o_entity = o_entity,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("relationship fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(relationships=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_kg_query(v.query, v.kg)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_document_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_document_query(v.query, v.documents)
|
||||
|
||||
print("prompt")
|
||||
print(prompt)
|
||||
|
||||
print("Call LLM...")
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/prompt/template/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/prompt/template/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
47
trustgraph-flow/trustgraph/model/prompt/template/prompts.py
Normal file
47
trustgraph-flow/trustgraph/model/prompt/template/prompts.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
|
||||
def to_relationships(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_definitions(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_topics(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_rows(template, schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
return template.format(schema=schema, text=text)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{schema}"""
|
||||
|
||||
prompt = f""""""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
sg2 = []
|
||||
for f in kg:
|
||||
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
return kg
|
||||
|
||||
def to_kg_query(template, query, kg):
|
||||
cypher = get_cypher(kg)
|
||||
return template.format(query=query, graph=cypher)
|
||||
|
||||
def to_document_query(template, query, docs):
|
||||
docs = "\n\n".join(docs)
|
||||
return template.format(query=query, documents=docs)
|
||||
|
||||
523
trustgraph-flow/trustgraph/model/prompt/template/service.py
Executable file
523
trustgraph-flow/trustgraph/model/prompt/template/service.py
Executable file
|
|
@ -0,0 +1,523 @@
|
|||
|
||||
"""
|
||||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships, to_rows
|
||||
from . prompts import to_kg_query, to_document_query, to_topics
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
definition_template = params.get("definition_template")
|
||||
relationship_template = params.get("relationship_template")
|
||||
topic_template = params.get("topic_template")
|
||||
rows_template = params.get("rows_template")
|
||||
knowledge_query_template = params.get("knowledge_query_template")
|
||||
document_query_template = params.get("document_query_template")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host
|
||||
)
|
||||
|
||||
self.definition_template = definition_template
|
||||
self.topic_template = topic_template
|
||||
self.relationship_template = relationship_template
|
||||
self.rows_template = rows_template
|
||||
self.knowledge_query_template = knowledge_query_template
|
||||
self.document_query_template = document_query_template
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
kind = v.kind
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
|
||||
if kind == "extract-definitions":
|
||||
|
||||
self.handle_extract_definitions(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-topics":
|
||||
|
||||
self.handle_extract_topics(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-relationships":
|
||||
|
||||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
return
|
||||
|
||||
elif kind == "document-prompt":
|
||||
|
||||
self.handle_document_prompt(id, v)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
print("Invalid kind.", flush=True)
|
||||
return
|
||||
|
||||
def handle_extract_definitions(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_definitions(self.definition_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["entity"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Definition(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(definitions=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_topics(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_topics(self.topic_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["topic"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Topic(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(topics=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
|
||||
def handle_extract_relationships(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_relationships(self.relationship_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
|
||||
s = defn["subject"]
|
||||
p = defn["predicate"]
|
||||
o = defn["object"]
|
||||
o_entity = defn["object-entity"]
|
||||
|
||||
if s == "": continue
|
||||
if s is None: continue
|
||||
|
||||
if p == "": continue
|
||||
if p is None: continue
|
||||
|
||||
if o == "": continue
|
||||
if o is None: continue
|
||||
|
||||
if o_entity == "" or o_entity is None:
|
||||
o_entity = False
|
||||
|
||||
output.append(
|
||||
Relationship(
|
||||
s = s,
|
||||
p = p,
|
||||
o = o,
|
||||
o_entity = o_entity,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("relationship fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(relationships=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(self.rows_template, v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_kg_query(self.knowledge_query_template, v.query, v.kg)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_document_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_document_query(
|
||||
self.document_query_template, v.query, v.documents
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--definition-template',
|
||||
required=True,
|
||||
help=f'Definition extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--topic-template',
|
||||
required=True,
|
||||
help=f'Topic extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rows-template',
|
||||
required=True,
|
||||
help=f'Rows extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--relationship-template',
|
||||
required=True,
|
||||
help=f'Relationship extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--knowledge-query-template',
|
||||
required=True,
|
||||
help=f'Knowledge query template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--document-query-template',
|
||||
required=True,
|
||||
help=f'Document query template',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/azure/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/azure/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
226
trustgraph-flow/trustgraph/model/text_completion/azure/llm.py
Executable file
226
trustgraph-flow/trustgraph/model/text_completion/azure/llm.py
Executable file
|
|
@ -0,0 +1,226 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using the Azure
|
||||
serverless endpoint service. Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_temperature = 0.0
|
||||
default_max_output = 4192
|
||||
default_model = "AzureAI"
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
endpoint = params.get("endpoint")
|
||||
token = params.get("token")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
model = default_model
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.token = token
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
|
||||
def build_prompt(self, system, content):
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system", "content": system
|
||||
},
|
||||
{
|
||||
"role": "user", "content": content
|
||||
}
|
||||
],
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1
|
||||
}
|
||||
|
||||
body = json.dumps(data)
|
||||
|
||||
return body
|
||||
|
||||
def call_llm(self, body):
|
||||
|
||||
url = self.endpoint
|
||||
|
||||
# Replace this with the primary/secondary key, AMLToken, or
|
||||
# Microsoft Entra ID token for the endpoint
|
||||
api_key = self.token
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}'
|
||||
}
|
||||
|
||||
resp = requests.post(url, data=body, headers=headers)
|
||||
|
||||
if resp.status_code == 429:
|
||||
raise TooManyRequests()
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError("LLM failure")
|
||||
|
||||
result = resp.json()
|
||||
|
||||
return result
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
try:
|
||||
|
||||
prompt = self.build_prompt(
|
||||
"You are a helpful chatbot",
|
||||
v.prompt
|
||||
)
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
response = self.call_llm(prompt)
|
||||
|
||||
resp = response['choices'][0]['message']['content']
|
||||
inputtokens = response['usage']['prompt_tokens']
|
||||
outputtokens = response['usage']['completion_tokens']
|
||||
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-e', '--endpoint',
|
||||
help=f'LLM model endpoint'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--token',
|
||||
help=f'LLM model token'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/claude/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/claude/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
199
trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
Executable file
199
trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
Executable file
|
|
@ -0,0 +1,199 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using Claude.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import anthropic
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'claude-3-5-sonnet-20240620'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 8192
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.claude = anthropic.Anthropic(api_key=api_key)
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# FIXME: Rate limits?
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
response = message = self.claude.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_output,
|
||||
temperature=self.temperature,
|
||||
system = "You are a helpful chatbot.",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
resp = response.content[0].text
|
||||
inputtokens = response.usage.input_tokens
|
||||
outputtokens = response.usage.output_tokens
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="claude-3-5-sonnet-20240620",
|
||||
help=f'LLM model (default: claude-3-5-sonnet-20240620)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
help=f'Claude API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/cohere/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/cohere/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
179
trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py
Executable file
179
trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py
Executable file
|
|
@ -0,0 +1,179 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using Cohere.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import cohere
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'c4ai-aya-23-8b'
|
||||
default_temperature = 0.0
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.cohere = cohere.Client(api_key=api_key)
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
output = self.cohere.chat(
|
||||
model=self.model,
|
||||
message=prompt,
|
||||
preamble = "You are a helpful AI-assistant.",
|
||||
temperature=self.temperature,
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
resp = output.text
|
||||
inputtokens = int(output.meta.billed_units.input_tokens)
|
||||
outputtokens = int(output.meta.billed_units.output_tokens)
|
||||
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="c4ai-aya-23-8b",
|
||||
help=f'Cohere model (default: c4ai-aya-23-8b)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
help=f'Cohere API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/llamafile/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/llamafile/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
209
trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py
Executable file
209
trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py
Executable file
|
|
@ -0,0 +1,209 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using OpenAI.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'LLaMA_CPP'
|
||||
default_llamafile = 'http://localhost:8080/v1'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 4096
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
llamafile = params.get("llamafile", default_llamafile)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
"llamafile" : llamafile,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.llamafile=llamafile
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.openai = OpenAI(
|
||||
base_url=self.llamafile,
|
||||
api_key = "sk-no-key-required",
|
||||
)
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# FIXME: Rate limits
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
#temperature=self.temperature,
|
||||
#max_tokens=self.max_output,
|
||||
#top_p=1,
|
||||
#frequency_penalty=0,
|
||||
#presence_penalty=0,
|
||||
#response_format={
|
||||
# "type": "text"
|
||||
#}
|
||||
)
|
||||
|
||||
inputtokens = resp.usage.prompt_tokens
|
||||
outputtokens = resp.usage.completion_tokens
|
||||
|
||||
print(resp.choices[0].message.content, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(
|
||||
response=resp.choices[0].message.content,
|
||||
error=None,
|
||||
in_token=inputtokens,
|
||||
out_token=outputtokens,
|
||||
model="llama.cpp"
|
||||
)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'LLM model (default: LLaMA_CPP)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-r', '--llamafile',
|
||||
default=default_llamafile,
|
||||
help=f'ollama (default: {default_llamafile})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/ollama/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/ollama/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
168
trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py
Executable file
168
trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py
Executable file
|
|
@ -0,0 +1,168 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using an Ollama service.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
from ollama import Client
|
||||
from prometheus_client import Histogram, Info
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'gemma2'
|
||||
default_ollama = 'http://localhost:11434'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
ollama = params.get("ollama", default_ollama)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"model": model,
|
||||
"ollama": ollama,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "model_metric"):
|
||||
__class__.model_metric = Info(
|
||||
'model', 'Model information'
|
||||
)
|
||||
|
||||
__class__.model_metric.info({
|
||||
"model": model,
|
||||
"ollama": ollama,
|
||||
})
|
||||
|
||||
self.model = model
|
||||
self.llm = Client(host=ollama)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
response = self.llm.generate(self.model, prompt)
|
||||
|
||||
response_text = response['response']
|
||||
print("Send response...", flush=True)
|
||||
print(response_text, flush=True)
|
||||
|
||||
inputtokens = int(response['prompt_eval_count'])
|
||||
outputtokens = int(response['eval_count'])
|
||||
|
||||
r = TextCompletionResponse(response=response_text, error=None, in_token=inputtokens, out_token=outputtokens, model="ollama")
|
||||
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="gemma2",
|
||||
help=f'LLM model (default: gemma2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-r', '--ollama',
|
||||
default=default_ollama,
|
||||
help=f'ollama (default: {default_ollama})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/openai/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/openai/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
209
trustgraph-flow/trustgraph/model/text_completion/openai/llm.py
Executable file
209
trustgraph-flow/trustgraph/model/text_completion/openai/llm.py
Executable file
|
|
@ -0,0 +1,209 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using OpenAI.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'gpt-3.5-turbo'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 4096
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.openai = OpenAI(api_key=api_key)
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# FIXME: Rate limits
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={
|
||||
"type": "text"
|
||||
}
|
||||
)
|
||||
|
||||
inputtokens = resp.usage.prompt_tokens
|
||||
outputtokens = resp.usage.completion_tokens
|
||||
print(resp.choices[0].message.content, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(
|
||||
response=resp.choices[0].message.content,
|
||||
error=None,
|
||||
in_token=inputtokens,
|
||||
out_token=outputtokens,
|
||||
model=self.model
|
||||
)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="gpt-3.5-turbo",
|
||||
help=f'LLM model (default: GPT-3.5-Turbo)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
help=f'OpenAI API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue