mirror of
https://github.com/katanemo/plano.git
synced 2026-05-02 20:32:42 +02:00
Serialize tool calls for Arch FC (#131)
* Serialize tool calls * fix int tests
This commit is contained in:
parent
b43f687b85
commit
96686dc606
10 changed files with 166 additions and 57 deletions
|
|
@ -11,6 +11,7 @@ OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
|
|||
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
ARCH_STATE_HEADER = 'x-arch-state'
|
||||
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
|
||||
|
|
@ -32,7 +33,7 @@ def predict(message, state):
|
|||
|
||||
metadata = None
|
||||
if 'arch_state' in state:
|
||||
metadata = {"x-arch-state": state['arch_state']}
|
||||
metadata = {ARCH_STATE_HEADER: state['arch_state']}
|
||||
|
||||
try:
|
||||
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
|
||||
|
|
@ -48,11 +49,12 @@ def predict(message, state):
|
|||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
|
||||
log.debug("raw_response: ", raw_response.text)
|
||||
response = raw_response.parse()
|
||||
|
||||
# extract arch_state from metadata and store it in gradio session state
|
||||
# this state must be passed back to the gateway in the next request
|
||||
arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None)
|
||||
arch_state = json.loads(raw_response.text).get('metadata', {}).get(ARCH_STATE_HEADER, None)
|
||||
if arch_state:
|
||||
state['arch_state'] = arch_state
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue