chore: linting

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-01-15 00:05:53 -08:00
parent 3375aeb9bc
commit 7ae68455b3
20 changed files with 128 additions and 103 deletions

View file

@ -24,9 +24,7 @@ def enum_exists(enum_name: str) -> bool:
"""Check if an enum type exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"
),
sa.text("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"),
{"enum_name": enum_name},
)
return result.scalar()

View file

@ -22,9 +22,7 @@ def enum_exists(enum_name: str) -> bool:
"""Check if an enum type exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"
),
sa.text("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"),
{"enum_name": enum_name},
)
return result.scalar()

View file

@ -197,9 +197,7 @@ def enum_exists(enum_name: str) -> bool:
"""Check if an enum type exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"
),
sa.text("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"),
{"enum_name": enum_name},
)
return result.scalar()

View file

@ -5,13 +5,14 @@ Revises: 61
Create Date: 2026-01-09 15:19:51.827647
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '62'
down_revision: str | None = '61'
revision: str = "62"
down_revision: str | None = "61"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None

View file

@ -5,6 +5,7 @@ Revises: 62
Create Date: 2026-01-13 12:23:31.481643
"""
from collections.abc import Sequence
from sqlalchemy import text
@ -12,8 +13,8 @@ from sqlalchemy import text
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '63'
down_revision: str | None = '62'
revision: str = "63"
down_revision: str | None = "62"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
@ -21,7 +22,7 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
connection = op.get_bind()
# Check if old constraint exists before trying to drop it
old_constraint_exists = connection.execute(
text("""
@ -31,14 +32,14 @@ def upgrade() -> None:
AND constraint_name='uq_searchspace_user_connector_type'
""")
).scalar()
if old_constraint_exists:
op.drop_constraint(
'uq_searchspace_user_connector_type',
'search_source_connectors',
type_='unique'
"uq_searchspace_user_connector_type",
"search_source_connectors",
type_="unique",
)
# Check if new constraint already exists before creating it
new_constraint_exists = connection.execute(
text("""
@ -48,19 +49,19 @@ def upgrade() -> None:
AND constraint_name='uq_searchspace_user_connector_type_name'
""")
).scalar()
if not new_constraint_exists:
op.create_unique_constraint(
'uq_searchspace_user_connector_type_name',
'search_source_connectors',
['search_space_id', 'user_id', 'connector_type', 'name']
"uq_searchspace_user_connector_type_name",
"search_source_connectors",
["search_space_id", "user_id", "connector_type", "name"],
)
def downgrade() -> None:
"""Downgrade schema."""
connection = op.get_bind()
# Check if new constraint exists before trying to drop it
new_constraint_exists = connection.execute(
text("""
@ -70,14 +71,14 @@ def downgrade() -> None:
AND constraint_name='uq_searchspace_user_connector_type_name'
""")
).scalar()
if new_constraint_exists:
op.drop_constraint(
'uq_searchspace_user_connector_type_name',
'search_source_connectors',
type_='unique'
"uq_searchspace_user_connector_type_name",
"search_source_connectors",
type_="unique",
)
# Check if old constraint already exists before creating it
old_constraint_exists = connection.execute(
text("""
@ -87,10 +88,10 @@ def downgrade() -> None:
AND constraint_name='uq_searchspace_user_connector_type'
""")
).scalar()
if not old_constraint_exists:
op.create_unique_constraint(
'uq_searchspace_user_connector_type',
'search_source_connectors',
['search_space_id', 'user_id', 'connector_type']
"uq_searchspace_user_connector_type",
"search_source_connectors",
["search_space_id", "user_id", "connector_type"],
)

View file

@ -44,4 +44,3 @@ def downgrade() -> None:
DROP COLUMN IF EXISTS author_id;
"""
)

View file

@ -18,7 +18,9 @@ logger = logging.getLogger(__name__)
class MCPClient:
"""Client for communicating with an MCP server."""
def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None):
def __init__(
self, command: str, args: list[str], env: dict[str, str] | None = None
):
"""Initialize MCP client.
Args:
@ -44,18 +46,16 @@ class MCPClient:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command,
args=self.args,
env=server_env
command=self.command, args=self.args, env=server_env
)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write):
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
@ -85,7 +85,9 @@ class MCPClient:
"""
if not self.session:
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
# Call tools/list RPC method
@ -93,11 +95,15 @@ class MCPClient:
tools = []
for tool in response.tools:
tools.append({
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
})
tools.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
logger.info("Listed %d tools from MCP server", len(tools))
return tools
@ -121,10 +127,14 @@ class MCPClient:
"""
if not self.session:
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
logger.info("Calling MCP tool '%s' with arguments: %s", tool_name, arguments)
logger.info(
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
# Call tools/call RPC method
response = await self.session.call_tool(tool_name, arguments=arguments)
@ -147,12 +157,17 @@ class MCPClient:
# Handle validation errors from MCP server responses
# Some MCP servers (like server-memory) return extra fields not in their schema
if "Invalid structured content" in str(e):
logger.warning("MCP server returned data not matching its schema, but continuing: %s", e)
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
# Try to extract result from error message or return a success message
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:
logger.error("Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True)
logger.error(
"Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True
)
return f"Error calling tool: {e!s}"

View file

@ -21,7 +21,8 @@ logger = logging.getLogger(__name__)
def _create_dynamic_input_model_from_schema(
tool_name: str, input_schema: dict[str, Any],
tool_name: str,
input_schema: dict[str, Any],
) -> type[BaseModel]:
"""Create a Pydantic model from MCP tool's JSON schema.
@ -41,15 +42,18 @@ def _create_dynamic_input_model_from_schema(
for param_name, param_schema in properties.items():
param_description = param_schema.get("description", "")
is_required = param_name in required_fields
# Use Any type for complex schemas to preserve structure
# This allows the MCP server to do its own validation
from typing import Any as AnyType
from pydantic import Field
if is_required:
field_definitions[param_name] = (AnyType, Field(..., description=param_description))
field_definitions[param_name] = (
AnyType,
Field(..., description=param_description),
)
else:
field_definitions[param_name] = (
AnyType | None,
@ -88,7 +92,7 @@ async def _create_mcp_tool_from_definition(
async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
try:
# Connect to server and call tool
async with mcp_client.connect():
@ -114,7 +118,8 @@ async def _create_mcp_tool_from_definition(
async def load_mcp_tools(
session: AsyncSession, search_space_id: int,
session: AsyncSession,
search_space_id: int,
) -> list[StructuredTool]:
"""Load all MCP tools from user's active MCP server connectors.
@ -150,7 +155,9 @@ async def load_mcp_tools(
env = server_config.get("env", {})
if not command:
logger.warning(f"MCP connector {connector.id} missing command, skipping")
logger.warning(
f"MCP connector {connector.id} missing command, skipping"
)
continue
# Create MCP client
@ -168,7 +175,9 @@ async def load_mcp_tools(
# Create LangChain tools from definitions
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition(tool_def, mcp_client)
tool = await _create_mcp_tool_from_definition(
tool_def, mcp_client
)
tools.append(tool)
except Exception as e:
logger.exception(

View file

@ -283,7 +283,8 @@ async def build_tools_async(
):
try:
mcp_tools = await load_mcp_tools(
dependencies["db_session"], dependencies["search_space_id"],
dependencies["db_session"],
dependencies["search_space_id"],
)
tools.extend(mcp_tools)
logging.info(

View file

@ -23,7 +23,9 @@ class SearchSourceConnectorBase(BaseModel):
@field_validator("config")
@classmethod
def validate_config_for_connector_type(
cls, config: dict[str, Any], values: dict[str, Any],
cls,
config: dict[str, Any],
values: dict[str, Any],
) -> dict[str, Any]:
connector_type = values.data.get("connector_type")
return validate_connector_config(connector_type, config)

View file

@ -34,7 +34,6 @@ from .base import (
)
from .markdown_processor import add_received_markdown_file_document
# Constants for LlamaCloud retry configuration
LLAMACLOUD_MAX_RETRIES = 3
LLAMACLOUD_BASE_DELAY = 5 # Base delay in seconds for exponential backoff