mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-27 17:36:25 +02:00
tool invocation
This commit is contained in:
parent
b131c1768e
commit
b2fd9bf877
7 changed files with 574 additions and 162 deletions
|
|
@ -16,20 +16,20 @@ if __name__ == "__main__":
|
|||
tool_configs = complete_request.get("tools", [])
|
||||
prompt_configs = complete_request.get("prompts", [])
|
||||
start_agent_name = complete_request.get("startAgent", "")
|
||||
|
||||
|
||||
return agent_configs, tool_configs, prompt_configs, start_agent_name
|
||||
|
||||
|
||||
external_tool_mappings = {
|
||||
"raise_error": respond_to_tool_raise_error,
|
||||
"close_chat": respond_to_tool_close_chat
|
||||
}
|
||||
|
||||
|
||||
config_file = sys.argv[sys.argv.index("--config") + 1] if "--config" in sys.argv else "default_config.json"
|
||||
sample_request_file = sys.argv[sys.argv.index("--sample_request") + 1] if "--sample_request" in sys.argv else "default_example.json"
|
||||
|
||||
|
||||
config = read_json_from_file(f"./configs/{config_file}")
|
||||
example_request = read_json_from_file(f"./tests/sample_requests/{sample_request_file}").get("lastRequest", {})
|
||||
|
||||
|
||||
if "--load_messages" in sys.argv:
|
||||
messages = example_request.get("messages", [])
|
||||
messages = order_messages(messages)
|
||||
|
|
@ -57,7 +57,7 @@ if __name__ == "__main__":
|
|||
agent_configs, tool_configs, prompt_configs, start_agent_name = extract_request_fields(complete_request)
|
||||
|
||||
print(f"\nUsing agent: {last_agent_name}")
|
||||
|
||||
|
||||
if user_input_needed:
|
||||
user_inp = input('\nUSER: ')
|
||||
messages.append({
|
||||
|
|
@ -81,7 +81,7 @@ if __name__ == "__main__":
|
|||
"startAgent": start_agent_name
|
||||
}
|
||||
print(json.dumps(request_json, indent=2))
|
||||
|
||||
print(complete_request)
|
||||
resp_messages, resp_tokens_used, resp_state = run_turn(
|
||||
messages=messages,
|
||||
start_agent_name=start_agent_name,
|
||||
|
|
@ -89,7 +89,8 @@ if __name__ == "__main__":
|
|||
tool_configs=tool_configs,
|
||||
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
|
||||
state=state,
|
||||
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL]
|
||||
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
|
||||
complete_request=complete_request
|
||||
)
|
||||
state = resp_state
|
||||
resp_messages = order_messages(resp_messages)
|
||||
|
|
@ -101,12 +102,12 @@ if __name__ == "__main__":
|
|||
"tokens_used": resp_tokens_used
|
||||
}
|
||||
print(json.dumps(response_json, indent=2))
|
||||
|
||||
|
||||
last_msg = resp_messages[-1]
|
||||
print(f"\nBOT: {last_msg}\n")
|
||||
tool_calls = last_msg.get("tool_calls", [])
|
||||
sender = last_msg.get("sender", "")
|
||||
|
||||
|
||||
if config.get("return_diff_messages", True):
|
||||
messages.extend(resp_messages)
|
||||
else:
|
||||
|
|
@ -133,7 +134,7 @@ if __name__ == "__main__":
|
|||
current_tool_duration = round((datetime.now() - tool_start_time).total_seconds() * 10) / 10
|
||||
logger.info(f"Tool response duration: {current_tool_duration:.1f}s")
|
||||
tool_duration += current_tool_duration
|
||||
|
||||
|
||||
if tool_name == "close_chat":
|
||||
user_input_needed = False
|
||||
logger.info("Closing chat")
|
||||
|
|
@ -141,10 +142,10 @@ if __name__ == "__main__":
|
|||
|
||||
if should_break:
|
||||
break
|
||||
|
||||
|
||||
else:
|
||||
user_input_needed = True
|
||||
print(f"Turn Duration: {round((datetime.now() - turn_start_time).total_seconds() * 10) / 10:.1f}s\n")
|
||||
print(f"Tool Response Duration: {round(tool_duration * 10) / 10:.1f}s\n")
|
||||
|
||||
|
||||
print("\n" + "-" * 80)
|
||||
Loading…
Add table
Add a link
Reference in a new issue