mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-27 17:36:25 +02:00
Move system message additions outside swarm
This commit is contained in:
parent
c1db2b7306
commit
27fd1a069a
3 changed files with 15 additions and 11 deletions
|
|
@ -11,7 +11,7 @@ from .types import AgentRole, PromptType, ErrorType
|
|||
from .helpers.access import get_agent_data_by_name, get_agent_by_name, get_agent_config_by_name, get_tool_config_by_name, get_tool_config_by_type, get_external_tools, get_prompt_by_type, pop_agent_config_by_type, get_agent_by_type
|
||||
from .helpers.transfer import create_transfer_function_to_agent, create_transfer_function_to_parent_agent
|
||||
from .helpers.state import add_recent_messages_to_history, construct_state_from_response, reset_current_turn, reset_current_turn_agent_history
|
||||
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message
|
||||
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message, add_universal_system_message_to_agent
|
||||
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
|
||||
|
||||
from src.utils.common import common_logger
|
||||
|
|
@ -50,7 +50,7 @@ def clear_agent_fields(agent):
|
|||
|
||||
return agent
|
||||
|
||||
def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent):
|
||||
def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent, universal_sys_msg):
|
||||
# Create Agent objects
|
||||
agents = []
|
||||
|
||||
|
|
@ -183,6 +183,9 @@ def get_agents(agent_configs, tool_configs, localize_history, available_tool_map
|
|||
if agent.most_recent_parent:
|
||||
assert agent.most_recent_parent.name in agent.candidate_parent_functions, f"Most recent parent {agent.most_recent_parent.name} not found in candidate parent functions for agent {agent.name}"
|
||||
agent.parent_function = agent.candidate_parent_functions[agent.most_recent_parent.name]
|
||||
|
||||
for agent in agents:
|
||||
agent = add_universal_system_message_to_agent(agent, universal_sys_msg)
|
||||
|
||||
return agents
|
||||
|
||||
|
|
@ -342,6 +345,7 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
|
|||
localize_history=localize_history,
|
||||
start_turn_with_start_agent=start_turn_with_start_agent,
|
||||
children_aware_of_parent=children_aware_of_parent,
|
||||
universal_sys_msg=universal_sys_msg
|
||||
)
|
||||
if not all_agents:
|
||||
logger.error("No agents initialized")
|
||||
|
|
@ -396,8 +400,7 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
|
|||
localize_history=localize_history,
|
||||
parent_has_child_history=parent_has_child_history,
|
||||
max_messages_per_turn=max_messages_per_turn,
|
||||
tokens_used=tokens_used,
|
||||
universal_sys_msg=universal_sys_msg
|
||||
tokens_used=tokens_used
|
||||
)
|
||||
tokens_used = response.tokens_used
|
||||
last_agent = response.agent
|
||||
|
|
|
|||
|
|
@ -32,4 +32,8 @@ def add_error_escalation_instructions(agent):
|
|||
def get_universal_system_message(messages):
|
||||
if messages and messages[0].get("role") == "system":
|
||||
return SYSTEM_MESSAGE.format(system_message=messages[0].get("content"))
|
||||
return ""
|
||||
return ""
|
||||
|
||||
def add_universal_system_message_to_agent(agent, universal_sys_msg):
|
||||
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + universal_sys_msg
|
||||
return agent
|
||||
|
|
@ -39,7 +39,6 @@ class Swarm:
|
|||
model_override: str,
|
||||
stream: bool,
|
||||
debug: bool,
|
||||
universal_sys_msg: str,
|
||||
) -> ChatCompletionMessage:
|
||||
context_variables = defaultdict(str, context_variables)
|
||||
instructions = (
|
||||
|
|
@ -47,7 +46,7 @@ class Swarm:
|
|||
if callable(agent.instructions)
|
||||
else agent.instructions
|
||||
)
|
||||
messages = [{"role": "system", "content": instructions + universal_sys_msg}] + history
|
||||
messages = [{"role": "system", "content": instructions}] + history
|
||||
debug_print(debug, "Getting chat completion for...:", messages)
|
||||
|
||||
all_functions = list(agent.child_functions.values()) + ([agent.parent_function] if agent.parent_function else [])
|
||||
|
|
@ -156,8 +155,7 @@ class Swarm:
|
|||
external_tools: List[str] = [],
|
||||
localize_history: bool = True,
|
||||
parent_has_child_history: bool = True,
|
||||
tokens_used: dict = {},
|
||||
universal_sys_msg: str = '',
|
||||
tokens_used: dict = {}
|
||||
) -> Response:
|
||||
|
||||
active_agent = agent
|
||||
|
|
@ -182,8 +180,7 @@ class Swarm:
|
|||
context_variables=context_variables,
|
||||
model_override=model_override,
|
||||
stream=stream,
|
||||
debug=debug,
|
||||
universal_sys_msg=universal_sys_msg,
|
||||
debug=debug
|
||||
)
|
||||
tokens_used = update_tokens_used(provider="openai", model=model_override or active_agent.model, tokens_used=tokens_used, completion=completion)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue