mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-06-12 19:55:19 +02:00
275 lines
11 KiB
Python
275 lines
11 KiB
Python
# Standard library imports
|
|
import copy
|
|
import json
|
|
from collections import defaultdict
|
|
from typing import List, Callable, Union
|
|
from datetime import datetime
|
|
|
|
# Package/library imports
|
|
from openai import OpenAI
|
|
import random
|
|
|
|
# Local imports
|
|
from .util import *
|
|
from .types import (
|
|
Agent,
|
|
AgentFunction,
|
|
ChatCompletionMessage,
|
|
ChatCompletionMessageToolCall,
|
|
Function,
|
|
Response,
|
|
Result,
|
|
)
|
|
|
|
__CTX_VARS_NAME__ = "context_variables"
|
|
|
|
|
|
class Swarm:
|
|
def __init__(self, client=None):
|
|
if not client:
|
|
client = OpenAI(api_key=OPENAI_API_KEY)
|
|
self.client = client
|
|
self.history = defaultdict(lambda : [])
|
|
|
|
def get_chat_completion(
|
|
self,
|
|
agent: Agent,
|
|
history: List,
|
|
context_variables: dict,
|
|
model_override: str,
|
|
stream: bool,
|
|
debug: bool,
|
|
temperature: float
|
|
) -> ChatCompletionMessage:
|
|
context_variables = defaultdict(str, context_variables)
|
|
instructions = (
|
|
agent.instructions(context_variables)
|
|
if callable(agent.instructions)
|
|
else agent.instructions
|
|
)
|
|
messages = [{"role": "system", "content": instructions}] + history
|
|
debug_print(debug, "Getting chat completion for...:", messages)
|
|
|
|
all_functions = list(agent.child_functions.values()) + ([agent.parent_function] if agent.parent_function else [])
|
|
all_tools = agent.external_tools + agent.internal_tools
|
|
funcs_and_tools = [function_to_json(f) for f in all_functions] + [t for t in all_tools]
|
|
# hide context_variables from model
|
|
for tool in funcs_and_tools:
|
|
params = tool["function"]["parameters"]
|
|
params["properties"].pop(__CTX_VARS_NAME__, None)
|
|
if __CTX_VARS_NAME__ in params.get("required", []):
|
|
params["required"].remove(__CTX_VARS_NAME__)
|
|
|
|
create_params = {
|
|
"model": model_override or agent.model,
|
|
"messages": messages,
|
|
"tools": funcs_and_tools or None,
|
|
"tool_choice": agent.tool_choice,
|
|
"stream": stream,
|
|
"temperature": temperature
|
|
}
|
|
|
|
if funcs_and_tools:
|
|
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
|
|
|
|
return self.client.chat.completions.create(**create_params)
|
|
|
|
def handle_function_result(self, result, debug) -> Result:
|
|
# Check if result is already a Result instance
|
|
if isinstance(result, Result):
|
|
return result
|
|
|
|
# Check if result is an Agent instance
|
|
if isinstance(result, Agent):
|
|
return Result(
|
|
value=json.dumps({"assistant": result.name}),
|
|
agent=result,
|
|
)
|
|
|
|
# Handle all other cases
|
|
try:
|
|
return Result(value=str(result))
|
|
except Exception as e:
|
|
error_message = f"Failed to cast response to string: {result}. Make sure agent functions return a string or Result object. Error: {str(e)}"
|
|
debug_print(debug, error_message)
|
|
raise TypeError(error_message)
|
|
|
|
def handle_function_calls(
|
|
self,
|
|
tool_calls: List[ChatCompletionMessageToolCall],
|
|
functions: List[AgentFunction],
|
|
context_variables: dict,
|
|
debug: bool,
|
|
) -> Response:
|
|
function_map = {f.__name__: f for f in functions}
|
|
partial_response = Response(
|
|
messages=[], agent=None, context_variables={})
|
|
|
|
for tool_call in tool_calls:
|
|
name = tool_call.function.name
|
|
# handle missing tool case, skip to next tool
|
|
if name not in function_map:
|
|
debug_print(debug, f"Tool {name} not found in function map.")
|
|
partial_response.messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call.id,
|
|
"tool_name": name,
|
|
"content": f"Error: Tool {name} not found.",
|
|
}
|
|
)
|
|
continue
|
|
args = json.loads(tool_call.function.arguments)
|
|
debug_print(
|
|
debug, f"Processing tool call: {name} with arguments {args}")
|
|
|
|
func = function_map[name]
|
|
# pass context_variables to agent functions
|
|
if __CTX_VARS_NAME__ in func.__code__.co_varnames:
|
|
args[__CTX_VARS_NAME__] = context_variables
|
|
raw_result = function_map[name](**args)
|
|
|
|
result: Result = self.handle_function_result(raw_result, debug)
|
|
partial_response.messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call.id,
|
|
"tool_name": name,
|
|
"content": result.value,
|
|
}
|
|
)
|
|
partial_response.context_variables.update(result.context_variables)
|
|
if result.agent:
|
|
partial_response.agent = result.agent
|
|
|
|
return partial_response
|
|
|
|
def run(
|
|
self,
|
|
agent: Agent,
|
|
messages: List,
|
|
context_variables: dict = {},
|
|
model_override: str = None,
|
|
stream: bool = False,
|
|
debug: bool = False,
|
|
max_messages_per_turn: int = 10,
|
|
execute_tools: bool = True,
|
|
external_tools: List[str] = [],
|
|
localize_history: bool = True,
|
|
parent_has_child_history: bool = True,
|
|
tokens_used: dict = {},
|
|
temperature: float = 0.0
|
|
) -> Response:
|
|
|
|
active_agent = agent
|
|
context_variables = copy.deepcopy(context_variables)
|
|
global_history = copy.deepcopy(messages)
|
|
init_len = len(messages)
|
|
|
|
while len(global_history) - init_len < max_messages_per_turn and active_agent:
|
|
history = active_agent.history if localize_history else global_history
|
|
history = arrange_messages_keys_in_order(history)
|
|
|
|
parent = active_agent.most_recent_parent
|
|
|
|
children_names_backup, children_backup, child_functions_backup = copy.deepcopy(active_agent.children_names), copy.deepcopy(active_agent.children), copy.deepcopy(active_agent.child_functions)
|
|
|
|
active_agent = check_and_remove_repeat_tool_call_to_child(active_agent, history)
|
|
|
|
# get completion with current history, agent
|
|
completion = self.get_chat_completion(
|
|
agent=active_agent,
|
|
history=history,
|
|
context_variables=context_variables,
|
|
model_override=model_override,
|
|
stream=stream,
|
|
debug=debug,
|
|
temperature=temperature
|
|
)
|
|
tokens_used = update_tokens_used(provider="openai", model=model_override or active_agent.model, tokens_used=tokens_used, completion=completion)
|
|
|
|
# Restore children and child functions
|
|
active_agent.children_names, active_agent.children, active_agent.child_functions = children_names_backup, children_backup, child_functions_backup
|
|
|
|
message = completion.choices[0].message
|
|
debug_print(debug, "Received completion:", message)
|
|
message.sender = active_agent.name
|
|
message_json = json.loads(message.model_dump_json())
|
|
message_json = add_message_metadata(message_json, active_agent)
|
|
|
|
if localize_history:
|
|
active_agent = update_histories(active_agent, message_json)
|
|
if parent and parent_has_child_history:
|
|
parent = update_histories(parent, message_json)
|
|
global_history.append(message_json)
|
|
|
|
external_tool_calls = []
|
|
internal_tool_calls = []
|
|
|
|
if message.tool_calls:
|
|
message_json["response_type"] = "internal"
|
|
for tool_call in message.tool_calls:
|
|
tool_name = tool_call.function.name
|
|
if tool_name in external_tools:
|
|
external_tool_calls.append(tool_call)
|
|
else:
|
|
internal_tool_calls.append(tool_call)
|
|
message.tool_calls = internal_tool_calls
|
|
|
|
if not message.tool_calls or not execute_tools:
|
|
if external_tool_calls:
|
|
message.tool_calls.extend(external_tool_calls)
|
|
debug_print(debug, "Ending turn.")
|
|
break
|
|
|
|
# handle function calls, updating context_variables, and switching agents
|
|
all_functions = list(active_agent.child_functions.values()) + ([active_agent.parent_function] if active_agent.parent_function else [])
|
|
partial_response = self.handle_function_calls(
|
|
message.tool_calls, all_functions, context_variables, debug
|
|
)
|
|
for msg in partial_response.messages:
|
|
msg = add_message_metadata(msg, active_agent)
|
|
if localize_history:
|
|
active_agent = update_histories(active_agent, msg)
|
|
if parent and parent_has_child_history:
|
|
parent = update_histories(parent, msg)
|
|
|
|
global_history.extend(partial_response.messages)
|
|
context_variables.update(partial_response.context_variables)
|
|
|
|
# Parent to child transfer
|
|
if partial_response.agent:
|
|
prev_agent = active_agent
|
|
active_agent = partial_response.agent
|
|
|
|
# Parent to child transfer
|
|
if active_agent.name in prev_agent.children_names:
|
|
active_agent.most_recent_parent = prev_agent
|
|
active_agent.parent_function = active_agent.candidate_parent_functions[active_agent.most_recent_parent.name]
|
|
if localize_history:
|
|
if not parent_has_child_history:
|
|
prev_agent.history = remove_irrelevant_messages(prev_agent.history)
|
|
new_active_agent_history = get_current_turn_messages(global_history, only_user = True)
|
|
active_agent.history.extend(new_active_agent_history)
|
|
|
|
# Child to parent transfer
|
|
else:
|
|
assert parent == active_agent, "Parent and active agent do not match when active agent is not a child of previous agent"
|
|
child = prev_agent
|
|
if localize_history:
|
|
child.history = remove_irrelevant_messages(child.history)
|
|
|
|
|
|
return_messages = global_history[init_len:]
|
|
error_msg = ""
|
|
|
|
if len(global_history) - init_len >= max_messages_per_turn:
|
|
error_msg = "Max messages per turn reached"
|
|
|
|
return Response(
|
|
messages=return_messages,
|
|
agent=active_agent,
|
|
context_variables=context_variables,
|
|
error_msg=error_msg,
|
|
tokens_used=tokens_used
|
|
)
|