mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-24 21:38:09 +02:00
refactor: pass auth context through automations
This commit is contained in:
parent
7ec6fa4d1f
commit
096dea45d4
7 changed files with 42 additions and 23 deletions
|
|
@ -34,6 +34,7 @@ from app.agents.chat.runtime.llm_config import AgentConfig
|
|||
from app.agents.chat.runtime.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.user_tool_allowlist import (
|
||||
|
|
@ -73,6 +74,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
anon_session_id: str | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
image_gen_model_id: int | None = None,
|
||||
auth_context: AuthContext | None = None,
|
||||
):
|
||||
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled.
|
||||
|
||||
|
|
@ -139,6 +141,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
"connector_service": connector_service,
|
||||
"firecrawl_api_key": firecrawl_api_key,
|
||||
"user_id": user_id,
|
||||
"auth_context": auth_context,
|
||||
"thread_id": thread_id,
|
||||
"thread_visibility": visibility,
|
||||
"available_connectors": available_connectors,
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from pydantic import ValidationError
|
|||
from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
|
||||
request_approval,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.automations.schemas.api import AutomationCreate
|
||||
from app.automations.services.automation import AutomationService
|
||||
from app.db import User, async_session_maker
|
||||
|
|
@ -47,6 +48,7 @@ def create_create_automation_tool(
|
|||
search_space_id: int,
|
||||
user_id: str | UUID,
|
||||
llm: Any,
|
||||
auth_context: AuthContext | None = None,
|
||||
):
|
||||
"""Factory for the ``create_automation`` tool.
|
||||
|
||||
|
|
@ -172,7 +174,8 @@ def create_create_automation_tool(
|
|||
"status": "error",
|
||||
"message": "user not found in this session",
|
||||
}
|
||||
service = AutomationService(session=session, user=user)
|
||||
auth = auth_context or AuthContext.system(user, source="agent")
|
||||
service = AutomationService(session=session, auth=auth)
|
||||
created = await service.create(final_validated)
|
||||
return {
|
||||
"status": "saved",
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool:
|
|||
return create_create_automation_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
auth_context=deps.get("auth_context"),
|
||||
llm=deps["llm"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ from app.agents.chat.runtime.mention_resolver import (
|
|||
substitute_in_text,
|
||||
)
|
||||
from app.agents.chat.shared.context import SurfSenseContextSchema
|
||||
from app.db import ChatVisibility, async_session_maker
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import ChatVisibility, User, async_session_maker
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
from ...types import ActionContext
|
||||
|
|
@ -147,6 +148,12 @@ async def run_agent_task(
|
|||
decision = "approve" if auto_approve_all else "reject"
|
||||
|
||||
async with async_session_maker() as agent_session:
|
||||
auth_context = None
|
||||
if ctx.creator_user_id:
|
||||
user = await agent_session.get(User, ctx.creator_user_id)
|
||||
if user is not None:
|
||||
auth_context = AuthContext.system(user, source="automation")
|
||||
|
||||
deps = await build_dependencies(
|
||||
session=agent_session,
|
||||
search_space_id=ctx.search_space_id,
|
||||
|
|
@ -168,6 +175,7 @@ async def run_agent_task(
|
|||
thread_visibility=ChatVisibility.PRIVATE,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
image_gen_model_id=ctx.image_gen_model_id,
|
||||
auth_context=auth_context,
|
||||
)
|
||||
|
||||
agent_query, runtime_context = await _resolve_mention_context(
|
||||
|
|
|
|||
|
|
@ -27,17 +27,19 @@ from app.automations.services.model_policy import (
|
|||
)
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||
from app.db import Permission, SearchSpace, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, SearchSpace, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class AutomationService:
|
||||
"""Lifecycle of the ``Automation`` resource."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.auth = auth
|
||||
self.user = auth.user
|
||||
|
||||
async def create(self, payload: AutomationCreate) -> Automation:
|
||||
"""Create an automation and its initial triggers in one transaction."""
|
||||
|
|
@ -235,7 +237,7 @@ class AutomationService:
|
|||
async def _authorize(self, search_space_id: int, permission: str) -> None:
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
self.auth,
|
||||
search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
|
|
@ -274,6 +276,6 @@ def _build_trigger(spec: TriggerCreate) -> AutomationTrigger:
|
|||
|
||||
def get_automation_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AutomationService:
|
||||
return AutomationService(session=session, user=user)
|
||||
return AutomationService(session=session, auth=auth)
|
||||
|
|
|
|||
|
|
@ -8,17 +8,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.automations.persistence.models.automation import Automation
|
||||
from app.automations.persistence.models.run import AutomationRun
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class RunService:
|
||||
"""Read-only access to ``AutomationRun`` history."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.auth = auth
|
||||
|
||||
async def list(
|
||||
self,
|
||||
|
|
@ -63,7 +64,7 @@ class RunService:
|
|||
)
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
self.auth,
|
||||
automation.search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
|
|
@ -73,6 +74,6 @@ class RunService:
|
|||
|
||||
def get_run_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> RunService:
|
||||
return RunService(session=session, user=user)
|
||||
return RunService(session=session, auth=auth)
|
||||
|
|
|
|||
|
|
@ -14,17 +14,18 @@ from app.automations.persistence.models.trigger import AutomationTrigger
|
|||
from app.automations.schemas.api import TriggerCreate, TriggerUpdate
|
||||
from app.automations.triggers import get_trigger
|
||||
from app.automations.triggers.builtin.schedule import compute_next_fire_at
|
||||
from app.db import Permission, User, get_async_session
|
||||
from app.users import current_active_user
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import Permission, get_async_session
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
||||
class TriggerService:
|
||||
"""Lifecycle of the ``AutomationTrigger`` sub-resource."""
|
||||
|
||||
def __init__(self, *, session: AsyncSession, user: User) -> None:
|
||||
def __init__(self, *, session: AsyncSession, auth: AuthContext) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.auth = auth
|
||||
|
||||
async def add(
|
||||
self, *, automation_id: int, payload: TriggerCreate
|
||||
|
|
@ -101,7 +102,7 @@ class TriggerService:
|
|||
)
|
||||
await check_permission(
|
||||
self.session,
|
||||
self.user,
|
||||
self.auth,
|
||||
automation.search_space_id,
|
||||
permission,
|
||||
f"You don't have permission to {permission.split(':')[1]} automations in this search space",
|
||||
|
|
@ -144,6 +145,6 @@ def _initial_next_fire(
|
|||
|
||||
def get_trigger_service(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> TriggerService:
|
||||
return TriggerService(session=session, user=user)
|
||||
return TriggerService(session=session, auth=auth)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue