refactor: pass auth context through automations

This commit is contained in:
Anish Sarkar 2026-06-19 20:28:35 +05:30
parent 7ec6fa4d1f
commit 096dea45d4
7 changed files with 42 additions and 23 deletions

View file

@ -34,6 +34,7 @@ from app.agents.chat.runtime.llm_config import AgentConfig
from app.agents.chat.runtime.prompt_caching import (
apply_litellm_prompt_caching,
)
from app.auth.context import AuthContext
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.services.user_tool_allowlist import (
@ -73,6 +74,7 @@ async def create_multi_agent_chat_deep_agent(
anon_session_id: str | None = None,
filesystem_selection: FilesystemSelection | None = None,
image_gen_model_id: int | None = None,
auth_context: AuthContext | None = None,
):
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.
@ -139,6 +141,7 @@ async def create_multi_agent_chat_deep_agent(
"connector_service": connector_service,
"firecrawl_api_key": firecrawl_api_key,
"user_id": user_id,
"auth_context": auth_context,
"thread_id": thread_id,
"thread_visibility": visibility,
"available_connectors": available_connectors,

View file

@ -30,6 +30,7 @@ from pydantic import ValidationError
from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
request_approval,
)
from app.auth.context import AuthContext
from app.automations.schemas.api import AutomationCreate
from app.automations.services.automation import AutomationService
from app.db import User, async_session_maker
@ -47,6 +48,7 @@ def create_create_automation_tool(
search_space_id: int,
user_id: str | UUID,
llm: Any,
auth_context: AuthContext | None = None,
):
"""Factory for the ``create_automation`` tool.
@ -172,7 +174,8 @@ def create_create_automation_tool(
"status": "error",
"message": "user not found in this session",
}
service = AutomationService(session=session, user=user)
auth = auth_context or AuthContext.system(user, source="agent")
service = AutomationService(session=session, auth=auth)
created = await service.create(final_validated)
return {
"status": "saved",

View file

@ -60,6 +60,7 @@ def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool:
return create_create_automation_tool(
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
auth_context=deps.get("auth_context"),
llm=deps["llm"],
)

View file

@ -16,7 +16,8 @@ from app.agents.chat.runtime.mention_resolver import (
substitute_in_text,
)
from app.agents.chat.shared.context import SurfSenseContextSchema
from app.db import ChatVisibility, async_session_maker
from app.auth.context import AuthContext
from app.db import ChatVisibility, User, async_session_maker
from app.schemas.new_chat import MentionedDocumentInfo
from ...types import ActionContext
@ -147,6 +148,12 @@ async def run_agent_task(
decision = "approve" if auto_approve_all else "reject"
async with async_session_maker() as agent_session:
auth_context = None
if ctx.creator_user_id:
user = await agent_session.get(User, ctx.creator_user_id)
if user is not None:
auth_context = AuthContext.system(user, source="automation")
deps = await build_dependencies(
session=agent_session,
search_space_id=ctx.search_space_id,
@ -168,6 +175,7 @@ async def run_agent_task(
thread_visibility=ChatVisibility.PRIVATE,
mentioned_document_ids=mentioned_document_ids,
image_gen_model_id=ctx.image_gen_model_id,
auth_context=auth_context,
)
agent_query, runtime_context = await _resolve_mention_context(

View file

@ -27,17 +27,19 @@ from app.automations.services.model_policy import (
)
from app.automations.triggers import get_trigger
from app.automations.triggers.builtin.schedule import compute_next_fire_at
from app.db import Permission, SearchSpace, User, get_async_session
from app.users import current_active_user
from app.auth.context import AuthContext
from app.db import Permission, SearchSpace, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission
class AutomationService:
"""Lifecycle of the ``Automation`` resource."""
def __init__(self, *, session: AsyncSession, user: User) -> None:
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
self.session = session
self.user = user
self.auth = auth
self.user = auth.user
async def create(self, payload: AutomationCreate) -> Automation:
"""Create an automation and its initial triggers in one transaction."""
@ -235,7 +237,7 @@ class AutomationService:
async def _authorize(self, search_space_id: int, permission: str) -> None:
await check_permission(
self.session,
self.user,
self.auth,
search_space_id,
permission,
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
@ -274,6 +276,6 @@ def _build_trigger(spec: TriggerCreate) -> AutomationTrigger:
def get_automation_service(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> AutomationService:
return AutomationService(session=session, user=user)
return AutomationService(session=session, auth=auth)

View file

@ -8,17 +8,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.automations.persistence.models.automation import Automation
from app.automations.persistence.models.run import AutomationRun
from app.db import Permission, User, get_async_session
from app.users import current_active_user
from app.auth.context import AuthContext
from app.db import Permission, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission
class RunService:
"""Read-only access to ``AutomationRun`` history."""
def __init__(self, *, session: AsyncSession, user: User) -> None:
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
self.session = session
self.user = user
self.auth = auth
async def list(
self,
@ -63,7 +64,7 @@ class RunService:
)
await check_permission(
self.session,
self.user,
self.auth,
automation.search_space_id,
permission,
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
@ -73,6 +74,6 @@ class RunService:
def get_run_service(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> RunService:
return RunService(session=session, user=user)
return RunService(session=session, auth=auth)

View file

@ -14,17 +14,18 @@ from app.automations.persistence.models.trigger import AutomationTrigger
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
from app.automations.triggers import get_trigger
from app.automations.triggers.builtin.schedule import compute_next_fire_at
from app.db import Permission, User, get_async_session
from app.users import current_active_user
from app.auth.context import AuthContext
from app.db import Permission, get_async_session
from app.users import get_auth_context
from app.utils.rbac import check_permission
class TriggerService:
"""Lifecycle of the ``AutomationTrigger`` sub-resource."""
def __init__(self, *, session: AsyncSession, user: User) -> None:
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
self.session = session
self.user = user
self.auth = auth
async def add(
self, *, automation_id: int, payload: TriggerCreate
@ -101,7 +102,7 @@ class TriggerService:
)
await check_permission(
self.session,
self.user,
self.auth,
automation.search_space_id,
permission,
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
@ -144,6 +145,6 @@ def _initial_next_fire(
def get_trigger_service(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
) -> TriggerService:
return TriggerService(session=session, user=user)
return TriggerService(session=session, auth=auth)