mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
don't compute embeddings for names and other fixes see description (#126)
* serialize tools - 2 * fix int tests * fix int test * fix unit tests
This commit is contained in:
parent
0e5ea3d6db
commit
2a747df7c0
16 changed files with 125 additions and 86 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI, DefaultHttpxClient
|
||||
import gradio as gr
|
||||
import logging as log
|
||||
from dotenv import load_dotenv
|
||||
|
|
@ -13,11 +14,13 @@ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
|||
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
|
||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT)
|
||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}))
|
||||
|
||||
def predict(message, history):
|
||||
def predict(message, state):
|
||||
if 'history' not in state:
|
||||
state['history'] = []
|
||||
history = state.get("history")
|
||||
history.append({"role": "user", "content": message})
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
log.info("history: ", history)
|
||||
|
||||
# Custom headers
|
||||
|
|
@ -27,34 +30,42 @@ def predict(message, history):
|
|||
'x-arch-deterministic-provider': 'openai',
|
||||
}
|
||||
|
||||
metadata = None
|
||||
if 'arch_state' in state:
|
||||
metadata = {"x-arch-state": state['arch_state']}
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(model=MODEL_NAME,
|
||||
messages= history,
|
||||
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
|
||||
messages = history,
|
||||
temperature=1.0,
|
||||
metadata=metadata,
|
||||
extra_headers=custom_headers
|
||||
)
|
||||
except Exception as e:
|
||||
log.info(e)
|
||||
# remove last user message in case of exception
|
||||
history.pop()
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
|
||||
choices = response.choices
|
||||
message = choices[0].message
|
||||
content = message.content
|
||||
history.append({"role": "assistant", "content": content})
|
||||
history[-1]["model"] = response.model
|
||||
response = raw_response.parse()
|
||||
|
||||
# extract arch_state from metadata and store it in gradio session state
|
||||
# this state must be passed back to the gateway in the next request
|
||||
arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None)
|
||||
if arch_state:
|
||||
state['arch_state'] = arch_state
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
history.append({"role": "assistant", "content": content, "model": response.model})
|
||||
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
|
||||
return messages, history
|
||||
|
||||
return messages, state
|
||||
|
||||
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
|
||||
print("Starting Demo...")
|
||||
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
|
||||
state = gr.State([])
|
||||
state = gr.State({})
|
||||
with gr.Row():
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,4 +5,4 @@ asyncio==3.4.3
|
|||
httpx==0.27.0
|
||||
python-dotenv==1.0.1
|
||||
pydantic==2.8.2
|
||||
openai==1.46.1
|
||||
openai==1.51.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue