diff --git a/apps/agents/src/graph/core.py b/apps/agents/src/graph/core.py index 35db9fa9..f19c14ae 100644 --- a/apps/agents/src/graph/core.py +++ b/apps/agents/src/graph/core.py @@ -11,7 +11,7 @@ from .types import AgentRole, PromptType, ErrorType from .helpers.access import get_agent_data_by_name, get_agent_by_name, get_agent_config_by_name, get_tool_config_by_name, get_tool_config_by_type, get_external_tools, get_prompt_by_type, pop_agent_config_by_type, get_agent_by_type from .helpers.transfer import create_transfer_function_to_agent, create_transfer_function_to_parent_agent from .helpers.state import add_recent_messages_to_history, construct_state_from_response, reset_current_turn, reset_current_turn_agent_history -from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message +from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message, add_universal_system_message_to_agent from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name from src.utils.common import common_logger @@ -50,7 +50,7 @@ def clear_agent_fields(agent): return agent -def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent): +def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent, universal_sys_msg): # Create Agent objects agents = [] @@ -183,6 +183,9 @@ def get_agents(agent_configs, tool_configs, localize_history, available_tool_map if agent.most_recent_parent: assert agent.most_recent_parent.name in agent.candidate_parent_functions, f"Most recent parent {agent.most_recent_parent.name} not found in candidate parent functions for agent {agent.name}" agent.parent_function = agent.candidate_parent_functions[agent.most_recent_parent.name] + + for agent in agents: + agent = add_universal_system_message_to_agent(agent, universal_sys_msg) return agents @@ -342,6 +345,7 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_ localize_history=localize_history, start_turn_with_start_agent=start_turn_with_start_agent, children_aware_of_parent=children_aware_of_parent, + universal_sys_msg=universal_sys_msg ) if not all_agents: logger.error("No agents initialized") @@ -396,8 +400,7 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_ localize_history=localize_history, parent_has_child_history=parent_has_child_history, max_messages_per_turn=max_messages_per_turn, - tokens_used=tokens_used, - universal_sys_msg=universal_sys_msg + tokens_used=tokens_used ) tokens_used = response.tokens_used last_agent = response.agent diff --git a/apps/agents/src/graph/helpers/instructions.py b/apps/agents/src/graph/helpers/instructions.py index 8fc33959..a6e1d6ff 100644 --- a/apps/agents/src/graph/helpers/instructions.py +++ b/apps/agents/src/graph/helpers/instructions.py @@ -32,4 +32,8 @@ def add_error_escalation_instructions(agent): def get_universal_system_message(messages): if messages and messages[0].get("role") == "system": return SYSTEM_MESSAGE.format(system_message=messages[0].get("content")) - return "" \ No newline at end of file + return "" + +def add_universal_system_message_to_agent(agent, universal_sys_msg): + agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + universal_sys_msg + return agent \ No newline at end of file diff --git a/apps/agents/src/swarm/core.py b/apps/agents/src/swarm/core.py index c93334d1..9150d9c1 100644 --- a/apps/agents/src/swarm/core.py +++ b/apps/agents/src/swarm/core.py @@ -39,7 +39,6 @@ class Swarm: model_override: str, stream: bool, debug: bool, - universal_sys_msg: str, ) -> ChatCompletionMessage: context_variables = defaultdict(str, context_variables) instructions = ( @@ -47,7 +46,7 @@ class Swarm: if callable(agent.instructions) else agent.instructions ) - messages = [{"role": "system", "content": instructions + universal_sys_msg}] + history + messages = [{"role": "system", "content": instructions}] + history debug_print(debug, "Getting chat completion for...:", messages) all_functions = list(agent.child_functions.values()) + ([agent.parent_function] if agent.parent_function else []) @@ -156,8 +155,7 @@ class Swarm: external_tools: List[str] = [], localize_history: bool = True, parent_has_child_history: bool = True, - tokens_used: dict = {}, - universal_sys_msg: str = '', + tokens_used: dict = {} ) -> Response: active_agent = agent @@ -182,8 +180,7 @@ class Swarm: context_variables=context_variables, model_override=model_override, stream=stream, - debug=debug, - universal_sys_msg=universal_sys_msg, + debug=debug ) tokens_used = update_tokens_used(provider="openai", model=model_override or active_agent.model, tokens_used=tokens_used, completion=completion)