don't compute embeddings for names and other fixes see description (#126)

* serialize tools - 2

* fix int tests

* fix int test

* fix unit tests
This commit is contained in:
Adil Hafeez 2024-10-05 19:25:16 -07:00 committed by GitHub
parent 0e5ea3d6db
commit 2a747df7c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 125 additions and 86 deletions

View file

@ -1,5 +1,6 @@
import json
import os
from openai import OpenAI
from openai import OpenAI, DefaultHttpxClient
import gradio as gr
import logging as log
from dotenv import load_dotenv
@ -13,11 +14,13 @@ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT)
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}))
def predict(message, history):
def predict(message, state):
if 'history' not in state:
state['history'] = []
history = state.get("history")
history.append({"role": "user", "content": message})
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
log.info("history: ", history)
# Custom headers
@ -27,34 +30,42 @@ def predict(message, history):
'x-arch-deterministic-provider': 'openai',
}
metadata = None
if 'arch_state' in state:
metadata = {"x-arch-state": state['arch_state']}
try:
response = client.chat.completions.create(model=MODEL_NAME,
messages= history,
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
messages = history,
temperature=1.0,
metadata=metadata,
extra_headers=custom_headers
)
except Exception as e:
log.info(e)
# remove last user message in case of exception
history.pop()
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message))
choices = response.choices
message = choices[0].message
content = message.content
history.append({"role": "assistant", "content": content})
history[-1]["model"] = response.model
response = raw_response.parse()
# extract arch_state from metadata and store it in gradio session state
# this state must be passed back to the gateway in the next request
arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None)
if arch_state:
state['arch_state'] = arch_state
content = response.choices[0].message.content
history.append({"role": "assistant", "content": content, "model": response.model})
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
return messages, history
return messages, state
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
print("Starting Demo...")
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
state = gr.State([])
state = gr.State({})
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True)

View file

@ -5,4 +5,4 @@ asyncio==3.4.3
httpx==0.27.0
python-dotenv==1.0.1
pydantic==2.8.2
openai==1.46.1
openai==1.51.0