mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-29 10:26:23 +02:00
Add agents with custom swarm implementation
This commit is contained in:
parent
24c4f6e552
commit
a19dedd59f
35 changed files with 3413 additions and 0 deletions
4
apps/agents/src/swarm/__init__.py
Normal file
4
apps/agents/src/swarm/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .core import Swarm
|
||||
from .types import Agent, Response
|
||||
|
||||
__all__ = ["Swarm", "Agent", "Response"]
|
||||
269
apps/agents/src/swarm/core.py
Normal file
269
apps/agents/src/swarm/core.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
# Standard library imports
|
||||
import copy
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import List, Callable, Union
|
||||
from datetime import datetime
|
||||
|
||||
# Package/library imports
|
||||
from openai import OpenAI
|
||||
import random
|
||||
|
||||
# Local imports
|
||||
from .util import *
|
||||
from .types import (
|
||||
Agent,
|
||||
AgentFunction,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
Response,
|
||||
Result,
|
||||
)
|
||||
|
||||
__CTX_VARS_NAME__ = "context_variables"
|
||||
|
||||
|
||||
class Swarm:
|
||||
def __init__(self, client=None):
|
||||
if not client:
|
||||
client = OpenAI(api_key=OPENAI_API_KEY)
|
||||
self.client = client
|
||||
self.history = defaultdict(lambda : [])
|
||||
|
||||
def get_chat_completion(
|
||||
self,
|
||||
agent: Agent,
|
||||
history: List,
|
||||
context_variables: dict,
|
||||
model_override: str,
|
||||
stream: bool,
|
||||
debug: bool,
|
||||
) -> ChatCompletionMessage:
|
||||
context_variables = defaultdict(str, context_variables)
|
||||
instructions = (
|
||||
agent.instructions(context_variables)
|
||||
if callable(agent.instructions)
|
||||
else agent.instructions
|
||||
)
|
||||
messages = [{"role": "system", "content": instructions}] + history
|
||||
debug_print(debug, "Getting chat completion for...:", messages)
|
||||
|
||||
all_functions = list(agent.child_functions.values()) + ([agent.parent_function] if agent.parent_function else [])
|
||||
all_tools = agent.external_tools + agent.internal_tools
|
||||
funcs_and_tools = [function_to_json(f) for f in all_functions] + [t for t in all_tools]
|
||||
# hide context_variables from model
|
||||
for tool in funcs_and_tools:
|
||||
params = tool["function"]["parameters"]
|
||||
params["properties"].pop(__CTX_VARS_NAME__, None)
|
||||
if __CTX_VARS_NAME__ in params["required"]:
|
||||
params["required"].remove(__CTX_VARS_NAME__)
|
||||
|
||||
create_params = {
|
||||
"model": model_override or agent.model,
|
||||
"messages": messages,
|
||||
"tools": funcs_and_tools or None,
|
||||
"tool_choice": agent.tool_choice,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if funcs_and_tools:
|
||||
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
|
||||
|
||||
return self.client.chat.completions.create(**create_params)
|
||||
|
||||
def handle_function_result(self, result, debug) -> Result:
|
||||
match result:
|
||||
case Result() as result:
|
||||
return result
|
||||
|
||||
case Agent() as agent:
|
||||
return Result(
|
||||
value=json.dumps({"assistant": agent.name}),
|
||||
agent=agent,
|
||||
)
|
||||
case _:
|
||||
try:
|
||||
return Result(value=str(result))
|
||||
except Exception as e:
|
||||
error_message = f"Failed to cast response to string: {result}. Make sure agent functions return a string or Result object. Error: {str(e)}"
|
||||
debug_print(debug, error_message)
|
||||
raise TypeError(error_message)
|
||||
|
||||
def handle_function_calls(
|
||||
self,
|
||||
tool_calls: List[ChatCompletionMessageToolCall],
|
||||
functions: List[AgentFunction],
|
||||
context_variables: dict,
|
||||
debug: bool,
|
||||
) -> Response:
|
||||
function_map = {f.__name__: f for f in functions}
|
||||
partial_response = Response(
|
||||
messages=[], agent=None, context_variables={})
|
||||
|
||||
for tool_call in tool_calls:
|
||||
name = tool_call.function.name
|
||||
# handle missing tool case, skip to next tool
|
||||
if name not in function_map:
|
||||
debug_print(debug, f"Tool {name} not found in function map.")
|
||||
partial_response.messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"tool_name": name,
|
||||
"content": f"Error: Tool {name} not found.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
debug_print(
|
||||
debug, f"Processing tool call: {name} with arguments {args}")
|
||||
|
||||
func = function_map[name]
|
||||
# pass context_variables to agent functions
|
||||
if __CTX_VARS_NAME__ in func.__code__.co_varnames:
|
||||
args[__CTX_VARS_NAME__] = context_variables
|
||||
raw_result = function_map[name](**args)
|
||||
|
||||
result: Result = self.handle_function_result(raw_result, debug)
|
||||
partial_response.messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"tool_name": name,
|
||||
"content": result.value,
|
||||
}
|
||||
)
|
||||
partial_response.context_variables.update(result.context_variables)
|
||||
if result.agent:
|
||||
partial_response.agent = result.agent
|
||||
|
||||
return partial_response
|
||||
|
||||
def run(
|
||||
self,
|
||||
agent: Agent,
|
||||
messages: List,
|
||||
context_variables: dict = {},
|
||||
model_override: str = None,
|
||||
stream: bool = False,
|
||||
debug: bool = False,
|
||||
max_messages_per_turn: int = 10,
|
||||
execute_tools: bool = True,
|
||||
external_tools: List[str] = [],
|
||||
localize_history: bool = True,
|
||||
parent_has_child_history: bool = True,
|
||||
tokens_used: dict = {}
|
||||
) -> Response:
|
||||
|
||||
active_agent = agent
|
||||
context_variables = copy.deepcopy(context_variables)
|
||||
global_history = copy.deepcopy(messages)
|
||||
init_len = len(messages)
|
||||
|
||||
while len(global_history) - init_len < max_messages_per_turn and active_agent:
|
||||
history = active_agent.history if localize_history else global_history
|
||||
history = arrange_messages_keys_in_order(history)
|
||||
|
||||
parent = active_agent.most_recent_parent
|
||||
|
||||
children_names_backup, children_backup, child_functions_backup = copy.deepcopy(active_agent.children_names), copy.deepcopy(active_agent.children), copy.deepcopy(active_agent.child_functions)
|
||||
|
||||
active_agent = check_and_remove_repeat_tool_call_to_child(active_agent, history)
|
||||
|
||||
# get completion with current history, agent
|
||||
completion = self.get_chat_completion(
|
||||
agent=active_agent,
|
||||
history=history,
|
||||
context_variables=context_variables,
|
||||
model_override=model_override,
|
||||
stream=stream,
|
||||
debug=debug,
|
||||
)
|
||||
tokens_used = update_tokens_used(provider="openai", model=model_override or active_agent.model, tokens_used=tokens_used, completion=completion)
|
||||
|
||||
# Restore children and child functions
|
||||
active_agent.children_names, active_agent.children, active_agent.child_functions = children_names_backup, children_backup, child_functions_backup
|
||||
|
||||
message = completion.choices[0].message
|
||||
debug_print(debug, "Received completion:", message)
|
||||
message.sender = active_agent.name
|
||||
message_json = json.loads(message.model_dump_json())
|
||||
message_json = add_message_metadata(message_json, active_agent)
|
||||
|
||||
if localize_history:
|
||||
active_agent = update_histories(active_agent, message_json)
|
||||
if parent and parent_has_child_history:
|
||||
parent = update_histories(parent, message_json)
|
||||
global_history.append(message_json)
|
||||
|
||||
external_tool_calls = []
|
||||
internal_tool_calls = []
|
||||
|
||||
if message.tool_calls:
|
||||
message_json["response_type"] = "internal"
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
if tool_name in external_tools:
|
||||
external_tool_calls.append(tool_call)
|
||||
else:
|
||||
internal_tool_calls.append(tool_call)
|
||||
message.tool_calls = internal_tool_calls
|
||||
|
||||
if not message.tool_calls or not execute_tools:
|
||||
if external_tool_calls:
|
||||
message.tool_calls.extend(external_tool_calls)
|
||||
debug_print(debug, "Ending turn.")
|
||||
break
|
||||
|
||||
# handle function calls, updating context_variables, and switching agents
|
||||
all_functions = list(active_agent.child_functions.values()) + ([active_agent.parent_function] if active_agent.parent_function else [])
|
||||
partial_response = self.handle_function_calls(
|
||||
message.tool_calls, all_functions, context_variables, debug
|
||||
)
|
||||
for msg in partial_response.messages:
|
||||
msg = add_message_metadata(msg, active_agent)
|
||||
if localize_history:
|
||||
active_agent = update_histories(active_agent, msg)
|
||||
if parent and parent_has_child_history:
|
||||
parent = update_histories(parent, msg)
|
||||
|
||||
global_history.extend(partial_response.messages)
|
||||
context_variables.update(partial_response.context_variables)
|
||||
|
||||
# Parent to child transfer
|
||||
if partial_response.agent:
|
||||
prev_agent = active_agent
|
||||
active_agent = partial_response.agent
|
||||
|
||||
# Parent to child transfer
|
||||
if active_agent.name in prev_agent.children_names:
|
||||
active_agent.most_recent_parent = prev_agent
|
||||
active_agent.parent_function = active_agent.candidate_parent_functions[active_agent.most_recent_parent.name]
|
||||
if localize_history:
|
||||
if not parent_has_child_history:
|
||||
prev_agent.history = remove_irrelevant_messages(prev_agent.history)
|
||||
new_active_agent_history = get_current_turn_messages(global_history, only_user = True)
|
||||
active_agent.history.extend(new_active_agent_history)
|
||||
|
||||
# Child to parent transfer
|
||||
else:
|
||||
assert parent == active_agent, "Parent and active agent do not match when active agent is not a child of previous agent"
|
||||
child = prev_agent
|
||||
if localize_history:
|
||||
child.history = remove_irrelevant_messages(child.history)
|
||||
|
||||
|
||||
return_messages = global_history[init_len:]
|
||||
error_msg = ""
|
||||
|
||||
if len(global_history) - init_len >= max_messages_per_turn:
|
||||
error_msg = "Max messages per turn reached"
|
||||
|
||||
return Response(
|
||||
messages=return_messages,
|
||||
agent=active_agent,
|
||||
context_variables=context_variables,
|
||||
error_msg=error_msg,
|
||||
tokens_used=tokens_used
|
||||
)
|
||||
1
apps/agents/src/swarm/repl/__init__.py
Normal file
1
apps/agents/src/swarm/repl/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .repl import run_demo_loop
|
||||
87
apps/agents/src/swarm/repl/repl.py
Normal file
87
apps/agents/src/swarm/repl/repl.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import json
|
||||
|
||||
from swarm import Swarm
|
||||
|
||||
|
||||
def process_and_print_streaming_response(response):
|
||||
content = ""
|
||||
last_sender = ""
|
||||
|
||||
for chunk in response:
|
||||
if "sender" in chunk:
|
||||
last_sender = chunk["sender"]
|
||||
|
||||
if "content" in chunk and chunk["content"] is not None:
|
||||
if not content and last_sender:
|
||||
print(f"\033[94m{last_sender}:\033[0m", end=" ", flush=True)
|
||||
last_sender = ""
|
||||
print(chunk["content"], end="", flush=True)
|
||||
content += chunk["content"]
|
||||
|
||||
if "tool_calls" in chunk and chunk["tool_calls"] is not None:
|
||||
for tool_call in chunk["tool_calls"]:
|
||||
f = tool_call["function"]
|
||||
name = f["name"]
|
||||
if not name:
|
||||
continue
|
||||
print(f"\033[94m{last_sender}: \033[95m{name}\033[0m()")
|
||||
|
||||
if "delim" in chunk and chunk["delim"] == "end" and content:
|
||||
print() # End of response message
|
||||
content = ""
|
||||
|
||||
if "response" in chunk:
|
||||
return chunk["response"]
|
||||
|
||||
|
||||
def pretty_print_messages(messages) -> None:
|
||||
for message in messages:
|
||||
if message["role"] != "assistant":
|
||||
continue
|
||||
|
||||
# print agent name in blue
|
||||
print(f"\033[94m{message['sender']}\033[0m:", end=" ")
|
||||
|
||||
# print response, if any
|
||||
if message["content"]:
|
||||
print(message["content"])
|
||||
|
||||
# print tool calls in purple, if any
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if len(tool_calls) > 1:
|
||||
print()
|
||||
for tool_call in tool_calls:
|
||||
f = tool_call["function"]
|
||||
name, args = f["name"], f["arguments"]
|
||||
arg_str = json.dumps(json.loads(args)).replace(":", "=")
|
||||
print(f"\033[95m{name}\033[0m({arg_str[1:-1]})")
|
||||
|
||||
|
||||
def run_demo_loop(
|
||||
starting_agent, context_variables=None, stream=False, debug=False
|
||||
) -> None:
|
||||
client = Swarm()
|
||||
print("Starting Swarm CLI 🐝")
|
||||
|
||||
messages = []
|
||||
agent = starting_agent
|
||||
|
||||
while True:
|
||||
user_input = input("\033[90mUser\033[0m: ")
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
response = client.run(
|
||||
agent=agent,
|
||||
messages=messages,
|
||||
context_variables=context_variables or {},
|
||||
stream=stream,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
if stream:
|
||||
response = process_and_print_streaming_response(response)
|
||||
else:
|
||||
pretty_print_messages(response.messages)
|
||||
|
||||
messages.extend(response.messages)
|
||||
agent = response.agent
|
||||
54
apps/agents/src/swarm/types.py
Normal file
54
apps/agents/src/swarm/types.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from openai.types.chat import ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
from typing import List, Callable, Union, Optional, Dict
|
||||
|
||||
# Third-party imports
|
||||
from pydantic import BaseModel
|
||||
|
||||
AgentFunction = Callable[[], Union[str, "Agent", dict]]
|
||||
|
||||
class Agent(BaseModel):
|
||||
name: str = "Agent"
|
||||
model: str = "gpt-4o"
|
||||
type: str = ""
|
||||
instructions: Union[str, Callable[[], str]] = "You are a helpful agent.",
|
||||
description: str = "This is a helpful agent."
|
||||
candidate_parent_functions: Dict[str, AgentFunction] = {}
|
||||
parent_function: AgentFunction = None
|
||||
child_functions: Dict[str, AgentFunction] = {}
|
||||
internal_tools: List[Dict] = []
|
||||
external_tools: List[Dict] = []
|
||||
tool_choice: str = None
|
||||
parallel_tool_calls: bool = True
|
||||
respond_to_user: bool = True
|
||||
history: List[Dict] = []
|
||||
children_names: List[str] = []
|
||||
children: Dict[str, "Agent"] = {}
|
||||
most_recent_parent: Optional["Agent"] = None
|
||||
parent: "Agent" = None
|
||||
|
||||
class Response(BaseModel):
|
||||
messages: List = []
|
||||
agent: Optional[Agent] = None
|
||||
context_variables: dict = {}
|
||||
error_msg: Optional[str] = ""
|
||||
tokens_used: dict = {}
|
||||
|
||||
class Result(BaseModel):
|
||||
"""
|
||||
Encapsulates the possible return values for an agent function.
|
||||
|
||||
Attributes:
|
||||
value (str): The result value as a string.
|
||||
agent (Agent): The agent instance, if applicable.
|
||||
context_variables (dict): A dictionary of context variables.
|
||||
"""
|
||||
|
||||
value: str = ""
|
||||
agent: Optional[Agent] = None
|
||||
context_variables: dict = {}
|
||||
175
apps/agents/src/swarm/util.py
Normal file
175
apps/agents/src/swarm/util.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
import inspect
|
||||
import json
|
||||
from datetime import datetime
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from src.utils.common import read_json_from_file, get_api_key
|
||||
|
||||
load_dotenv()
|
||||
OPENAI_API_KEY = get_api_key("OPENAI_API_KEY")
|
||||
|
||||
def debug_print(debug: bool, *args: str) -> None:
|
||||
if not debug:
|
||||
return
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
message = " ".join(map(str, args))
|
||||
print(f"\033[97m[\033[90m{timestamp}\033[97m]\033[90m {message}\033[0m")
|
||||
|
||||
def merge_fields(target, source):
|
||||
for key, value in source.items():
|
||||
if isinstance(value, str):
|
||||
target[key] += value
|
||||
elif value is not None and isinstance(value, dict):
|
||||
merge_fields(target[key], value)
|
||||
|
||||
|
||||
def merge_chunk(final_response: dict, delta: dict) -> None:
|
||||
delta.pop("role", None)
|
||||
merge_fields(final_response, delta)
|
||||
|
||||
tool_calls = delta.get("tool_calls")
|
||||
if tool_calls and len(tool_calls) > 0:
|
||||
index = tool_calls[0].pop("index")
|
||||
merge_fields(final_response["tool_calls"][index], tool_calls[0])
|
||||
|
||||
|
||||
def function_to_json(func) -> dict:
|
||||
"""
|
||||
Converts a Python function into a JSON-serializable dictionary
|
||||
that describes the function's signature, including its name,
|
||||
description, and parameters.
|
||||
|
||||
Args:
|
||||
func: The function to be converted.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the function's signature in JSON format.
|
||||
"""
|
||||
type_map = {
|
||||
str: "string",
|
||||
int: "integer",
|
||||
float: "number",
|
||||
bool: "boolean",
|
||||
list: "array",
|
||||
dict: "object",
|
||||
type(None): "null",
|
||||
}
|
||||
|
||||
try:
|
||||
signature = inspect.signature(func)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Failed to get signature for function {func.__name__}: {str(e)}"
|
||||
)
|
||||
|
||||
parameters = {}
|
||||
for param in signature.parameters.values():
|
||||
try:
|
||||
param_type = type_map.get(param.annotation, "string")
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}"
|
||||
)
|
||||
parameters[param.name] = {"type": param_type}
|
||||
|
||||
required = [
|
||||
param.name
|
||||
for param in signature.parameters.values()
|
||||
if param.default == inspect._empty
|
||||
]
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func.__name__,
|
||||
"description": func.__doc__ or "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def get_current_turn_messages(messages, only_user = False):
|
||||
if only_user:
|
||||
return [msg for msg in messages if msg.get("current_turn") and msg.get("role") == "user"]
|
||||
else:
|
||||
return [msg for msg in messages if msg.get("current_turn")]
|
||||
|
||||
def arrange_messages_keys_in_order(messages):
|
||||
"""Arranges message keys in a specific order: id, role, sender, relevant_agents, content, created_at, timestamp, followed by rest alphabetically"""
|
||||
key_order = ['role', 'sender', 'content', 'created_at']
|
||||
|
||||
def sort_keys(message):
|
||||
# Create new dict with specified key order
|
||||
ordered = {}
|
||||
# Add keys in specified order if they exist
|
||||
for key in key_order:
|
||||
if key in message:
|
||||
ordered[key] = message[key]
|
||||
# Add remaining keys in alphabetical order
|
||||
for key in sorted(message.keys()):
|
||||
if key not in key_order:
|
||||
ordered[key] = message[key]
|
||||
return ordered
|
||||
|
||||
return [sort_keys(message) for message in messages]
|
||||
|
||||
def remove_irrelevant_messages(messages):
|
||||
"""Removes all messages from and including the latest user message"""
|
||||
for i in range(len(messages)-1, -1, -1):
|
||||
if messages[i].get("role") == "user":
|
||||
return messages[:i]
|
||||
return messages
|
||||
|
||||
def update_histories(active_agent, message):
|
||||
active_agent.history.append(message)
|
||||
return active_agent
|
||||
|
||||
def remove_none_fields(message):
|
||||
return {k: v for k, v in message.items() if v is not None}
|
||||
|
||||
def add_message_metadata(message, active_agent):
|
||||
message = remove_none_fields(message)
|
||||
message["created_at"] = datetime.now().isoformat()
|
||||
message["current_turn"] = True
|
||||
|
||||
if active_agent.respond_to_user:
|
||||
message["response_type"] = "external"
|
||||
else:
|
||||
message["response_type"] = "internal"
|
||||
|
||||
return message
|
||||
|
||||
def check_and_remove_repeat_tool_call_to_child(agent, messages):
|
||||
# If in the current turn, the most recent assistant message (need not be the last message overall, just needs to be the last message with role as assistant) is a tool call from a child agent, which transfers control to the agent using its parent function, then remove the tool call to transfer to that child again from this agent. This is to prevent back and forth between this agent and the child agent.
|
||||
for message in reversed(messages):
|
||||
if message.get("role") == "assistant" and message.get("sender") in agent.children_names and message.get("tool_calls"):
|
||||
tool_call = message.get("tool_calls")[0]
|
||||
child_agent = agent.children.get(message.get("sender"), None)
|
||||
if not child_agent:
|
||||
continue
|
||||
child_agent_name = child_agent.name
|
||||
if tool_call.get("function").get("name") == child_agent.parent_function:
|
||||
agent.children_names.remove(child_agent_name)
|
||||
agent.children.pop(child_agent_name)
|
||||
agent.child_functions.pop(child_agent_name)
|
||||
break
|
||||
return agent
|
||||
|
||||
def update_tokens_used(provider, model, tokens_used, completion):
|
||||
provider_model = f"{provider}/{model}"
|
||||
input_tokens = completion.usage.prompt_tokens
|
||||
output_tokens = completion.usage.completion_tokens
|
||||
|
||||
if provider_model not in tokens_used:
|
||||
tokens_used[provider_model] = {
|
||||
'input_tokens': 0,
|
||||
'output_tokens': 0,
|
||||
}
|
||||
|
||||
tokens_used[provider_model]['input_tokens'] += input_tokens
|
||||
tokens_used[provider_model]['output_tokens'] += output_tokens
|
||||
|
||||
return tokens_used
|
||||
Loading…
Add table
Add a link
Reference in a new issue