Serialize tool calls for Arch FC (#131)

* Serialize tool calls

* fix int tests
This commit is contained in:
Adil Hafeez 2024-10-07 00:03:25 -07:00 committed by GitHub
parent b43f687b85
commit 96686dc606
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 166 additions and 57 deletions

View file

@ -11,6 +11,7 @@ OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
ARCH_STATE_HEADER = 'x-arch-state'
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
@ -32,7 +33,7 @@ def predict(message, state):
metadata = None
if 'arch_state' in state:
metadata = {"x-arch-state": state['arch_state']}
metadata = {ARCH_STATE_HEADER: state['arch_state']}
try:
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
@ -48,11 +49,12 @@ def predict(message, state):
log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message))
log.debug("raw_response: ", raw_response.text)
response = raw_response.parse()
# extract arch_state from metadata and store it in gradio session state
# this state must be passed back to the gateway in the next request
arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None)
arch_state = json.loads(raw_response.text).get('metadata', {}).get(ARCH_STATE_HEADER, None)
if arch_state:
state['arch_state'] = arch_state