Add agents with custom swarm implementation

This commit is contained in:
akhisud3195 2025-01-13 18:20:38 +05:30
parent 24c4f6e552
commit a19dedd59f
35 changed files with 3413 additions and 0 deletions

View file

@ -0,0 +1,4 @@
from .core import Swarm
from .types import Agent, Response
__all__ = ["Swarm", "Agent", "Response"]

View file

@ -0,0 +1,269 @@
# Standard library imports
import copy
import json
from collections import defaultdict
from typing import List, Callable, Union
from datetime import datetime
# Package/library imports
from openai import OpenAI
import random
# Local imports
from .util import *
from .types import (
Agent,
AgentFunction,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
Function,
Response,
Result,
)
__CTX_VARS_NAME__ = "context_variables"
class Swarm:
def __init__(self, client=None):
if not client:
client = OpenAI(api_key=OPENAI_API_KEY)
self.client = client
self.history = defaultdict(lambda : [])
def get_chat_completion(
self,
agent: Agent,
history: List,
context_variables: dict,
model_override: str,
stream: bool,
debug: bool,
) -> ChatCompletionMessage:
context_variables = defaultdict(str, context_variables)
instructions = (
agent.instructions(context_variables)
if callable(agent.instructions)
else agent.instructions
)
messages = [{"role": "system", "content": instructions}] + history
debug_print(debug, "Getting chat completion for...:", messages)
all_functions = list(agent.child_functions.values()) + ([agent.parent_function] if agent.parent_function else [])
all_tools = agent.external_tools + agent.internal_tools
funcs_and_tools = [function_to_json(f) for f in all_functions] + [t for t in all_tools]
# hide context_variables from model
for tool in funcs_and_tools:
params = tool["function"]["parameters"]
params["properties"].pop(__CTX_VARS_NAME__, None)
if __CTX_VARS_NAME__ in params["required"]:
params["required"].remove(__CTX_VARS_NAME__)
create_params = {
"model": model_override or agent.model,
"messages": messages,
"tools": funcs_and_tools or None,
"tool_choice": agent.tool_choice,
"stream": stream,
}
if funcs_and_tools:
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
return self.client.chat.completions.create(**create_params)
def handle_function_result(self, result, debug) -> Result:
match result:
case Result() as result:
return result
case Agent() as agent:
return Result(
value=json.dumps({"assistant": agent.name}),
agent=agent,
)
case _:
try:
return Result(value=str(result))
except Exception as e:
error_message = f"Failed to cast response to string: {result}. Make sure agent functions return a string or Result object. Error: {str(e)}"
debug_print(debug, error_message)
raise TypeError(error_message)
def handle_function_calls(
self,
tool_calls: List[ChatCompletionMessageToolCall],
functions: List[AgentFunction],
context_variables: dict,
debug: bool,
) -> Response:
function_map = {f.__name__: f for f in functions}
partial_response = Response(
messages=[], agent=None, context_variables={})
for tool_call in tool_calls:
name = tool_call.function.name
# handle missing tool case, skip to next tool
if name not in function_map:
debug_print(debug, f"Tool {name} not found in function map.")
partial_response.messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": f"Error: Tool {name} not found.",
}
)
continue
args = json.loads(tool_call.function.arguments)
debug_print(
debug, f"Processing tool call: {name} with arguments {args}")
func = function_map[name]
# pass context_variables to agent functions
if __CTX_VARS_NAME__ in func.__code__.co_varnames:
args[__CTX_VARS_NAME__] = context_variables
raw_result = function_map[name](**args)
result: Result = self.handle_function_result(raw_result, debug)
partial_response.messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": result.value,
}
)
partial_response.context_variables.update(result.context_variables)
if result.agent:
partial_response.agent = result.agent
return partial_response
def run(
self,
agent: Agent,
messages: List,
context_variables: dict = {},
model_override: str = None,
stream: bool = False,
debug: bool = False,
max_messages_per_turn: int = 10,
execute_tools: bool = True,
external_tools: List[str] = [],
localize_history: bool = True,
parent_has_child_history: bool = True,
tokens_used: dict = {}
) -> Response:
active_agent = agent
context_variables = copy.deepcopy(context_variables)
global_history = copy.deepcopy(messages)
init_len = len(messages)
while len(global_history) - init_len < max_messages_per_turn and active_agent:
history = active_agent.history if localize_history else global_history
history = arrange_messages_keys_in_order(history)
parent = active_agent.most_recent_parent
children_names_backup, children_backup, child_functions_backup = copy.deepcopy(active_agent.children_names), copy.deepcopy(active_agent.children), copy.deepcopy(active_agent.child_functions)
active_agent = check_and_remove_repeat_tool_call_to_child(active_agent, history)
# get completion with current history, agent
completion = self.get_chat_completion(
agent=active_agent,
history=history,
context_variables=context_variables,
model_override=model_override,
stream=stream,
debug=debug,
)
tokens_used = update_tokens_used(provider="openai", model=model_override or active_agent.model, tokens_used=tokens_used, completion=completion)
# Restore children and child functions
active_agent.children_names, active_agent.children, active_agent.child_functions = children_names_backup, children_backup, child_functions_backup
message = completion.choices[0].message
debug_print(debug, "Received completion:", message)
message.sender = active_agent.name
message_json = json.loads(message.model_dump_json())
message_json = add_message_metadata(message_json, active_agent)
if localize_history:
active_agent = update_histories(active_agent, message_json)
if parent and parent_has_child_history:
parent = update_histories(parent, message_json)
global_history.append(message_json)
external_tool_calls = []
internal_tool_calls = []
if message.tool_calls:
message_json["response_type"] = "internal"
for tool_call in message.tool_calls:
tool_name = tool_call.function.name
if tool_name in external_tools:
external_tool_calls.append(tool_call)
else:
internal_tool_calls.append(tool_call)
message.tool_calls = internal_tool_calls
if not message.tool_calls or not execute_tools:
if external_tool_calls:
message.tool_calls.extend(external_tool_calls)
debug_print(debug, "Ending turn.")
break
# handle function calls, updating context_variables, and switching agents
all_functions = list(active_agent.child_functions.values()) + ([active_agent.parent_function] if active_agent.parent_function else [])
partial_response = self.handle_function_calls(
message.tool_calls, all_functions, context_variables, debug
)
for msg in partial_response.messages:
msg = add_message_metadata(msg, active_agent)
if localize_history:
active_agent = update_histories(active_agent, msg)
if parent and parent_has_child_history:
parent = update_histories(parent, msg)
global_history.extend(partial_response.messages)
context_variables.update(partial_response.context_variables)
# Parent to child transfer
if partial_response.agent:
prev_agent = active_agent
active_agent = partial_response.agent
# Parent to child transfer
if active_agent.name in prev_agent.children_names:
active_agent.most_recent_parent = prev_agent
active_agent.parent_function = active_agent.candidate_parent_functions[active_agent.most_recent_parent.name]
if localize_history:
if not parent_has_child_history:
prev_agent.history = remove_irrelevant_messages(prev_agent.history)
new_active_agent_history = get_current_turn_messages(global_history, only_user = True)
active_agent.history.extend(new_active_agent_history)
# Child to parent transfer
else:
assert parent == active_agent, "Parent and active agent do not match when active agent is not a child of previous agent"
child = prev_agent
if localize_history:
child.history = remove_irrelevant_messages(child.history)
return_messages = global_history[init_len:]
error_msg = ""
if len(global_history) - init_len >= max_messages_per_turn:
error_msg = "Max messages per turn reached"
return Response(
messages=return_messages,
agent=active_agent,
context_variables=context_variables,
error_msg=error_msg,
tokens_used=tokens_used
)

View file

@ -0,0 +1 @@
from .repl import run_demo_loop

View file

@ -0,0 +1,87 @@
import json
from swarm import Swarm
def process_and_print_streaming_response(response):
content = ""
last_sender = ""
for chunk in response:
if "sender" in chunk:
last_sender = chunk["sender"]
if "content" in chunk and chunk["content"] is not None:
if not content and last_sender:
print(f"\033[94m{last_sender}:\033[0m", end=" ", flush=True)
last_sender = ""
print(chunk["content"], end="", flush=True)
content += chunk["content"]
if "tool_calls" in chunk and chunk["tool_calls"] is not None:
for tool_call in chunk["tool_calls"]:
f = tool_call["function"]
name = f["name"]
if not name:
continue
print(f"\033[94m{last_sender}: \033[95m{name}\033[0m()")
if "delim" in chunk and chunk["delim"] == "end" and content:
print() # End of response message
content = ""
if "response" in chunk:
return chunk["response"]
def pretty_print_messages(messages) -> None:
for message in messages:
if message["role"] != "assistant":
continue
# print agent name in blue
print(f"\033[94m{message['sender']}\033[0m:", end=" ")
# print response, if any
if message["content"]:
print(message["content"])
# print tool calls in purple, if any
tool_calls = message.get("tool_calls") or []
if len(tool_calls) > 1:
print()
for tool_call in tool_calls:
f = tool_call["function"]
name, args = f["name"], f["arguments"]
arg_str = json.dumps(json.loads(args)).replace(":", "=")
print(f"\033[95m{name}\033[0m({arg_str[1:-1]})")
def run_demo_loop(
starting_agent, context_variables=None, stream=False, debug=False
) -> None:
client = Swarm()
print("Starting Swarm CLI 🐝")
messages = []
agent = starting_agent
while True:
user_input = input("\033[90mUser\033[0m: ")
messages.append({"role": "user", "content": user_input})
response = client.run(
agent=agent,
messages=messages,
context_variables=context_variables or {},
stream=stream,
debug=debug,
)
if stream:
response = process_and_print_streaming_response(response)
else:
pretty_print_messages(response.messages)
messages.extend(response.messages)
agent = response.agent

View file

@ -0,0 +1,54 @@
from __future__ import annotations
from openai.types.chat import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from typing import List, Callable, Union, Optional, Dict
# Third-party imports
from pydantic import BaseModel
AgentFunction = Callable[[], Union[str, "Agent", dict]]
class Agent(BaseModel):
name: str = "Agent"
model: str = "gpt-4o"
type: str = ""
instructions: Union[str, Callable[[], str]] = "You are a helpful agent.",
description: str = "This is a helpful agent."
candidate_parent_functions: Dict[str, AgentFunction] = {}
parent_function: AgentFunction = None
child_functions: Dict[str, AgentFunction] = {}
internal_tools: List[Dict] = []
external_tools: List[Dict] = []
tool_choice: str = None
parallel_tool_calls: bool = True
respond_to_user: bool = True
history: List[Dict] = []
children_names: List[str] = []
children: Dict[str, "Agent"] = {}
most_recent_parent: Optional["Agent"] = None
parent: "Agent" = None
class Response(BaseModel):
messages: List = []
agent: Optional[Agent] = None
context_variables: dict = {}
error_msg: Optional[str] = ""
tokens_used: dict = {}
class Result(BaseModel):
"""
Encapsulates the possible return values for an agent function.
Attributes:
value (str): The result value as a string.
agent (Agent): The agent instance, if applicable.
context_variables (dict): A dictionary of context variables.
"""
value: str = ""
agent: Optional[Agent] = None
context_variables: dict = {}

View file

@ -0,0 +1,175 @@
import inspect
import json
from datetime import datetime
import os
from dotenv import load_dotenv
from src.utils.common import read_json_from_file, get_api_key
load_dotenv()
OPENAI_API_KEY = get_api_key("OPENAI_API_KEY")
def debug_print(debug: bool, *args: str) -> None:
if not debug:
return
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
message = " ".join(map(str, args))
print(f"\033[97m[\033[90m{timestamp}\033[97m]\033[90m {message}\033[0m")
def merge_fields(target, source):
for key, value in source.items():
if isinstance(value, str):
target[key] += value
elif value is not None and isinstance(value, dict):
merge_fields(target[key], value)
def merge_chunk(final_response: dict, delta: dict) -> None:
delta.pop("role", None)
merge_fields(final_response, delta)
tool_calls = delta.get("tool_calls")
if tool_calls and len(tool_calls) > 0:
index = tool_calls[0].pop("index")
merge_fields(final_response["tool_calls"][index], tool_calls[0])
def function_to_json(func) -> dict:
"""
Converts a Python function into a JSON-serializable dictionary
that describes the function's signature, including its name,
description, and parameters.
Args:
func: The function to be converted.
Returns:
A dictionary representing the function's signature in JSON format.
"""
type_map = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
list: "array",
dict: "object",
type(None): "null",
}
try:
signature = inspect.signature(func)
except ValueError as e:
raise ValueError(
f"Failed to get signature for function {func.__name__}: {str(e)}"
)
parameters = {}
for param in signature.parameters.values():
try:
param_type = type_map.get(param.annotation, "string")
except KeyError as e:
raise KeyError(
f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}"
)
parameters[param.name] = {"type": param_type}
required = [
param.name
for param in signature.parameters.values()
if param.default == inspect._empty
]
return {
"type": "function",
"function": {
"name": func.__name__,
"description": func.__doc__ or "",
"parameters": {
"type": "object",
"properties": parameters,
"required": required,
},
},
}
def get_current_turn_messages(messages, only_user = False):
if only_user:
return [msg for msg in messages if msg.get("current_turn") and msg.get("role") == "user"]
else:
return [msg for msg in messages if msg.get("current_turn")]
def arrange_messages_keys_in_order(messages):
"""Arranges message keys in a specific order: id, role, sender, relevant_agents, content, created_at, timestamp, followed by rest alphabetically"""
key_order = ['role', 'sender', 'content', 'created_at']
def sort_keys(message):
# Create new dict with specified key order
ordered = {}
# Add keys in specified order if they exist
for key in key_order:
if key in message:
ordered[key] = message[key]
# Add remaining keys in alphabetical order
for key in sorted(message.keys()):
if key not in key_order:
ordered[key] = message[key]
return ordered
return [sort_keys(message) for message in messages]
def remove_irrelevant_messages(messages):
"""Removes all messages from and including the latest user message"""
for i in range(len(messages)-1, -1, -1):
if messages[i].get("role") == "user":
return messages[:i]
return messages
def update_histories(active_agent, message):
active_agent.history.append(message)
return active_agent
def remove_none_fields(message):
return {k: v for k, v in message.items() if v is not None}
def add_message_metadata(message, active_agent):
message = remove_none_fields(message)
message["created_at"] = datetime.now().isoformat()
message["current_turn"] = True
if active_agent.respond_to_user:
message["response_type"] = "external"
else:
message["response_type"] = "internal"
return message
def check_and_remove_repeat_tool_call_to_child(agent, messages):
# If in the current turn, the most recent assistant message (need not be the last message overall, just needs to be the last message with role as assistant) is a tool call from a child agent, which transfers control to the agent using its parent function, then remove the tool call to transfer to that child again from this agent. This is to prevent back and forth between this agent and the child agent.
for message in reversed(messages):
if message.get("role") == "assistant" and message.get("sender") in agent.children_names and message.get("tool_calls"):
tool_call = message.get("tool_calls")[0]
child_agent = agent.children.get(message.get("sender"), None)
if not child_agent:
continue
child_agent_name = child_agent.name
if tool_call.get("function").get("name") == child_agent.parent_function:
agent.children_names.remove(child_agent_name)
agent.children.pop(child_agent_name)
agent.child_functions.pop(child_agent_name)
break
return agent
def update_tokens_used(provider, model, tokens_used, completion):
provider_model = f"{provider}/{model}"
input_tokens = completion.usage.prompt_tokens
output_tokens = completion.usage.completion_tokens
if provider_model not in tokens_used:
tokens_used[provider_model] = {
'input_tokens': 0,
'output_tokens': 0,
}
tokens_used[provider_model]['input_tokens'] += input_tokens
tokens_used[provider_model]['output_tokens'] += output_tokens
return tokens_used