rowboat/apps/agents/src/swarm/core.py
2025-02-18 14:15:56 +05:30

275 lines
11 KiB
Python

# Standard library imports
import copy
import json
from collections import defaultdict
from typing import List, Callable, Union
from datetime import datetime
# Package/library imports
from openai import OpenAI
import random
# Local imports
from .util import *
from .types import (
Agent,
AgentFunction,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
Function,
Response,
Result,
)
__CTX_VARS_NAME__ = "context_variables"
class Swarm:
def __init__(self, client=None):
if not client:
client = OpenAI(api_key=OPENAI_API_KEY)
self.client = client
self.history = defaultdict(lambda : [])
def get_chat_completion(
self,
agent: Agent,
history: List,
context_variables: dict,
model_override: str,
stream: bool,
debug: bool,
temperature: float
) -> ChatCompletionMessage:
context_variables = defaultdict(str, context_variables)
instructions = (
agent.instructions(context_variables)
if callable(agent.instructions)
else agent.instructions
)
messages = [{"role": "system", "content": instructions}] + history
debug_print(debug, "Getting chat completion for...:", messages)
all_functions = list(agent.child_functions.values()) + ([agent.parent_function] if agent.parent_function else [])
all_tools = agent.external_tools + agent.internal_tools
funcs_and_tools = [function_to_json(f) for f in all_functions] + [t for t in all_tools]
# hide context_variables from model
for tool in funcs_and_tools:
params = tool["function"]["parameters"]
params["properties"].pop(__CTX_VARS_NAME__, None)
if __CTX_VARS_NAME__ in params.get("required", []):
params["required"].remove(__CTX_VARS_NAME__)
create_params = {
"model": model_override or agent.model,
"messages": messages,
"tools": funcs_and_tools or None,
"tool_choice": agent.tool_choice,
"stream": stream,
"temperature": temperature
}
if funcs_and_tools:
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
return self.client.chat.completions.create(**create_params)
def handle_function_result(self, result, debug) -> Result:
# Check if result is already a Result instance
if isinstance(result, Result):
return result
# Check if result is an Agent instance
if isinstance(result, Agent):
return Result(
value=json.dumps({"assistant": result.name}),
agent=result,
)
# Handle all other cases
try:
return Result(value=str(result))
except Exception as e:
error_message = f"Failed to cast response to string: {result}. Make sure agent functions return a string or Result object. Error: {str(e)}"
debug_print(debug, error_message)
raise TypeError(error_message)
def handle_function_calls(
self,
tool_calls: List[ChatCompletionMessageToolCall],
functions: List[AgentFunction],
context_variables: dict,
debug: bool,
) -> Response:
function_map = {f.__name__: f for f in functions}
partial_response = Response(
messages=[], agent=None, context_variables={})
for tool_call in tool_calls:
name = tool_call.function.name
# handle missing tool case, skip to next tool
if name not in function_map:
debug_print(debug, f"Tool {name} not found in function map.")
partial_response.messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": f"Error: Tool {name} not found.",
}
)
continue
args = json.loads(tool_call.function.arguments)
debug_print(
debug, f"Processing tool call: {name} with arguments {args}")
func = function_map[name]
# pass context_variables to agent functions
if __CTX_VARS_NAME__ in func.__code__.co_varnames:
args[__CTX_VARS_NAME__] = context_variables
raw_result = function_map[name](**args)
result: Result = self.handle_function_result(raw_result, debug)
partial_response.messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": result.value,
}
)
partial_response.context_variables.update(result.context_variables)
if result.agent:
partial_response.agent = result.agent
return partial_response
def run(
self,
agent: Agent,
messages: List,
context_variables: dict = {},
model_override: str = None,
stream: bool = False,
debug: bool = False,
max_messages_per_turn: int = 10,
execute_tools: bool = True,
external_tools: List[str] = [],
localize_history: bool = True,
parent_has_child_history: bool = True,
tokens_used: dict = {},
temperature: float = 0.0
) -> Response:
active_agent = agent
context_variables = copy.deepcopy(context_variables)
global_history = copy.deepcopy(messages)
init_len = len(messages)
while len(global_history) - init_len < max_messages_per_turn and active_agent:
history = active_agent.history if localize_history else global_history
history = arrange_messages_keys_in_order(history)
parent = active_agent.most_recent_parent
children_names_backup, children_backup, child_functions_backup = copy.deepcopy(active_agent.children_names), copy.deepcopy(active_agent.children), copy.deepcopy(active_agent.child_functions)
active_agent = check_and_remove_repeat_tool_call_to_child(active_agent, history)
# get completion with current history, agent
completion = self.get_chat_completion(
agent=active_agent,
history=history,
context_variables=context_variables,
model_override=model_override,
stream=stream,
debug=debug,
temperature=temperature
)
tokens_used = update_tokens_used(provider="openai", model=model_override or active_agent.model, tokens_used=tokens_used, completion=completion)
# Restore children and child functions
active_agent.children_names, active_agent.children, active_agent.child_functions = children_names_backup, children_backup, child_functions_backup
message = completion.choices[0].message
debug_print(debug, "Received completion:", message)
message.sender = active_agent.name
message_json = json.loads(message.model_dump_json())
message_json = add_message_metadata(message_json, active_agent)
if localize_history:
active_agent = update_histories(active_agent, message_json)
if parent and parent_has_child_history:
parent = update_histories(parent, message_json)
global_history.append(message_json)
external_tool_calls = []
internal_tool_calls = []
if message.tool_calls:
message_json["response_type"] = "internal"
for tool_call in message.tool_calls:
tool_name = tool_call.function.name
if tool_name in external_tools:
external_tool_calls.append(tool_call)
else:
internal_tool_calls.append(tool_call)
message.tool_calls = internal_tool_calls
if not message.tool_calls or not execute_tools:
if external_tool_calls:
message.tool_calls.extend(external_tool_calls)
debug_print(debug, "Ending turn.")
break
# handle function calls, updating context_variables, and switching agents
all_functions = list(active_agent.child_functions.values()) + ([active_agent.parent_function] if active_agent.parent_function else [])
partial_response = self.handle_function_calls(
message.tool_calls, all_functions, context_variables, debug
)
for msg in partial_response.messages:
msg = add_message_metadata(msg, active_agent)
if localize_history:
active_agent = update_histories(active_agent, msg)
if parent and parent_has_child_history:
parent = update_histories(parent, msg)
global_history.extend(partial_response.messages)
context_variables.update(partial_response.context_variables)
# Parent to child transfer
if partial_response.agent:
prev_agent = active_agent
active_agent = partial_response.agent
# Parent to child transfer
if active_agent.name in prev_agent.children_names:
active_agent.most_recent_parent = prev_agent
active_agent.parent_function = active_agent.candidate_parent_functions[active_agent.most_recent_parent.name]
if localize_history:
if not parent_has_child_history:
prev_agent.history = remove_irrelevant_messages(prev_agent.history)
new_active_agent_history = get_current_turn_messages(global_history, only_user = True)
active_agent.history.extend(new_active_agent_history)
# Child to parent transfer
else:
assert parent == active_agent, "Parent and active agent do not match when active agent is not a child of previous agent"
child = prev_agent
if localize_history:
child.history = remove_irrelevant_messages(child.history)
return_messages = global_history[init_len:]
error_msg = ""
if len(global_history) - init_len >= max_messages_per_turn:
error_msg = "Max messages per turn reached"
return Response(
messages=return_messages,
agent=active_agent,
context_variables=context_variables,
error_msg=error_msg,
tokens_used=tokens_used
)