Merge remote-tracking branch 'origin/main' into pr-381

This commit is contained in:
Abhishek Kumar 2026-06-02 12:11:57 +05:30
commit 858c474139
119 changed files with 5057 additions and 1018 deletions

View file

@ -16,6 +16,8 @@ TYPE_MAP = {
"string": "string",
"number": "number",
"boolean": "boolean",
"object": "object",
"array": "array",
}
@ -45,10 +47,24 @@ def tool_to_function_schema(tool: Any) -> Dict[str, Any]:
if not param_name:
continue
properties[param_name] = {
"type": TYPE_MAP.get(param_type, "string"),
"description": param_desc,
}
schema_type = TYPE_MAP.get(param_type, "string")
if schema_type == "object":
properties[param_name] = {
"type": "object",
"additionalProperties": True,
"description": param_desc,
}
elif schema_type == "array":
properties[param_name] = {
"type": "array",
"items": {},
"description": param_desc,
}
else:
properties[param_name] = {
"type": schema_type,
"description": param_desc,
}
if param_required:
required.append(param_name)
@ -127,6 +143,26 @@ def _coerce_parameter_value(value: Any, param_type: str) -> Any:
raise ValueError(f"Cannot convert '{value}' to boolean")
if param_type == "object":
if isinstance(value, str):
try:
value = json.loads(value)
except json.JSONDecodeError as exc:
raise ValueError(f"Cannot convert '{value}' to object") from exc
if isinstance(value, dict):
return value
raise ValueError(f"Cannot convert '{value}' to object")
if param_type == "array":
if isinstance(value, str):
try:
value = json.loads(value)
except json.JSONDecodeError as exc:
raise ValueError(f"Cannot convert '{value}' to array") from exc
if isinstance(value, list):
return value
raise ValueError(f"Cannot convert '{value}' to array")
return value

View file

@ -4,70 +4,27 @@ LLM-function-name namespacing. No I/O, no MCP protocol here."""
from __future__ import annotations
import re
from typing import Any, Dict, Literal, Optional
from typing import Any, Dict
from pydantic import BaseModel, Field, ValidationError, field_validator
from pydantic import ValidationError
DEFAULT_TIMEOUT_SECS = 30
DEFAULT_SSE_READ_TIMEOUT_SECS = 300
from api.schemas.tool import (
DEFAULT_MCP_SSE_READ_TIMEOUT_SECS,
DEFAULT_MCP_TIMEOUT_SECS,
McpToolDefinition,
)
from api.schemas.tool import (
McpToolConfig as McpToolConfig,
)
DEFAULT_TIMEOUT_SECS = DEFAULT_MCP_TIMEOUT_SECS
DEFAULT_SSE_READ_TIMEOUT_SECS = DEFAULT_MCP_SSE_READ_TIMEOUT_SECS
class McpDefinitionError(ValueError):
"""Raised when an MCP tool definition is structurally invalid."""
class McpToolConfig(BaseModel):
"""Configuration for an MCP tool definition."""
transport: Literal["streamable_http"] = Field(
default="streamable_http", description="MCP transport protocol"
)
url: str = Field(description="MCP server URL (must be http:// or https://)")
credential_uuid: Optional[str] = Field(
default=None, description="Reference to ExternalCredentialModel for auth"
)
tools_filter: list[str] = Field(
default_factory=list,
description="Allowlist of MCP tool names to expose (empty = all tools)",
)
timeout_secs: int = Field(
default=DEFAULT_TIMEOUT_SECS, description="Connection timeout in seconds"
)
sse_read_timeout_secs: int = Field(
default=DEFAULT_SSE_READ_TIMEOUT_SECS,
description="SSE read timeout in seconds",
)
discovered_tools: list[dict[str, Any]] = Field(
default_factory=list,
description=(
"Server-managed cache of the MCP server's tool catalog "
"[{name, description}]. Populated best-effort by the backend."
),
)
@field_validator("url")
@classmethod
def validate_url(cls, v: str) -> str:
if not isinstance(v, str) or not v.startswith(("http://", "https://")):
raise ValueError("config.url must be an http(s) URL")
return v
@field_validator("tools_filter")
@classmethod
def validate_tools_filter(cls, v: list[str]) -> list[str]:
if not all(isinstance(tool_name, str) for tool_name in v):
raise ValueError("config.tools_filter must be a list of strings")
return v
class McpToolDefinition(BaseModel):
"""Persisted MCP tool definition."""
schema_version: int = Field(default=1, description="Schema version")
type: Literal["mcp"] = Field(description="Tool type")
config: McpToolConfig = Field(description="MCP server configuration")
def _format_validation_error(error: ValidationError) -> str:
parts: list[str] = []
for item in error.errors():