move custom tracer to llm filter (#267)

This commit is contained in:
Adil Hafeez 2024-11-15 10:44:01 -08:00 committed by GitHub
parent 1d229cba8f
commit d3c17c7abd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 335 additions and 133 deletions

View file

@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
from openai import OpenAI
from dotenv import load_dotenv
from common import get_prompt_targets, process_stream_chunk
from common import format_log, get_llm_models, get_prompt_targets, process_stream_chunk
load_dotenv()
@ -36,20 +36,28 @@ CSS_STYLE = """
footer {visibility: hidden}
"""
client = OpenAI(
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
)
def chat(
query: Optional[str],
conversation: Optional[List[Tuple[str, str]]],
history: List[dict],
debug_output: str,
model_selector: str,
):
history.append({"role": "user", "content": query})
if debug_output is None:
debug_output = ""
try:
headers = {}
if model_selector and model_selector != "":
headers["x-arch-llm-provider-hint"] = model_selector
client = OpenAI(
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
default_headers=headers,
)
response = client.chat.completions.create(
# we select model from arch_config file
model="--",
@ -65,15 +73,20 @@ def chat(
conversation.append((query, ""))
model_is_set = False
for chunk in response:
tokens = process_stream_chunk(chunk, history)
if tokens and not model_is_set:
model_is_set = True
model = history[-1]["model"]
debug_output = debug_output + "\n" + format_log(f"model: {model}")
if tokens:
conversation[-1] = (
conversation[-1][0],
conversation[-1][1] + tokens,
)
yield "", conversation, history
yield "", conversation, history, debug_output, model_selector
def main():
@ -94,8 +107,17 @@ def main():
value=get_prompt_targets(),
show_indices=False,
elem_classes="json-container",
min_height="95vh",
min_height="50vh",
)
model_selector_textbox = gr.Dropdown(
get_llm_models(),
label="override model",
elem_classes="dropdown",
)
debug_output = gr.TextArea(
label="debug output",
elem_classes="debug_output",
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
@ -110,7 +132,9 @@ def main():
)
textbox.submit(
chat, [textbox, chatbot, history], [textbox, chatbot, history]
chat,
[textbox, chatbot, history, debug_output, model_selector_textbox],
[textbox, chatbot, history, debug_output, model_selector_textbox],
)
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)