support streameable HTTP mcp

This commit is contained in:
Ramnique Singh 2025-06-08 16:23:51 +05:30
parent f25e3e2ed4
commit a79667b401
7 changed files with 100 additions and 85 deletions

View file

@ -23,6 +23,7 @@ from typing import Any
import asyncio
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel
from typing import List, Optional, Dict
@ -98,16 +99,34 @@ async def call_webhook(tool_name: str, args: str, webhook_url: str, signing_secr
async def call_mcp(tool_name: str, args: str, mcp_server_url: str) -> str:
try:
print(f"MCP tool called for: {tool_name} with args: {args} at url: {mcp_server_url}")
async with sse_client(url=mcp_server_url) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
jargs = json.loads(args)
response = await session.call_tool(tool_name, arguments=jargs)
json_output = json.dumps(response.content, default=lambda x: x.__dict__ if hasattr(x, '__dict__') else str(x), indent=2)
return json_output
# Try StreamableHTTP first
try:
print("Attempting to connect using StreamableHTTP...")
async with streamablehttp_client(mcp_server_url) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
jargs = json.loads(args)
response = await session.call_tool(tool_name, arguments=jargs)
json_output = json.dumps(response.content, default=lambda x: x.__dict__ if hasattr(x, '__dict__') else str(x), indent=2)
print("Successfully connected using StreamableHTTP")
return json_output
except Exception as streamable_error:
print(f"StreamableHTTP connection failed: {str(streamable_error)}")
print("Falling back to SSE...")
# Fallback to SSE
async with sse_client(url=mcp_server_url) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
jargs = json.loads(args)
response = await session.call_tool(tool_name, arguments=jargs)
json_output = json.dumps(response.content, default=lambda x: x.__dict__ if hasattr(x, '__dict__') else str(x), indent=2)
print("Successfully connected using SSE fallback")
return json_output
except Exception as e:
print(f"Error in call_mcp: {str(e)}")
print(f"Error in call_mcp (both StreamableHTTP and SSE failed): {str(e)}")
return f"Error: {str(e)}"
async def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict, complete_request: dict) -> str: