diff --git a/apps/rowboat_agents/configs/default_config.json b/apps/rowboat_agents/configs/default_config.json index 8f6dba83..c0866886 100644 --- a/apps/rowboat_agents/configs/default_config.json +++ b/apps/rowboat_agents/configs/default_config.json @@ -7,7 +7,5 @@ "max_messages_per_turn": 20, "max_messages_per_error_escalation_turn": 15, "escalate_errors": true, - "max_overall_turns": 25, - "max_calls_per_child_agent": 1, - "enable_tracing": false + "max_overall_turns": 25 } \ No newline at end of file diff --git a/apps/rowboat_agents/src/app/main.py b/apps/rowboat_agents/src/app/main.py index f16553a5..5d467836 100644 --- a/apps/rowboat_agents/src/app/main.py +++ b/apps/rowboat_agents/src/app/main.py @@ -16,6 +16,10 @@ app = Quart(__name__) master_config = read_json_from_file("./configs/default_config.json") print("Master config:", master_config) +# Get environment variables with defaults +MAX_CALLS_PER_CHILD_AGENT = int(os.environ.get('MAX_CALLS_PER_CHILD_AGENT', '1')) +ENABLE_TRACING = os.environ.get('ENABLE_TRACING', 'false').lower() == 'true' + # filter out agent transfer messages using a function def is_agent_transfer_message(msg): if (msg.get("role") == "assistant" and @@ -63,9 +67,6 @@ async def chat(): request_data = await request.get_json() print("Request:", json.dumps(request_data)) - # Add enable_tracing from master_config to request_data - request_data["enable_tracing"] = master_config.get("enable_tracing", False) - # filter out agent transfer messages input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)] @@ -93,10 +94,11 @@ async def chat(): tool_configs=data.get("tools", []), prompt_configs=data.get("prompts", []), start_turn_with_start_agent=master_config.get("start_turn_with_start_agent", False), - max_calls_per_child_agent=master_config.get("max_calls_per_child_agent", 1), + max_calls_per_child_agent=MAX_CALLS_PER_CHILD_AGENT, state=data.get("state", {}), additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL], - complete_request=data + complete_request=data, + enable_tracing=ENABLE_TRACING ): if event_type == 'message': messages.append(event_data) @@ -135,9 +137,6 @@ async def chat_stream(): print("Request:", request_data.decode('utf-8')) request_data = json.loads(request_data) - # Add enable_tracing from master_config to request_data - request_data["enable_tracing"] = master_config.get("enable_tracing", False) - # filter out agent transfer messages input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)] @@ -164,10 +163,11 @@ async def chat_stream(): tool_configs=request_data.get("tools", []), prompt_configs=request_data.get("prompts", []), start_turn_with_start_agent=master_config.get("start_turn_with_start_agent", False), - max_calls_per_child_agent=master_config.get("max_calls_per_child_agent", 1), + max_calls_per_child_agent=MAX_CALLS_PER_CHILD_AGENT, state=request_data.get("state", {}), additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL], - complete_request=request_data + complete_request=request_data, + enable_tracing=ENABLE_TRACING ): if event_type == 'message': yield format_sse(event_data, "message") diff --git a/apps/rowboat_agents/src/graph/core.py b/apps/rowboat_agents/src/graph/core.py index 5997d8e0..8d4d8285 100644 --- a/apps/rowboat_agents/src/graph/core.py +++ b/apps/rowboat_agents/src/graph/core.py @@ -192,9 +192,6 @@ async def run_turn_streamed( async for event in stream_result.stream_events(): try: - print('-'*100) - print(f"Event: {event}") - print('-'*100) # Handle web search events if event.type == "raw_response_event": web_search_messages = handle_web_search_event(event, current_agent) @@ -211,7 +208,6 @@ async def run_turn_streamed( # Handle agent transfer elif event.type == "agent_updated_stream_event": - # print(f"\nAgent transfer attempt: {current_agent.name} -> {event.new_agent.name}") # Skip self-transfers if current_agent.name == event.new_agent.name: diff --git a/apps/rowboat_agents/src/graph/execute_turn.py b/apps/rowboat_agents/src/graph/execute_turn.py index ce044701..29d6a05c 100644 --- a/apps/rowboat_agents/src/graph/execute_turn.py +++ b/apps/rowboat_agents/src/graph/execute_turn.py @@ -281,7 +281,7 @@ async def run_streamed( messages, external_tools=None, tokens_used=None, - enable_tracing=False # Changed default to False + enable_tracing=False ): """ Wrapper function for initializing and running the Swarm client in streaming mode. @@ -319,25 +319,26 @@ async def run_streamed( add_trace_processor(trace_processor) trace_processor_added = True - # Create a trace context only if tracing is enabled - trace_ctx = None - if enable_tracing: - trace_ctx = trace(f"Agent turn: {agent.name}") - trace_ctx.__enter__() - - # Get the stream result + # Get the stream result without trace context first stream_result = Runner.run_streamed(agent, formatted_messages) - # Patch the stream_events method to ensure trace context is maintained if tracing is enabled + # If tracing is enabled, wrap the stream_events to handle tracing if enable_tracing: original_stream_events = stream_result.stream_events + async def wrapped_stream_events(): - try: - async for event in original_stream_events(): - yield event - finally: - if trace_ctx: - trace_ctx.__exit__(None, None, None) + # Create trace context inside the async function + with trace(f"Agent turn: {agent.name}") as trace_ctx: + try: + async for event in original_stream_events(): + yield event + except GeneratorExit: + # Handle generator exit gracefully + raise + except Exception as e: + print(f"Error in stream events: {str(e)}") + raise + stream_result.stream_events = wrapped_stream_events return stream_result diff --git a/docker-compose.yml b/docker-compose.yml index 4db22aee..a09d8bb7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -56,6 +56,8 @@ services: - PROVIDER_BASE_URL=${PROVIDER_BASE_URL} - PROVIDER_API_KEY=${PROVIDER_API_KEY} - PROVIDER_DEFAULT_MODEL=${PROVIDER_DEFAULT_MODEL} + - MAX_CALLS_PER_CHILD_AGENT=${MAX_CALLS_PER_CHILD_AGENT} + - ENABLE_TRACING=${ENABLE_TRACING} restart: unless-stopped copilot: