mirror of
https://github.com/katanemo/plano.git
synced 2026-04-28 18:36:34 +02:00
move custom tracer to llm filter (#267)
This commit is contained in:
parent
1d229cba8f
commit
d3c17c7abd
22 changed files with 335 additions and 133 deletions
|
|
@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
|
|||
from openai import OpenAI
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from common import get_prompt_targets, process_stream_chunk
|
||||
from common import format_log, get_llm_models, get_prompt_targets, process_stream_chunk
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
|
@ -36,20 +36,28 @@ CSS_STYLE = """
|
|||
footer {visibility: hidden}
|
||||
"""
|
||||
|
||||
client = OpenAI(
|
||||
api_key="--",
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
)
|
||||
|
||||
|
||||
def chat(
|
||||
query: Optional[str],
|
||||
conversation: Optional[List[Tuple[str, str]]],
|
||||
history: List[dict],
|
||||
debug_output: str,
|
||||
model_selector: str,
|
||||
):
|
||||
history.append({"role": "user", "content": query})
|
||||
|
||||
if debug_output is None:
|
||||
debug_output = ""
|
||||
|
||||
try:
|
||||
headers = {}
|
||||
if model_selector and model_selector != "":
|
||||
headers["x-arch-llm-provider-hint"] = model_selector
|
||||
client = OpenAI(
|
||||
api_key="--",
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
default_headers=headers,
|
||||
)
|
||||
response = client.chat.completions.create(
|
||||
# we select model from arch_config file
|
||||
model="--",
|
||||
|
|
@ -65,15 +73,20 @@ def chat(
|
|||
|
||||
conversation.append((query, ""))
|
||||
|
||||
model_is_set = False
|
||||
for chunk in response:
|
||||
tokens = process_stream_chunk(chunk, history)
|
||||
if tokens and not model_is_set:
|
||||
model_is_set = True
|
||||
model = history[-1]["model"]
|
||||
debug_output = debug_output + "\n" + format_log(f"model: {model}")
|
||||
if tokens:
|
||||
conversation[-1] = (
|
||||
conversation[-1][0],
|
||||
conversation[-1][1] + tokens,
|
||||
)
|
||||
|
||||
yield "", conversation, history
|
||||
yield "", conversation, history, debug_output, model_selector
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -94,8 +107,17 @@ def main():
|
|||
value=get_prompt_targets(),
|
||||
show_indices=False,
|
||||
elem_classes="json-container",
|
||||
min_height="95vh",
|
||||
min_height="50vh",
|
||||
)
|
||||
model_selector_textbox = gr.Dropdown(
|
||||
get_llm_models(),
|
||||
label="override model",
|
||||
elem_classes="dropdown",
|
||||
)
|
||||
debug_output = gr.TextArea(
|
||||
label="debug output",
|
||||
elem_classes="debug_output",
|
||||
)
|
||||
|
||||
with gr.Column(scale=2):
|
||||
chatbot = gr.Chatbot(
|
||||
|
|
@ -110,7 +132,9 @@ def main():
|
|||
)
|
||||
|
||||
textbox.submit(
|
||||
chat, [textbox, chatbot, history], [textbox, chatbot, history]
|
||||
chat,
|
||||
[textbox, chatbot, history, debug_output, model_selector_textbox],
|
||||
[textbox, chatbot, history, debug_output, model_selector_textbox],
|
||||
)
|
||||
|
||||
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue