Merge pull request #1463 from CREDO23/feat/blob-storage

[Feat] Add Document Blob Storage, Segregate Connectors, and Jira Cleanup
This commit is contained in:
Rohan Verma 2026-06-02 12:15:48 -07:00 committed by GitHub
commit 309bd9a2dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 5354 additions and 7404 deletions

View file

@ -128,6 +128,23 @@ services:
timeout: 5s timeout: 5s
retries: 5 retries: 5
# OPTIONAL — Azurite emulates Azure Blob Storage for testing the Azure
# original-file backend. The default filesystem backend needs none of this.
# To exercise it, set in surfsense_backend/.env:
# FILE_STORAGE_BACKEND=azure
# AZURE_STORAGE_CONTAINER=surfsense-documents
# AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://localhost:${AZURITE_BLOB_PORT:-10000}/devstoreaccount1;
# The backend creates blobs on upload; create the container once first
# (Azure CLI / Storage Explorer), then upload a document.
azurite:
image: mcr.microsoft.com/azure-storage/azurite:3.33.0
command: azurite-blob --blobHost 0.0.0.0 --blobPort 10000
ports:
- "${AZURITE_BLOB_PORT:-10000}:10000"
volumes:
- azurite_data:/data
restart: unless-stopped
volumes: volumes:
postgres_data: postgres_data:
name: surfsense-deps-postgres name: surfsense-deps-postgres
@ -137,3 +154,5 @@ volumes:
name: surfsense-deps-redis name: surfsense-deps-redis
zero_cache_data: zero_cache_data:
name: surfsense-deps-zero-cache name: surfsense-deps-zero-cache
azurite_data:
name: surfsense-deps-azurite

View file

@ -293,6 +293,16 @@ LLAMA_CLOUD_API_KEY=llx-nnn
# AZURE_DI_ENDPOINT=https://your-resource.cognitiveservices.azure.com/ # AZURE_DI_ENDPOINT=https://your-resource.cognitiveservices.azure.com/
# AZURE_DI_KEY=your-key # AZURE_DI_KEY=your-key
# Original File Storage
# Where to persist the original bytes of uploaded documents (for download today,
# redaction / form-filling later). "local" needs no cloud creds and is the dev default.
FILE_STORAGE_BACKEND=local
# Local backend: directory for stored files (defaults to surfsense_backend/.local_object_store)
# FILE_STORAGE_LOCAL_PATH=/var/lib/surfsense/object-store
# Azure Blob backend (set FILE_STORAGE_BACKEND=azure):
# AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=...;AccountKey=...;EndpointSuffix=core.windows.net
# AZURE_STORAGE_CONTAINER=surfsense-documents
# Daytona Sandbox (isolated code execution) # Daytona Sandbox (isolated code execution)
# DAYTONA_SANDBOX_ENABLED=FALSE # DAYTONA_SANDBOX_ENABLED=FALSE
# DAYTONA_API_KEY=your-daytona-api-key # DAYTONA_API_KEY=your-daytona-api-key

View file

@ -2,6 +2,7 @@
.venv .venv
venv/ venv/
data/ data/
.local_object_store/
__pycache__/ __pycache__/
.flashrank_cache .flashrank_cache
surf_new_backend.egg-info/ surf_new_backend.egg-info/

View file

@ -0,0 +1,86 @@
"""add document_files table for stored original uploads
Revision ID: 152
Revises: 151
"""
from collections.abc import Sequence
from alembic import op
revision: str = "152"
down_revision: str | None = "151"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# The enum type must precede the table that references it.
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_type WHERE typname = 'document_file_kind'
) THEN
CREATE TYPE document_file_kind AS ENUM (
'ORIGINAL', 'REDACTED', 'FILLED_FORM'
);
END IF;
END
$$;
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS document_files (
id SERIAL PRIMARY KEY,
document_id INTEGER NOT NULL
REFERENCES documents(id) ON DELETE CASCADE,
search_space_id INTEGER NOT NULL
REFERENCES searchspaces(id) ON DELETE CASCADE,
kind document_file_kind NOT NULL DEFAULT 'ORIGINAL',
storage_backend VARCHAR(32) NOT NULL,
storage_key TEXT NOT NULL,
original_filename TEXT NOT NULL,
mime_type TEXT,
size_bytes BIGINT NOT NULL,
checksum_sha256 VARCHAR(64),
created_by_id UUID
REFERENCES "user"(id) ON DELETE SET NULL,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
);
"""
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_document_files_document_id "
"ON document_files(document_id);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_document_files_search_space_id "
"ON document_files(search_space_id);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_document_files_kind "
"ON document_files(kind);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_document_files_created_by_id "
"ON document_files(created_by_id);"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_document_files_created_at "
"ON document_files(created_at);"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ix_document_files_created_at;")
op.execute("DROP INDEX IF EXISTS ix_document_files_created_by_id;")
op.execute("DROP INDEX IF EXISTS ix_document_files_kind;")
op.execute("DROP INDEX IF EXISTS ix_document_files_search_space_id;")
op.execute("DROP INDEX IF EXISTS ix_document_files_document_id;")
op.execute("DROP TABLE IF EXISTS document_files;")
op.execute("DROP TYPE IF EXISTS document_file_kind;")

View file

@ -0,0 +1,109 @@
"""restore automation_runs to zero_publication
Migration 149's ``SET TABLE`` dropped ``automation_runs`` (added in 148),
breaking the dashboard live run ticker with a SchemaVersionNotSupported
reload loop. Re-emit the publication with ``automation_runs`` using the
``COMMENT`` bookend pattern so zero-cache fires its schema-change hook.
Revision ID: 153
Revises: 152
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "153"
down_revision: str | None = "152"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
PUBLICATION_NAME = "zero_publication"
DOCUMENT_COLS = [
"id",
"title",
"document_type",
"search_space_id",
"folder_id",
"created_by_id",
"status",
"created_at",
"updated_at",
]
USER_COLS = [
"id",
"pages_limit",
"pages_used",
"premium_credit_micros_limit",
"premium_credit_micros_used",
]
AUTOMATION_RUN_COLS = [
"id",
"automation_id",
"trigger_id",
"status",
"step_results",
"started_at",
"finished_at",
"created_at",
]
def _has_zero_version(conn, table: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = '_0_version'"
),
{"tbl": table},
).fetchone()
is not None
)
def _set_table_ddl(*, with_automation_runs: bool, conn) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if _has_zero_version(conn, "documents") else [])
user_cols = USER_COLS + (['"_0_version"'] if _has_zero_version(conn, "user") else [])
tables = [
"notifications",
f"documents ({', '.join(doc_cols)})",
"folders",
"search_source_connectors",
"new_chat_messages",
"chat_comments",
"chat_session_state",
f'"user" ({", ".join(user_cols)})',
]
if with_automation_runs:
tables.append(f"automation_runs ({', '.join(AUTOMATION_RUN_COLS)})")
return f"ALTER PUBLICATION {PUBLICATION_NAME} SET TABLE " + ", ".join(tables)
def _resync(*, with_automation_runs: bool, tag: str) -> None:
conn = op.get_bind()
exists = conn.execute(
sa.text("SELECT 1 FROM pg_publication WHERE pubname = :name"),
{"name": PUBLICATION_NAME},
).fetchone()
if not exists:
return
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'pre-{tag}'"))
conn.execute(sa.text(_set_table_ddl(with_automation_runs=with_automation_runs, conn=conn)))
conn.execute(sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'post-{tag}'"))
def upgrade() -> None:
_resync(with_automation_runs=True, tag="153-resync")
def downgrade() -> None:
_resync(with_automation_runs=False, tag="153-downgrade")

View file

@ -1,11 +0,0 @@
"""Jira tools for creating, updating, and deleting issues."""
from .create_issue import create_create_jira_issue_tool
from .delete_issue import create_delete_jira_issue_tool
from .update_issue import create_update_jira_issue_tool
__all__ = [
"create_create_jira_issue_tool",
"create_delete_jira_issue_tool",
"create_update_jira_issue_tool",
]

View file

@ -1,248 +0,0 @@
import asyncio
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
def create_create_jira_issue_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""Factory function to create the create_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits. Per-call sessions also
keep the request's outer transaction free of long-running Jira API
blocking.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured create_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool
async def create_jira_issue(
project_key: str,
summary: str,
issue_type: str = "Task",
description: str | None = None,
priority: str | None = None,
) -> dict[str, Any]:
"""Create a new issue in Jira.
Use this tool when the user explicitly asks to create a new Jira issue/ticket.
Args:
project_key: The Jira project key (e.g. "PROJ", "ENG").
summary: Short, descriptive issue title.
issue_type: Issue type (default "Task"). Others: "Bug", "Story", "Epic".
description: Optional description body for the issue.
priority: Optional priority name (e.g. "High", "Medium", "Low").
Returns:
Dictionary with status, issue_key, and message.
IMPORTANT:
- If status is "rejected", the user declined. Do NOT retry.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"create_jira_issue called: project_key='{project_key}', summary='{summary}'"
)
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
async with async_session_maker() as db_session:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Jira accounts need re-authentication.",
"connector_type": "jira",
}
result = request_approval(
action_type="jira_issue_creation",
tool_name="create_jira_issue",
params={
"project_key": project_key,
"summary": summary,
"issue_type": issue_type,
"description": description,
"priority": priority,
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_project_key = result.params.get("project_key", project_key)
final_summary = result.params.get("summary", summary)
final_issue_type = result.params.get("issue_type", issue_type)
final_description = result.params.get("description", description)
final_priority = result.params.get("priority", priority)
final_connector_id = result.params.get("connector_id", connector_id)
if not final_summary or not final_summary.strip():
return {
"status": "error",
"message": "Issue summary cannot be empty.",
}
if not final_project_key:
return {"status": "error", "message": "A project must be selected."}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Jira connector found.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
jira_client = await jira_history._get_jira_client()
api_result = await asyncio.to_thread(
jira_client.create_issue,
project_key=final_project_key,
summary=final_summary,
issue_type=final_issue_type,
description=final_description,
priority=final_priority,
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
issue_key = api_result.get("key", "")
issue_url = (
f"{jira_history._base_url}/browse/{issue_key}"
if jira_history._base_url and issue_key
else ""
)
kb_message_suffix = ""
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
issue_id=issue_key,
issue_identifier=issue_key,
issue_title=final_summary,
description=final_description,
state="To Do",
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This issue will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"issue_key": issue_key,
"issue_url": issue_url,
"message": f"Jira issue {issue_key} created successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Jira issue: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the issue.",
}
return create_jira_issue

View file

@ -1,210 +0,0 @@
import asyncio
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_jira_issue_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""Factory function to create the delete_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured delete_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_jira_issue(
issue_title_or_key: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Delete a Jira issue.
Use this tool when the user asks to delete or remove a Jira issue.
Args:
issue_title_or_key: The issue key (e.g. "PROJ-42") or title.
delete_from_kb: Whether to also remove from the knowledge base.
Returns:
Dictionary with status, message, and deleted_from_kb.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "not_found", relay the message to the user.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"delete_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
async with async_session_maker() as db_session:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, issue_title_or_key
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="jira_issue_deletion",
tool_name="delete_jira_issue",
params={
"issue_key": issue_key,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_key = result.params.get("issue_key", issue_key)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=final_connector_id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(jira_client.delete_issue, final_issue_key)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
message = f"Jira issue {final_issue_key} deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
return {
"status": "success",
"issue_key": final_issue_key,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting Jira issue: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while deleting the issue.",
}
return delete_jira_issue

View file

@ -1,255 +0,0 @@
import asyncio
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
from app.db import async_session_maker
from app.services.jira import JiraToolMetadataService
logger = logging.getLogger(__name__)
def create_update_jira_issue_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""Factory function to create the update_jira_issue tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Jira connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured update_jira_issue tool
"""
del db_session # per-call session — see docstring
@tool
async def update_jira_issue(
issue_title_or_key: str,
new_summary: str | None = None,
new_description: str | None = None,
new_priority: str | None = None,
) -> dict[str, Any]:
"""Update an existing Jira issue.
Use this tool when the user asks to modify, edit, or update a Jira issue.
Args:
issue_title_or_key: The issue key (e.g. "PROJ-42") or title to identify the issue.
new_summary: Optional new title/summary for the issue.
new_description: Optional new description.
new_priority: Optional new priority name.
Returns:
Dictionary with status and message.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "not_found", relay the message and ask user to verify.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"update_jira_issue called: issue_title_or_key='{issue_title_or_key}'"
)
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Jira tool not properly configured."}
try:
async with async_session_maker() as db_session:
metadata_service = JiraToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, issue_title_or_key
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "jira",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
issue_data = context["issue"]
issue_key = issue_data["issue_id"]
document_id = issue_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="jira_issue_update",
tool_name="update_jira_issue",
params={
"issue_key": issue_key,
"document_id": document_id,
"new_summary": new_summary,
"new_description": new_description,
"new_priority": new_priority,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_issue_key = result.params.get("issue_key", issue_key)
final_summary = result.params.get("new_summary", new_summary)
final_description = result.params.get(
"new_description", new_description
)
final_priority = result.params.get("new_priority", new_priority)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_document_id = result.params.get("document_id", document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this issue.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Jira connector is invalid.",
}
fields: dict[str, Any] = {}
if final_summary:
fields["summary"] = final_summary
if final_description is not None:
fields["description"] = {
"type": "doc",
"version": 1,
"content": [
{
"type": "paragraph",
"content": [
{"type": "text", "text": final_description}
],
}
],
}
if final_priority:
fields["priority"] = {"name": final_priority}
if not fields:
return {"status": "error", "message": "No changes specified."}
try:
jira_history = JiraHistoryConnector(
session=db_session, connector_id=final_connector_id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(
jira_client.update_issue, final_issue_key, fields
)
except Exception as api_err:
if "status code 403" in str(api_err).lower():
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Jira account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
issue_url = (
f"{jira_history._base_url}/browse/{final_issue_key}"
if jira_history._base_url and final_issue_key
else ""
)
kb_message_suffix = ""
if final_document_id:
try:
from app.services.jira import JiraKBSyncService
kb_service = JiraKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
issue_id=final_issue_key,
user_id=user_id,
search_space_id=search_space_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
return {
"status": "success",
"issue_key": final_issue_key,
"issue_url": issue_url,
"message": f"Jira issue {final_issue_key} updated successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error updating Jira issue: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while updating the issue.",
}
return update_jira_issue

View file

@ -1,648 +0,0 @@
"""
Jira Connector Module
A module for retrieving data from Jira.
Allows fetching issue lists and their comments, projects and more.
Supports both OAuth 2.0 (preferred) and legacy API token authentication.
"""
import base64
from datetime import datetime
from typing import Any
import requests
class JiraConnector:
"""Class for retrieving data from Jira."""
def __init__(
self,
base_url: str | None = None,
access_token: str | None = None,
cloud_id: str | None = None,
email: str | None = None,
api_token: str | None = None,
):
"""
Initialize the JiraConnector class.
Args:
base_url: Jira instance base URL (e.g., 'https://yourcompany.atlassian.net')
access_token: OAuth 2.0 access token (preferred method)
cloud_id: Atlassian cloud ID (used with OAuth for API URL construction)
email: Jira account email address (legacy method, used with api_token)
api_token: Jira API token (legacy method, used with email)
"""
self.base_url = base_url.rstrip("/") if base_url else None
self.access_token = access_token
self.cloud_id = cloud_id
self.email = email
self.api_token = api_token
self.api_version = "3" # Jira Cloud API version
self._use_oauth = access_token is not None
def set_oauth_credentials(
self, base_url: str, access_token: str, cloud_id: str | None = None
) -> None:
"""
Set OAuth 2.0 credentials (preferred method).
Args:
base_url: Jira instance base URL
access_token: OAuth 2.0 access token
cloud_id: Atlassian cloud ID (optional, used for API URL construction)
"""
self.base_url = base_url.rstrip("/")
self.access_token = access_token
self.cloud_id = cloud_id
self._use_oauth = True
def set_credentials(self, base_url: str, email: str, api_token: str) -> None:
"""
Set the Jira credentials (legacy method using API token).
Args:
base_url: Jira instance base URL
email: Jira account email address
api_token: Jira API token
"""
self.base_url = base_url.rstrip("/")
self.email = email
self.api_token = api_token
self._use_oauth = False
def set_email(self, email: str) -> None:
"""
Set the Jira account email (legacy method).
Args:
email: Jira account email address
"""
self.email = email
self._use_oauth = False
def set_api_token(self, api_token: str) -> None:
"""
Set the Jira API token (legacy method).
Args:
api_token: Jira API token
"""
self.api_token = api_token
self._use_oauth = False
def get_headers(self) -> dict[str, str]:
"""
Get headers for Jira API requests.
Uses OAuth Bearer token if available, otherwise falls back to Basic Auth.
Returns:
Dictionary of headers
Raises:
ValueError: If credentials have not been set
"""
if self._use_oauth:
# OAuth 2.0 authentication
if not self.base_url or not self.access_token:
raise ValueError(
"Jira OAuth credentials not initialized. Call set_oauth_credentials() first."
)
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.access_token}",
"Accept": "application/json",
}
else:
# Legacy Basic Auth
if not all([self.base_url, self.email, self.api_token]):
raise ValueError(
"Jira credentials not initialized. Call set_credentials() first."
)
# Create Basic Auth header using email:api_token
auth_str = f"{self.email}:{self.api_token}"
auth_bytes = auth_str.encode("utf-8")
auth_header = "Basic " + base64.b64encode(auth_bytes).decode("ascii")
return {
"Content-Type": "application/json",
"Authorization": auth_header,
"Accept": "application/json",
}
def make_api_request(
self,
endpoint: str,
params: dict[str, Any] | None = None,
method: str = "GET",
json_payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Make a request to the Jira API.
Args:
endpoint: API endpoint (without base URL)
params: Query parameters for the request (optional)
method: HTTP method (GET or POST)
json_payload: JSON payload for POST requests (optional)
Returns:
Response data from the API
Raises:
ValueError: If credentials have not been set
Exception: If the API request fails
"""
headers = self.get_headers()
# Construct API URL based on authentication method
if self._use_oauth and self.cloud_id:
# Use Atlassian API gateway with cloud_id for OAuth
url = f"https://api.atlassian.com/ex/jira/{self.cloud_id}/rest/api/{self.api_version}/{endpoint}"
else:
# Use direct base URL (works for both OAuth and legacy)
url = f"{self.base_url}/rest/api/{self.api_version}/{endpoint}"
method_upper = method.upper()
if method_upper == "POST":
response = requests.post(
url, headers=headers, json=json_payload, timeout=500
)
elif method_upper == "PUT":
response = requests.put(
url, headers=headers, json=json_payload, timeout=500
)
elif method_upper == "DELETE":
response = requests.delete(url, headers=headers, params=params, timeout=500)
else:
response = requests.get(url, headers=headers, params=params, timeout=500)
if response.status_code in (200, 201, 204):
if response.status_code == 204 or not response.text:
return {"status": "success"}
return response.json()
else:
raise Exception(
f"API request failed with status code {response.status_code}: {response.text}"
)
def get_all_projects(self) -> dict[str, Any]:
"""
Fetch all projects from Jira.
Returns:
List of project objects
Raises:
ValueError: If credentials have not been set
Exception: If the API request fails
"""
return self.make_api_request("project/search")
def get_all_issues(self, project_key: str | None = None) -> list[dict[str, Any]]:
"""
Fetch all issues from Jira.
Args:
project_key: Optional project key to filter issues (e.g., 'PROJ')
Returns:
List of issue objects
Raises:
ValueError: If credentials have not been set
Exception: If the API request fails
"""
jql = "ORDER BY created DESC"
if project_key:
jql = f'project = "{project_key}" ' + jql
fields = [
"summary",
"description",
"status",
"assignee",
"reporter",
"created",
"updated",
"priority",
"issuetype",
"project",
]
all_issues = []
start_at = 0
max_results = 100
all_issues = []
start_at = 0
while True:
json_payload = {
"jql": jql,
"fields": fields, # API accepts list
"maxResults": max_results,
"startAt": start_at,
}
result = self.make_api_request(
"search/jql", json_payload=json_payload, method="POST"
)
if not isinstance(result, dict) or "issues" not in result:
raise Exception("Invalid response from Jira API")
issues = result["issues"]
all_issues.extend(issues)
print(f"Fetched {len(issues)} issues (startAt={start_at})")
total = result.get("total", 0)
if start_at + len(issues) >= total:
break
start_at += len(issues)
return all_issues
def get_issues_by_date_range(
self,
start_date: str,
end_date: str,
include_comments: bool = True,
project_key: str | None = None,
) -> tuple[list[dict[str, Any]], str | None]:
"""
Fetch issues within a date range.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format (inclusive)
include_comments: Whether to include comments in the response
project_key: Optional project key to filter issues
Returns:
Tuple containing (issues list, error message or None)
"""
try:
# Build JQL query for date range
# Query issues that were either created OR updated within the date range
# Use end_date + 1 day with < operator to include the full end date
from datetime import datetime, timedelta
# Parse end_date and add 1 day for inclusive end date
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
end_date_next = (end_date_obj + timedelta(days=1)).strftime("%Y-%m-%d")
# Check both created and updated dates to catch all relevant issues
# Use 'created' and 'updated' (standard JQL field names)
date_filter = (
f"(created >= '{start_date}' AND created < '{end_date_next}') "
f"OR (updated >= '{start_date}' AND updated < '{end_date_next}')"
)
jql = f"{date_filter} ORDER BY created DESC"
if project_key:
jql = f'project = "{project_key}" AND ({date_filter}) ORDER BY created DESC'
# Define fields to retrieve
fields = [
"summary",
"description",
"status",
"assignee",
"reporter",
"created",
"updated",
"priority",
"issuetype",
"project",
]
if include_comments:
fields.append("comment")
params = {
"jql": jql,
"fields": ",".join(fields),
"maxResults": 100,
"startAt": 0,
}
all_issues = []
start_at = 0
while True:
params["startAt"] = start_at
result = self.make_api_request("search/jql", params)
if not isinstance(result, dict) or "issues" not in result:
return [], "Invalid response from Jira API"
issues = result["issues"]
all_issues.extend(issues)
# Check if there are more issues to fetch
total = result.get("total", 0)
if start_at + len(issues) >= total:
break
start_at += len(issues)
if not all_issues:
return [], "No issues found in the specified date range."
return all_issues, None
except Exception as e:
return [], f"Error fetching issues: {e!s}"
def get_myself(self) -> dict[str, Any]:
"""Fetch the current user's profile (health check)."""
return self.make_api_request("myself")
def get_projects(self) -> list[dict[str, Any]]:
"""Fetch all projects the user has access to."""
result = self.make_api_request("project/search")
return result.get("values", [])
def get_issue_types(self) -> list[dict[str, Any]]:
"""Fetch all issue types."""
return self.make_api_request("issuetype")
def get_priorities(self) -> list[dict[str, Any]]:
"""Fetch all priority levels."""
return self.make_api_request("priority")
def get_issue(self, issue_id_or_key: str) -> dict[str, Any]:
"""Fetch a single issue by ID or key."""
return self.make_api_request(f"issue/{issue_id_or_key}")
def create_issue(
self,
project_key: str,
summary: str,
issue_type: str = "Task",
description: str | None = None,
priority: str | None = None,
assignee_id: str | None = None,
) -> dict[str, Any]:
"""Create a new Jira issue."""
fields: dict[str, Any] = {
"project": {"key": project_key},
"summary": summary,
"issuetype": {"name": issue_type},
}
if description:
fields["description"] = {
"type": "doc",
"version": 1,
"content": [
{
"type": "paragraph",
"content": [{"type": "text", "text": description}],
}
],
}
if priority:
fields["priority"] = {"name": priority}
if assignee_id:
fields["assignee"] = {"accountId": assignee_id}
return self.make_api_request(
"issue", method="POST", json_payload={"fields": fields}
)
def update_issue(
self, issue_id_or_key: str, fields: dict[str, Any]
) -> dict[str, Any]:
"""Update an existing Jira issue fields."""
return self.make_api_request(
f"issue/{issue_id_or_key}",
method="PUT",
json_payload={"fields": fields},
)
def delete_issue(self, issue_id_or_key: str) -> dict[str, Any]:
"""Delete a Jira issue."""
return self.make_api_request(f"issue/{issue_id_or_key}", method="DELETE")
def get_transitions(self, issue_id_or_key: str) -> list[dict[str, Any]]:
"""Get available transitions for an issue (for status changes)."""
result = self.make_api_request(f"issue/{issue_id_or_key}/transitions")
return result.get("transitions", [])
def transition_issue(
self, issue_id_or_key: str, transition_id: str
) -> dict[str, Any]:
"""Transition an issue to a new status."""
return self.make_api_request(
f"issue/{issue_id_or_key}/transitions",
method="POST",
json_payload={"transition": {"id": transition_id}},
)
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
"""
Format an issue for easier consumption.
Args:
issue: The issue object from Jira API
Returns:
Formatted issue dictionary
"""
fields = issue.get("fields", {})
# Extract basic issue details
formatted = {
"id": issue.get("id", ""),
"key": issue.get("key", ""),
"title": fields.get("summary", ""),
"description": fields.get("description", ""),
"status": (
fields.get("status", {}).get("name", "Unknown")
if fields.get("status")
else "Unknown"
),
"status_category": (
fields.get("status", {})
.get("statusCategory", {})
.get("name", "Unknown")
if fields.get("status")
else "Unknown"
),
"priority": (
fields.get("priority", {}).get("name", "Unknown")
if fields.get("priority")
else "Unknown"
),
"issue_type": (
fields.get("issuetype", {}).get("name", "Unknown")
if fields.get("issuetype")
else "Unknown"
),
"project": (
fields.get("project", {}).get("key", "Unknown")
if fields.get("project")
else "Unknown"
),
"created_at": fields.get("created", ""),
"updated_at": fields.get("updated", ""),
"reporter": (
{
"account_id": (
fields.get("reporter", {}).get("accountId", "")
if fields.get("reporter")
else ""
),
"display_name": (
fields.get("reporter", {}).get("displayName", "Unknown")
if fields.get("reporter")
else "Unknown"
),
"email": (
fields.get("reporter", {}).get("emailAddress", "")
if fields.get("reporter")
else ""
),
}
if fields.get("reporter")
else {"account_id": "", "display_name": "Unknown", "email": ""}
),
"assignee": (
{
"account_id": fields.get("assignee", {}).get("accountId", ""),
"display_name": fields.get("assignee", {}).get(
"displayName", "Unknown"
),
"email": fields.get("assignee", {}).get("emailAddress", ""),
}
if fields.get("assignee")
else None
),
"comments": [],
}
# Extract comments if available
if "comment" in fields and "comments" in fields["comment"]:
for comment in fields["comment"]["comments"]:
formatted_comment = {
"id": comment.get("id", ""),
"body": comment.get("body", ""),
"created_at": comment.get("created", ""),
"updated_at": comment.get("updated", ""),
"author": (
{
"account_id": (
comment.get("author", {}).get("accountId", "")
if comment.get("author")
else ""
),
"display_name": (
comment.get("author", {}).get("displayName", "Unknown")
if comment.get("author")
else "Unknown"
),
"email": (
comment.get("author", {}).get("emailAddress", "")
if comment.get("author")
else ""
),
}
if comment.get("author")
else {"account_id": "", "display_name": "Unknown", "email": ""}
),
}
formatted["comments"].append(formatted_comment)
return formatted
def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
"""
Convert an issue to markdown format.
Args:
issue: The issue object (either raw or formatted)
Returns:
Markdown string representation of the issue
"""
# Format the issue if it's not already formatted
if "key" not in issue:
issue = self.format_issue(issue)
# Build the markdown content
markdown = (
f"# {issue.get('key', 'No Key')}: {issue.get('title', 'No Title')}\n\n"
)
if issue.get("status"):
markdown += f"**Status:** {issue['status']}\n"
if issue.get("priority"):
markdown += f"**Priority:** {issue['priority']}\n"
if issue.get("issue_type"):
markdown += f"**Type:** {issue['issue_type']}\n"
if issue.get("project"):
markdown += f"**Project:** {issue['project']}\n\n"
if issue.get("assignee") and issue["assignee"].get("display_name"):
markdown += f"**Assignee:** {issue['assignee']['display_name']}\n"
if issue.get("reporter") and issue["reporter"].get("display_name"):
markdown += f"**Reporter:** {issue['reporter']['display_name']}\n"
if issue.get("created_at"):
created_date = self.format_date(issue["created_at"])
markdown += f"**Created:** {created_date}\n"
if issue.get("updated_at"):
updated_date = self.format_date(issue["updated_at"])
markdown += f"**Updated:** {updated_date}\n\n"
if issue.get("description"):
markdown += f"## Description\n\n{issue['description']}\n\n"
if issue.get("comments"):
markdown += f"## Comments ({len(issue['comments'])})\n\n"
for comment in issue["comments"]:
author_name = "Unknown"
if comment.get("author") and comment["author"].get("display_name"):
author_name = comment["author"]["display_name"]
comment_date = "Unknown date"
if comment.get("created_at"):
comment_date = self.format_date(comment["created_at"])
markdown += f"### {author_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n"
return markdown
@staticmethod
def format_date(iso_date: str) -> str:
"""
Format an ISO date string to a more readable format.
Args:
iso_date: ISO format date string
Returns:
Formatted date string
"""
if not iso_date or not isinstance(iso_date, str):
return "Unknown date"
try:
# Jira dates are typically in format: 2023-01-01T12:00:00.000+0000
dt = datetime.fromisoformat(iso_date.replace("Z", "+00:00"))
return dt.strftime("%Y-%m-%d %H:%M:%S")
except ValueError:
return iso_date

View file

@ -1,350 +0,0 @@
"""
Jira OAuth Connector.
Handles OAuth-based authentication and token refresh for Jira API access.
Supports both OAuth 2.0 (preferred) and legacy API token authentication.
"""
import logging
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.connectors.jira_connector import JiraConnector
from app.db import SearchSourceConnector
from app.schemas.atlassian_auth_credentials import AtlassianAuthCredentialsBase
from app.utils.oauth_security import TokenEncryption
logger = logging.getLogger(__name__)
class JiraHistoryConnector:
"""
Jira connector with OAuth support and automatic token refresh.
This connector uses OAuth 2.0 access tokens to authenticate with the
Jira API. It automatically refreshes expired tokens when needed.
Also supports legacy API token authentication for backward compatibility.
"""
def __init__(
self,
session: AsyncSession,
connector_id: int,
credentials: AtlassianAuthCredentialsBase | None = None,
):
"""
Initialize the JiraHistoryConnector with auto-refresh capability.
Args:
session: Database session for updating connector
connector_id: Connector ID for direct updates
credentials: Jira OAuth credentials (optional, will be loaded from DB if not provided)
"""
self._session = session
self._connector_id = connector_id
self._credentials = credentials
self._cloud_id: str | None = None
self._base_url: str | None = None
self._jira_client: JiraConnector | None = None
self._use_oauth = True
self._legacy_email: str | None = None
self._legacy_api_token: str | None = None
async def _get_valid_token(self) -> str:
"""
Get valid Jira access token, refreshing if needed.
Returns:
Valid access token
Raises:
ValueError: If credentials are missing or invalid
Exception: If token refresh fails
"""
# Load credentials from DB if not provided
if self._credentials is None:
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if not connector:
raise ValueError(f"Connector {self._connector_id} not found")
config_data = connector.config.copy()
# Check if using OAuth or legacy API token
is_oauth = config_data.get("_token_encrypted", False) or config_data.get(
"access_token"
)
if is_oauth:
# OAuth 2.0 authentication
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Jira access token not found. "
"Please reconnect your Jira account."
)
if not config.SECRET_KEY:
raise ValueError(
"SECRET_KEY not configured but tokens are marked as encrypted"
)
try:
token_encryption = TokenEncryption(config.SECRET_KEY)
# Decrypt access_token
if config_data.get("access_token"):
config_data["access_token"] = token_encryption.decrypt_token(
config_data["access_token"]
)
logger.info(
f"Decrypted Jira access token for connector {self._connector_id}"
)
# Decrypt refresh_token if present
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
logger.info(
f"Decrypted Jira refresh token for connector {self._connector_id}"
)
except Exception as e:
logger.error(
f"Failed to decrypt Jira credentials for connector {self._connector_id}: {e!s}"
)
raise ValueError(
f"Failed to decrypt Jira credentials: {e!s}"
) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (
isinstance(final_token, str) and not final_token.strip()
):
raise ValueError(
"Jira access token is invalid or empty. "
"Please reconnect your Jira account."
)
try:
self._credentials = AtlassianAuthCredentialsBase.from_dict(
config_data
)
self._cloud_id = config_data.get("cloud_id")
self._base_url = config_data.get("base_url")
self._use_oauth = True
except Exception as e:
raise ValueError(f"Invalid Jira OAuth credentials: {e!s}") from e
else:
# Legacy API token authentication
self._legacy_email = config_data.get("JIRA_EMAIL")
self._legacy_api_token = config_data.get("JIRA_API_TOKEN")
self._base_url = config_data.get("JIRA_BASE_URL")
self._use_oauth = False
if (
not self._legacy_email
or not self._legacy_api_token
or not self._base_url
):
raise ValueError("Jira credentials not found in connector config")
# Check if token is expired and refreshable (only for OAuth)
if (
self._use_oauth
and self._credentials.is_expired
and self._credentials.is_refreshable
):
try:
logger.info(
f"Jira token expired for connector {self._connector_id}, refreshing..."
)
# Get connector for refresh
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if not connector:
raise RuntimeError(
f"Connector {self._connector_id} not found; cannot refresh token."
)
# Lazy import to avoid circular dependency
from app.routes.jira_add_connector_route import refresh_jira_token
connector = await refresh_jira_token(self._session, connector)
# Reload credentials after refresh
config_data = connector.config.copy()
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("access_token"):
config_data["access_token"] = token_encryption.decrypt_token(
config_data["access_token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
self._credentials = AtlassianAuthCredentialsBase.from_dict(config_data)
self._cloud_id = config_data.get("cloud_id")
self._base_url = config_data.get("base_url")
# Invalidate cached client so it's recreated with new token
self._jira_client = None
logger.info(
f"Successfully refreshed Jira token for connector {self._connector_id}"
)
except Exception as e:
logger.error(
f"Failed to refresh Jira token for connector {self._connector_id}: {e!s}"
)
raise Exception(
f"Failed to refresh Jira OAuth credentials: {e!s}"
) from e
if self._use_oauth:
return self._credentials.access_token
else:
# For legacy auth, return empty string (not used for token-based auth)
return ""
async def _get_jira_client(self) -> JiraConnector:
"""
Get or create JiraConnector with valid credentials.
Returns:
JiraConnector instance
"""
if self._jira_client is None:
if self._use_oauth:
# Ensure we have valid token (will refresh if needed)
await self._get_valid_token()
self._jira_client = JiraConnector(
base_url=self._base_url,
access_token=self._credentials.access_token,
cloud_id=self._cloud_id,
)
else:
# Legacy API token authentication
self._jira_client = JiraConnector(
base_url=self._base_url,
email=self._legacy_email,
api_token=self._legacy_api_token,
)
else:
# If OAuth, refresh token if expired before returning client
if self._use_oauth:
await self._get_valid_token()
# Update client with new token if it was refreshed
if self._credentials:
self._jira_client.set_oauth_credentials(
base_url=self._base_url or "",
access_token=self._credentials.access_token,
cloud_id=self._cloud_id,
)
return self._jira_client
async def get_issues_by_date_range(
self,
start_date: str,
end_date: str,
include_comments: bool = True,
project_key: str | None = None,
) -> tuple[list[dict[str, Any]], str | None]:
"""
Fetch issues within a date range.
This method wraps JiraConnector.get_issues_by_date_range() with automatic token refresh.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format (inclusive)
include_comments: Whether to include comments in the response
project_key: Optional project key to filter issues
Returns:
Tuple containing (issues list, error message or None)
"""
# Ensure token is valid (will refresh if needed)
if self._use_oauth:
await self._get_valid_token()
# Get client with valid credentials
client = await self._get_jira_client()
# JiraConnector methods are synchronous, so we call them directly
# Token refresh has already been handled above
return client.get_issues_by_date_range(
start_date=start_date,
end_date=end_date,
include_comments=include_comments,
project_key=project_key,
)
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
"""
Format an issue for easier consumption.
Wraps JiraConnector.format_issue().
Args:
issue: The issue object from Jira API
Returns:
Formatted issue dictionary
"""
# This is a synchronous method that doesn't need token refresh
# since it just formats data that's already been fetched
if self._jira_client is None:
# Create a minimal client just for formatting (doesn't need credentials)
self._jira_client = JiraConnector()
return self._jira_client.format_issue(issue)
def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
"""
Convert an issue to markdown format.
Wraps JiraConnector.format_issue_to_markdown().
Args:
issue: The issue object (either raw or formatted)
Returns:
Markdown string representation of the issue
"""
# This is a synchronous method that doesn't need token refresh
# since it just formats data that's already been fetched
if self._jira_client is None:
# Create a minimal client just for formatting (doesn't need credentials)
self._jira_client = JiraConnector()
return self._jira_client.format_issue_to_markdown(issue)
async def close(self):
"""Close any resources (currently no-op for JiraConnector)."""
# JiraConnector doesn't maintain persistent connections, so nothing to close
self._jira_client = None
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.close()

View file

@ -1478,6 +1478,11 @@ class Document(BaseModel, TimestampMixin):
chunks = relationship( chunks = relationship(
"Chunk", back_populates="document", cascade="all, delete-orphan" "Chunk", back_populates="document", cascade="all, delete-orphan"
) )
# Original upload + future derived artifacts (redacted, filled-form).
# Model lives in app.file_storage.persistence to keep that feature cohesive.
files = relationship(
"DocumentFile", back_populates="document", cascade="all, delete-orphan"
)
class DocumentVersion(BaseModel, TimestampMixin): class DocumentVersion(BaseModel, TimestampMixin):
@ -2931,6 +2936,7 @@ from app.automations.persistence import ( # noqa: E402, F401
AutomationRun, AutomationRun,
AutomationTrigger, AutomationTrigger,
) )
from app.file_storage.persistence import DocumentFile # noqa: E402, F401
engine = create_async_engine( engine = create_async_engine(
DATABASE_URL, DATABASE_URL,

View file

@ -0,0 +1,15 @@
"""Durable storage for original uploaded files (and future derived artifacts).
Public surface: resolve the configured backend via :func:`get_storage_backend`
and persist/retrieve a document's files via :mod:`app.file_storage.service`.
"""
from __future__ import annotations
from app.file_storage.backends.base import StorageBackend
from app.file_storage.factory import get_storage_backend
__all__ = [
"StorageBackend",
"get_storage_backend",
]

View file

@ -0,0 +1,89 @@
"""HTTP routes for document file storage (metadata listing + original download)."""
from __future__ import annotations
from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, Permission, User, get_async_session
from app.file_storage.persistence.enums import DocumentFileKind
from app.file_storage.schemas import DocumentFileRead
from app.file_storage.service import (
get_document_file,
list_document_files,
open_document_file_stream,
)
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
async def _load_readable_document(
*, document_id: int, session: AsyncSession, user: User
) -> Document:
"""Load a document the user may read, or raise 404/403."""
document = (
await session.execute(select(Document).where(Document.id == document_id))
).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
await check_permission(
session,
user,
document.search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
)
return document
def _content_disposition(filename: str) -> str:
"""Build an attachment header safe for arbitrary filenames (RFC 5987)."""
fallback = filename.encode("ascii", "ignore").decode("ascii") or "download"
fallback = fallback.replace('"', "")
return f"attachment; filename=\"{fallback}\"; filename*=UTF-8''{quote(filename)}"
@router.get(
"/documents/{document_id}/files",
response_model=list[DocumentFileRead],
)
async def read_document_files(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> list[DocumentFileRead]:
"""Return metadata for every stored file of a document (gates the UI)."""
await _load_readable_document(document_id=document_id, session=session, user=user)
records = await list_document_files(session, document_id=document_id)
return [DocumentFileRead.model_validate(r) for r in records]
@router.get("/documents/{document_id}/download-original")
async def download_original_document_file(
document_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> StreamingResponse:
"""Stream the document's original uploaded file."""
await _load_readable_document(document_id=document_id, session=session, user=user)
record = await get_document_file(
session, document_id=document_id, kind=DocumentFileKind.ORIGINAL
)
if record is None:
raise HTTPException(
status_code=404, detail="No original file stored for this document"
)
return StreamingResponse(
open_document_file_stream(record),
media_type=record.mime_type or "application/octet-stream",
headers={"Content-Disposition": _content_disposition(record.original_filename)},
)

View file

@ -0,0 +1,7 @@
"""Storage backend implementations behind the shared :class:`StorageBackend`."""
from __future__ import annotations
from app.file_storage.backends.base import StorageBackend
__all__ = ["StorageBackend"]

View file

@ -0,0 +1,54 @@
"""Azure Blob Storage backend (the first production target)."""
from __future__ import annotations
from collections.abc import AsyncIterator
from app.file_storage.backends.base import StorageBackend
class AzureBlobBackend(StorageBackend):
"""Stores objects as blobs in an Azure Blob Storage container."""
backend_name = "azure"
def __init__(self, *, connection_string: str, container: str) -> None:
self._connection_string = connection_string
self._container = container
def _service(self):
from azure.storage.blob.aio import BlobServiceClient
return BlobServiceClient.from_connection_string(self._connection_string)
async def put(
self, key: str, data: bytes, *, content_type: str | None = None
) -> None:
from azure.storage.blob import ContentSettings
settings = ContentSettings(content_type=content_type) if content_type else None
async with self._service() as service:
blob = service.get_blob_client(self._container, key)
await blob.upload_blob(data, overwrite=True, content_settings=settings)
async def open_stream(self, key: str) -> AsyncIterator[bytes]:
async with self._service() as service:
blob = service.get_blob_client(self._container, key)
downloader = await blob.download_blob()
async for chunk in downloader.chunks():
yield chunk
async def delete(self, key: str) -> None:
from azure.core.exceptions import ResourceNotFoundError
async with self._service() as service:
blob = service.get_blob_client(self._container, key)
try:
await blob.delete_blob()
except ResourceNotFoundError:
pass
async def exists(self, key: str) -> bool:
async with self._service() as service:
blob = service.get_blob_client(self._container, key)
return await blob.exists()

View file

@ -0,0 +1,31 @@
"""The storage backend contract: the minimal object-store surface we depend on."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
class StorageBackend(ABC):
"""Maps an opaque object key to durable bytes."""
#: Identifier stored on each row to record which backend holds the bytes.
backend_name: str
@abstractmethod
async def put(
self, key: str, data: bytes, *, content_type: str | None = None
) -> None:
"""Store ``data`` at ``key``, overwriting any existing object."""
@abstractmethod
def open_stream(self, key: str) -> AsyncIterator[bytes]:
"""Yield the object's bytes in chunks. Raises if the key is absent."""
@abstractmethod
async def delete(self, key: str) -> None:
"""Remove the object at ``key``; a missing key is not an error."""
@abstractmethod
async def exists(self, key: str) -> bool:
"""Return whether an object is stored at ``key``."""

View file

@ -0,0 +1,64 @@
"""Local filesystem backend for development (no cloud credentials required)."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from pathlib import Path
from app.file_storage.backends.base import StorageBackend
_CHUNK_SIZE = 1024 * 1024
class LocalFileBackend(StorageBackend):
"""Stores objects as files under a single root directory."""
backend_name = "local"
def __init__(self, root: str) -> None:
self._root = Path(root).resolve()
def _path_for(self, key: str) -> Path:
# Resolve and confirm the key stays inside the root to block traversal.
target = (self._root / key).resolve()
if self._root not in target.parents and target != self._root:
raise ValueError("Resolved storage key escapes the storage root")
return target
async def put(
self, key: str, data: bytes, *, content_type: str | None = None
) -> None:
path = self._path_for(key)
def _write() -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(data)
await asyncio.to_thread(_write)
async def open_stream(self, key: str) -> AsyncIterator[bytes]:
path = self._path_for(key)
handle = await asyncio.to_thread(path.open, "rb")
try:
while True:
chunk = await asyncio.to_thread(handle.read, _CHUNK_SIZE)
if not chunk:
break
yield chunk
finally:
await asyncio.to_thread(handle.close)
async def delete(self, key: str) -> None:
path = self._path_for(key)
def _unlink() -> None:
try:
path.unlink()
except FileNotFoundError:
pass
await asyncio.to_thread(_unlink)
async def exists(self, key: str) -> bool:
return await asyncio.to_thread(self._path_for(key).exists)

View file

@ -0,0 +1,38 @@
"""Resolve the configured :class:`StorageBackend` as a process-wide singleton."""
from __future__ import annotations
from functools import lru_cache
from app.file_storage.backends.base import StorageBackend
from app.file_storage.settings import (
AZURE_BACKEND,
LOCAL_BACKEND,
load_storage_settings,
)
@lru_cache(maxsize=1)
def get_storage_backend() -> StorageBackend:
"""Build the backend selected by ``FILE_STORAGE_BACKEND`` (lazy-imported)."""
settings = load_storage_settings()
if settings.backend == AZURE_BACKEND:
if not settings.azure_connection_string or not settings.azure_container:
raise ValueError(
"Azure storage requires AZURE_STORAGE_CONNECTION_STRING and "
"AZURE_STORAGE_CONTAINER."
)
from app.file_storage.backends.azure import AzureBlobBackend
return AzureBlobBackend(
connection_string=settings.azure_connection_string,
container=settings.azure_container,
)
if settings.backend == LOCAL_BACKEND:
from app.file_storage.backends.local import LocalFileBackend
return LocalFileBackend(settings.local_root)
raise ValueError(f"Unknown FILE_STORAGE_BACKEND: {settings.backend!r}")

View file

@ -0,0 +1,27 @@
"""Object-key construction for stored document files."""
from __future__ import annotations
import os
import uuid
from app.file_storage.persistence.enums import DocumentFileKind
def build_document_file_key(
*,
search_space_id: int,
document_id: int,
kind: DocumentFileKind,
filename: str,
) -> str:
"""Build the storage key for one document file.
Shape: ``documents/{search_space_id}/{document_id}/{kind}/{uuid}{ext}``.
"""
extension = os.path.splitext(filename)[1].lower()
unique = uuid.uuid4().hex
return (
f"documents/{search_space_id}/{document_id}/"
f"{kind.value.lower()}/{unique}{extension}"
)

View file

@ -0,0 +1,11 @@
"""Models and enums for the document file-storage tables."""
from __future__ import annotations
from .enums import DocumentFileKind
from .models import DocumentFile
__all__ = [
"DocumentFile",
"DocumentFileKind",
]

View file

@ -0,0 +1,11 @@
"""DocumentFile kinds: the original upload plus future derived artifacts."""
from __future__ import annotations
from enum import StrEnum
class DocumentFileKind(StrEnum):
ORIGINAL = "ORIGINAL"
REDACTED = "REDACTED"
FILLED_FORM = "FILLED_FORM"

View file

@ -0,0 +1,66 @@
"""``document_files`` table: durable blobs associated with a document."""
from __future__ import annotations
from sqlalchemy import (
BigInteger,
Column,
Enum as SQLAlchemyEnum,
ForeignKey,
Integer,
String,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import BaseModel, TimestampMixin
from .enums import DocumentFileKind
class DocumentFile(BaseModel, TimestampMixin):
"""One stored file for a document (its original upload, or a derived copy)."""
__tablename__ = "document_files"
document_id = Column(
Integer,
ForeignKey("documents.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
search_space_id = Column(
Integer,
ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
kind = Column(
SQLAlchemyEnum(
DocumentFileKind,
name="document_file_kind",
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
default=DocumentFileKind.ORIGINAL,
server_default=DocumentFileKind.ORIGINAL.value,
index=True,
)
# Where the bytes live: the backend that stored them and its object key.
storage_backend = Column(String(32), nullable=False)
storage_key = Column(String, nullable=False)
original_filename = Column(String, nullable=False)
mime_type = Column(String, nullable=True)
size_bytes = Column(BigInteger, nullable=False)
checksum_sha256 = Column(String(64), nullable=True)
created_by_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
document = relationship("Document", back_populates="files")

View file

@ -0,0 +1,23 @@
"""API shapes for document file metadata."""
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, ConfigDict
from app.file_storage.persistence.enums import DocumentFileKind
class DocumentFileRead(BaseModel):
"""Lightweight metadata for one stored document file (no bytes)."""
id: int
document_id: int
kind: DocumentFileKind
original_filename: str
mime_type: str | None = None
size_bytes: int
created_at: datetime
model_config = ConfigDict(from_attributes=True)

View file

@ -0,0 +1,129 @@
"""Application service: persist, locate, and remove a document's stored files.
Coordinates the storage backend (bytes) with the ``document_files`` table
(metadata). Callers own the surrounding DB transaction/commit.
"""
from __future__ import annotations
import hashlib
import logging
from collections.abc import AsyncIterator, Sequence
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.file_storage.backends.base import StorageBackend
from app.file_storage.factory import get_storage_backend
from app.file_storage.keys import build_document_file_key
from app.file_storage.persistence.enums import DocumentFileKind
from app.file_storage.persistence.models import DocumentFile
logger = logging.getLogger(__name__)
async def store_document_file(
session: AsyncSession,
*,
document_id: int,
search_space_id: int,
data: bytes,
filename: str,
mime_type: str | None = None,
kind: DocumentFileKind = DocumentFileKind.ORIGINAL,
created_by_id: str | UUID | None = None,
backend: StorageBackend | None = None,
) -> DocumentFile:
"""Write bytes to storage and add a ``DocumentFile`` row to the session."""
backend = backend or get_storage_backend()
key = build_document_file_key(
search_space_id=search_space_id,
document_id=document_id,
kind=kind,
filename=filename,
)
await backend.put(key, data, content_type=mime_type)
record = DocumentFile(
document_id=document_id,
search_space_id=search_space_id,
kind=kind,
storage_backend=backend.backend_name,
storage_key=key,
original_filename=filename,
mime_type=mime_type,
size_bytes=len(data),
checksum_sha256=hashlib.sha256(data).hexdigest(),
created_by_id=created_by_id,
)
session.add(record)
return record
async def list_document_files(
session: AsyncSession, *, document_id: int
) -> list[DocumentFile]:
"""Return all stored files for a document, newest first."""
result = await session.execute(
select(DocumentFile)
.where(DocumentFile.document_id == document_id)
.order_by(DocumentFile.created_at.desc())
)
return list(result.scalars().all())
async def get_document_file(
session: AsyncSession,
*,
document_id: int,
kind: DocumentFileKind = DocumentFileKind.ORIGINAL,
) -> DocumentFile | None:
"""Return the most recent stored file of ``kind`` for a document."""
result = await session.execute(
select(DocumentFile)
.where(
DocumentFile.document_id == document_id,
DocumentFile.kind == kind,
)
.order_by(DocumentFile.created_at.desc())
)
return result.scalars().first()
def open_document_file_stream(
record: DocumentFile, *, backend: StorageBackend | None = None
) -> AsyncIterator[bytes]:
"""Open a chunked byte stream for a stored file."""
backend = backend or get_storage_backend()
return backend.open_stream(record.storage_key)
async def purge_document_blobs(
session: AsyncSession,
*,
document_ids: Sequence[int],
backend: StorageBackend | None = None,
) -> None:
"""Delete stored blobs for the given documents.
Call this before the ``document_files`` rows are removed (they cascade with
the document). Best-effort: a failed blob delete is logged, not raised, so
document deletion is never blocked by an orphaned blob.
"""
if not document_ids:
return
backend = backend or get_storage_backend()
result = await session.execute(
select(DocumentFile.storage_key).where(
DocumentFile.document_id.in_(document_ids)
)
)
for storage_key in result.scalars().all():
try:
await backend.delete(storage_key)
except Exception as delete_error:
logger.warning(
"Failed to delete stored blob %s: %s", storage_key, delete_error
)

View file

@ -0,0 +1,37 @@
"""Environment-driven configuration for the file-storage module."""
from __future__ import annotations
import os
from dataclasses import dataclass
from pathlib import Path
LOCAL_BACKEND = "local"
AZURE_BACKEND = "azure"
# surfsense_backend/ — two levels up from app/file_storage/settings.py
_BACKEND_ROOT = Path(__file__).resolve().parents[2]
_DEFAULT_LOCAL_ROOT = str(_BACKEND_ROOT / ".local_object_store")
@dataclass(frozen=True)
class StorageSettings:
"""Resolved storage configuration for the current process."""
backend: str
azure_connection_string: str | None
azure_container: str | None
local_root: str
def load_storage_settings() -> StorageSettings:
"""Read storage settings from the environment.
Defaults to the ``local`` backend so development needs no cloud creds.
"""
return StorageSettings(
backend=os.getenv("FILE_STORAGE_BACKEND", LOCAL_BACKEND).strip().lower(),
azure_connection_string=os.getenv("AZURE_STORAGE_CONNECTION_STRING"),
azure_container=os.getenv("AZURE_STORAGE_CONTAINER"),
local_root=os.getenv("FILE_STORAGE_LOCAL_PATH", _DEFAULT_LOCAL_ROOT),
)

View file

@ -1,6 +1,7 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.automations.api import router as automations_router from app.automations.api import router as automations_router
from app.file_storage.api import router as file_storage_router
from .agent_action_log_route import router as agent_action_log_router from .agent_action_log_route import router as agent_action_log_router
from .agent_flags_route import router as agent_flags_router from .agent_flags_route import router as agent_flags_router
@ -126,3 +127,4 @@ router.include_router(prompts_router)
router.include_router(memory_router) # User personal memory (memory.md style) router.include_router(memory_router) # User personal memory (memory.md style)
router.include_router(team_memory_router) # Search-space team memory router.include_router(team_memory_router) # Search-space team memory
router.include_router(automations_router) # Automations CRUD + run history router.include_router(automations_router) # Automations CRUD + run history
router.include_router(file_storage_router) # Original file metadata + download

View file

@ -44,10 +44,12 @@ except RuntimeError as e:
print("Error setting event loop policy", e) print("Error setting event loop policy", e)
pass pass
import logging
import os import os
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1" os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -142,9 +144,11 @@ async def create_documents_file_upload(
import os import os
import tempfile import tempfile
from datetime import datetime from datetime import datetime
from pathlib import Path
from app.db import DocumentStatus from app.db import DocumentStatus
from app.etl_pipeline.etl_document import ProcessingMode from app.etl_pipeline.etl_document import ProcessingMode
from app.file_storage.service import store_document_file
from app.tasks.document_processors.base import ( from app.tasks.document_processors.base import (
check_document_by_unique_identifier, check_document_by_unique_identifier,
get_current_timestamp, get_current_timestamp,
@ -175,11 +179,12 @@ async def create_documents_file_upload(
) )
# ===== Read all files concurrently to avoid blocking the event loop ===== # ===== Read all files concurrently to avoid blocking the event loop =====
async def _read_and_save(file: UploadFile) -> tuple[str, str, int]: async def _read_and_save(file: UploadFile) -> tuple[str, str, int, str | None]:
"""Read upload content and write to temp file off the event loop.""" """Read upload content and write to temp file off the event loop."""
content = await file.read() content = await file.read()
file_size = len(content) file_size = len(content)
filename = file.filename or "unknown" filename = file.filename or "unknown"
content_type = file.content_type
if file_size > MAX_FILE_SIZE_BYTES: if file_size > MAX_FILE_SIZE_BYTES:
raise HTTPException( raise HTTPException(
@ -196,17 +201,18 @@ async def create_documents_file_upload(
return tmp.name return tmp.name
temp_path = await asyncio.to_thread(_write_temp) temp_path = await asyncio.to_thread(_write_temp)
return temp_path, filename, file_size return temp_path, filename, file_size, content_type
saved_files = await asyncio.gather(*(_read_and_save(f) for f in files)) saved_files = await asyncio.gather(*(_read_and_save(f) for f in files))
# ===== PHASE 1: Create pending documents for all files ===== # ===== PHASE 1: Create pending documents for all files =====
created_documents: list[Document] = [] created_documents: list[Document] = []
files_to_process: list[tuple[Document, str, str]] = [] # (document, temp_path, filename, content_type)
files_to_process: list[tuple[Document, str, str, str | None]] = []
skipped_duplicates = 0 skipped_duplicates = 0
duplicate_document_ids: list[int] = [] duplicate_document_ids: list[int] = []
for temp_path, filename, file_size in saved_files: for temp_path, filename, file_size, content_type in saved_files:
try: try:
unique_identifier_hash = generate_unique_identifier_hash( unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.FILE, filename, search_space_id DocumentType.FILE, filename, search_space_id
@ -231,7 +237,9 @@ async def create_documents_file_upload(
} }
existing.updated_at = get_current_timestamp() existing.updated_at = get_current_timestamp()
created_documents.append(existing) created_documents.append(existing)
files_to_process.append((existing, temp_path, filename)) files_to_process.append(
(existing, temp_path, filename, content_type)
)
continue continue
document = Document( document = Document(
@ -253,7 +261,7 @@ async def create_documents_file_upload(
) )
session.add(document) session.add(document)
created_documents.append(document) created_documents.append(document)
files_to_process.append((document, temp_path, filename)) files_to_process.append((document, temp_path, filename, content_type))
except HTTPException: except HTTPException:
raise raise
@ -269,8 +277,32 @@ async def create_documents_file_upload(
for doc in created_documents: for doc in created_documents:
await session.refresh(doc) await session.refresh(doc)
# ===== PHASE 1.5: Persist the original uploads to durable storage =====
# Best-effort: a storage failure must not block parsing or the response.
for document, temp_path, filename, content_type in files_to_process:
try:
original_bytes = await asyncio.to_thread(
lambda p=temp_path: Path(p).read_bytes()
)
await store_document_file(
session,
document_id=document.id,
search_space_id=search_space_id,
data=original_bytes,
filename=filename,
mime_type=content_type,
created_by_id=str(user.id),
)
except Exception as storage_error:
logger.warning(
"Failed to store original upload for document %s: %s",
document.id,
storage_error,
)
await session.commit()
# ===== PHASE 2: Dispatch tasks for each file ===== # ===== PHASE 2: Dispatch tasks for each file =====
for document, temp_path, filename in files_to_process: for document, temp_path, filename, _content_type in files_to_process:
await dispatcher.dispatch_file_processing( await dispatcher.dispatch_file_processing(
document_id=document.id, document_id=document.id,
temp_path=temp_path, temp_path=temp_path,

View file

@ -1,13 +0,0 @@
from app.services.jira.kb_sync_service import JiraKBSyncService
from app.services.jira.tool_metadata_service import (
JiraIssue,
JiraToolMetadataService,
JiraWorkspace,
)
__all__ = [
"JiraIssue",
"JiraKBSyncService",
"JiraToolMetadataService",
"JiraWorkspace",
]

View file

@ -1,257 +0,0 @@
import asyncio
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.jira_history import JiraHistoryConnector
from app.db import Document, DocumentType
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
logger = logging.getLogger(__name__)
class JiraKBSyncService:
"""Syncs Jira issue documents to the knowledge base after HITL actions."""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def sync_after_create(
self,
issue_id: str,
issue_identifier: str,
issue_title: str,
description: str | None,
state: str | None,
connector_id: int,
search_space_id: int,
user_id: str,
) -> dict:
from app.tasks.connector_indexers.base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_current_timestamp,
safe_set_chunks,
)
try:
unique_hash = generate_unique_identifier_hash(
DocumentType.JIRA_CONNECTOR, issue_id, search_space_id
)
existing = await check_document_by_unique_identifier(
self.db_session, unique_hash
)
if existing:
logger.info(
"Document for Jira issue %s already exists (doc_id=%s), skipping",
issue_identifier,
existing.id,
)
return {"status": "success"}
indexable_content = (description or "").strip()
if not indexable_content:
indexable_content = f"Jira Issue {issue_identifier}: {issue_title}"
issue_content = (
f"# {issue_identifier}: {issue_title}\n\n{indexable_content}"
)
content_hash = generate_content_hash(issue_content, search_space_id)
with self.db_session.no_autoflush:
dup = await check_duplicate_document_by_hash(
self.db_session, content_hash
)
if dup:
content_hash = unique_hash
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session,
user_id,
search_space_id,
disable_streaming=True,
)
doc_metadata_for_summary = {
"issue_id": issue_identifier,
"issue_title": issue_title,
"document_type": "Jira Issue",
"connector_type": "Jira",
}
if user_llm:
summary_content, summary_embedding = await generate_document_summary(
issue_content, user_llm, doc_metadata_for_summary
)
else:
summary_content = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
)
summary_embedding = await asyncio.to_thread(embed_text, summary_content)
chunks = await create_document_chunks(issue_content)
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
document = Document(
title=f"{issue_identifier}: {issue_title}",
document_type=DocumentType.JIRA_CONNECTOR,
document_metadata={
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state or "Unknown",
"indexed_at": now_str,
"connector_id": connector_id,
},
content=summary_content,
content_hash=content_hash,
unique_identifier_hash=unique_hash,
embedding=summary_embedding,
search_space_id=search_space_id,
connector_id=connector_id,
updated_at=get_current_timestamp(),
created_by_id=user_id,
)
self.db_session.add(document)
await self.db_session.flush()
await safe_set_chunks(self.db_session, document, chunks)
await self.db_session.commit()
logger.info(
"KB sync after create succeeded: doc_id=%s, issue=%s",
document.id,
issue_identifier,
)
return {"status": "success"}
except Exception as e:
error_str = str(e).lower()
if (
"duplicate key value violates unique constraint" in error_str
or "uniqueviolationerror" in error_str
):
await self.db_session.rollback()
return {"status": "error", "message": "Duplicate document detected"}
logger.error(
"KB sync after create failed for issue %s: %s",
issue_identifier,
e,
exc_info=True,
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}
async def sync_after_update(
self,
document_id: int,
issue_id: str,
user_id: str,
search_space_id: int,
) -> dict:
from app.tasks.connector_indexers.base import (
get_current_timestamp,
safe_set_chunks,
)
try:
document = await self.db_session.get(Document, document_id)
if not document:
return {"status": "not_indexed"}
connector_id = document.connector_id
if not connector_id:
return {"status": "error", "message": "Document has no connector_id"}
jira_history = JiraHistoryConnector(
session=self.db_session, connector_id=connector_id
)
jira_client = await jira_history._get_jira_client()
issue_raw = await asyncio.to_thread(jira_client.get_issue, issue_id)
formatted = jira_client.format_issue(issue_raw)
issue_content = jira_client.format_issue_to_markdown(formatted)
if not issue_content:
return {"status": "error", "message": "Issue produced empty content"}
issue_identifier = formatted.get("key", "")
issue_title = formatted.get("title", "")
state = formatted.get("status", "Unknown")
comment_count = len(formatted.get("comments", []))
from app.services.llm_service import get_user_long_context_llm
user_llm = await get_user_long_context_llm(
self.db_session, user_id, search_space_id, disable_streaming=True
)
if user_llm:
doc_meta = {
"issue_key": issue_identifier,
"issue_title": issue_title,
"status": state,
"document_type": "Jira Issue",
"connector_type": "Jira",
}
summary_content, summary_embedding = await generate_document_summary(
issue_content, user_llm, doc_meta
)
else:
summary_content = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n{issue_content}"
)
summary_embedding = await asyncio.to_thread(embed_text, summary_content)
chunks = await create_document_chunks(issue_content)
document.title = f"{issue_identifier}: {issue_title}"
document.content = summary_content
document.content_hash = generate_content_hash(
issue_content, search_space_id
)
document.embedding = summary_embedding
from sqlalchemy.orm.attributes import flag_modified
document.document_metadata = {
**(document.document_metadata or {}),
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"comment_count": comment_count,
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
flag_modified(document, "document_metadata")
await safe_set_chunks(self.db_session, document, chunks)
document.updated_at = get_current_timestamp()
await self.db_session.commit()
logger.info(
"KB sync successful for document %s (%s: %s)",
document_id,
issue_identifier,
issue_title,
)
return {"status": "success"}
except Exception as e:
logger.error(
"KB sync failed for document %s: %s", document_id, e, exc_info=True
)
await self.db_session.rollback()
return {"status": "error", "message": str(e)}

View file

@ -1,332 +0,0 @@
import asyncio
import logging
from dataclasses import dataclass
from sqlalchemy import and_, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.connectors.jira_history import JiraHistoryConnector
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
)
logger = logging.getLogger(__name__)
@dataclass
class JiraWorkspace:
"""Represents a Jira connector as a workspace for tool context."""
id: int
name: str
base_url: str
@classmethod
def from_connector(cls, connector: SearchSourceConnector) -> "JiraWorkspace":
return cls(
id=connector.id,
name=connector.name,
base_url=connector.config.get("base_url", ""),
)
def to_dict(self) -> dict:
return {
"id": self.id,
"name": self.name,
"base_url": self.base_url,
}
@dataclass
class JiraIssue:
"""Represents an indexed Jira issue resolved from the knowledge base."""
issue_id: str
issue_identifier: str
issue_title: str
state: str
connector_id: int
document_id: int
indexed_at: str | None
@classmethod
def from_document(cls, document: Document) -> "JiraIssue":
meta = document.document_metadata or {}
return cls(
issue_id=meta.get("issue_id", ""),
issue_identifier=meta.get("issue_identifier", ""),
issue_title=meta.get("issue_title", document.title),
state=meta.get("state", ""),
connector_id=document.connector_id,
document_id=document.id,
indexed_at=meta.get("indexed_at"),
)
def to_dict(self) -> dict:
return {
"issue_id": self.issue_id,
"issue_identifier": self.issue_identifier,
"issue_title": self.issue_title,
"state": self.state,
"connector_id": self.connector_id,
"document_id": self.document_id,
"indexed_at": self.indexed_at,
}
class JiraToolMetadataService:
"""Builds interrupt context for Jira HITL tools."""
def __init__(self, db_session: AsyncSession):
self._db_session = db_session
async def _check_account_health(self, connector: SearchSourceConnector) -> bool:
"""Check if the Jira connector auth is still valid.
Returns True if auth is expired/invalid, False if healthy.
"""
try:
jira_history = JiraHistoryConnector(
session=self._db_session, connector_id=connector.id
)
jira_client = await jira_history._get_jira_client()
await asyncio.to_thread(jira_client.get_myself)
return False
except Exception as e:
logger.warning("Jira connector %s health check failed: %s", connector.id, e)
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await self._db_session.commit()
await self._db_session.refresh(connector)
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return True
async def get_creation_context(self, search_space_id: int, user_id: str) -> dict:
"""Return context needed to create a new Jira issue.
Fetches all connected Jira accounts, and for the first healthy one
fetches projects, issue types, and priorities.
"""
connectors = await self._get_all_jira_connectors(search_space_id, user_id)
if not connectors:
return {"error": "No Jira account connected"}
accounts = []
projects = []
issue_types = []
priorities = []
fetched_context = False
for connector in connectors:
auth_expired = await self._check_account_health(connector)
workspace = JiraWorkspace.from_connector(connector)
account_info = {
**workspace.to_dict(),
"auth_expired": auth_expired,
}
accounts.append(account_info)
if not auth_expired and not fetched_context:
try:
jira_history = JiraHistoryConnector(
session=self._db_session, connector_id=connector.id
)
jira_client = await jira_history._get_jira_client()
raw_projects = await asyncio.to_thread(jira_client.get_projects)
projects = [
{"id": p.get("id"), "key": p.get("key"), "name": p.get("name")}
for p in raw_projects
]
raw_types = await asyncio.to_thread(jira_client.get_issue_types)
seen_type_names: set[str] = set()
issue_types = []
for t in raw_types:
if t.get("subtask", False):
continue
name = t.get("name")
if name not in seen_type_names:
seen_type_names.add(name)
issue_types.append({"id": t.get("id"), "name": name})
raw_priorities = await asyncio.to_thread(jira_client.get_priorities)
priorities = [
{"id": p.get("id"), "name": p.get("name")}
for p in raw_priorities
]
fetched_context = True
except Exception as e:
logger.warning(
"Failed to fetch Jira context for connector %s: %s",
connector.id,
e,
)
return {
"accounts": accounts,
"projects": projects,
"issue_types": issue_types,
"priorities": priorities,
}
async def get_update_context(
self, search_space_id: int, user_id: str, issue_ref: str
) -> dict:
"""Return context needed to update an indexed Jira issue.
Resolves the issue from the KB, then fetches current details from the Jira API.
"""
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
if not document:
return {
"error": f"Issue '{issue_ref}' not found in your synced Jira issues. "
"Please make sure the issue is indexed in your knowledge base."
}
connector = await self._get_connector_for_document(document, user_id)
if not connector:
return {"error": "Connector not found or access denied"}
auth_expired = await self._check_account_health(connector)
if auth_expired:
return {
"error": "Jira authentication has expired. Please re-authenticate.",
"auth_expired": True,
"connector_id": connector.id,
}
workspace = JiraWorkspace.from_connector(connector)
issue = JiraIssue.from_document(document)
try:
jira_history = JiraHistoryConnector(
session=self._db_session, connector_id=connector.id
)
jira_client = await jira_history._get_jira_client()
issue_data = await asyncio.to_thread(jira_client.get_issue, issue.issue_id)
formatted = jira_client.format_issue(issue_data)
except Exception as e:
error_str = str(e).lower()
if (
"401" in error_str
or "403" in error_str
or "authentication" in error_str
):
return {
"error": f"Failed to fetch Jira issue: {e!s}",
"auth_expired": True,
"connector_id": connector.id,
}
return {"error": f"Failed to fetch Jira issue: {e!s}"}
return {
"account": {**workspace.to_dict(), "auth_expired": False},
"issue": {
"issue_id": formatted.get("key", issue.issue_id),
"issue_identifier": formatted.get("key", issue.issue_identifier),
"issue_title": formatted.get("title", issue.issue_title),
"state": formatted.get("status", "Unknown"),
"priority": formatted.get("priority", "Unknown"),
"issue_type": formatted.get("issue_type", "Unknown"),
"assignee": formatted.get("assignee"),
"description": formatted.get("description"),
"project": formatted.get("project", ""),
"document_id": issue.document_id,
"indexed_at": issue.indexed_at,
},
}
async def get_deletion_context(
self, search_space_id: int, user_id: str, issue_ref: str
) -> dict:
"""Return context needed to delete a Jira issue (KB metadata only, no API call)."""
document = await self._resolve_issue(search_space_id, user_id, issue_ref)
if not document:
return {
"error": f"Issue '{issue_ref}' not found in your synced Jira issues. "
"Please make sure the issue is indexed in your knowledge base."
}
connector = await self._get_connector_for_document(document, user_id)
if not connector:
return {"error": "Connector not found or access denied"}
auth_expired = connector.config.get("auth_expired", False)
workspace = JiraWorkspace.from_connector(connector)
issue = JiraIssue.from_document(document)
return {
"account": {**workspace.to_dict(), "auth_expired": auth_expired},
"issue": issue.to_dict(),
}
async def _resolve_issue(
self, search_space_id: int, user_id: str, issue_ref: str
) -> Document | None:
"""Resolve an issue from KB: issue_identifier -> issue_title -> document.title."""
ref_lower = issue_ref.lower()
result = await self._db_session.execute(
select(Document)
.join(
SearchSourceConnector, Document.connector_id == SearchSourceConnector.id
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.JIRA_CONNECTOR,
SearchSourceConnector.user_id == user_id,
or_(
func.lower(
Document.document_metadata.op("->>")("issue_identifier")
)
== ref_lower,
func.lower(Document.document_metadata.op("->>")("issue_title"))
== ref_lower,
func.lower(Document.title) == ref_lower,
),
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
return result.scalars().first()
async def _get_all_jira_connectors(
self, search_space_id: int, user_id: str
) -> list[SearchSourceConnector]:
result = await self._db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.JIRA_CONNECTOR,
)
)
)
return result.scalars().all()
async def _get_connector_for_document(
self, document: Document, user_id: str
) -> SearchSourceConnector | None:
if not document.connector_id:
return None
result = await self._db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.user_id == user_id,
)
)
)
return result.scalars().first()

View file

@ -118,6 +118,7 @@ async def _delete_document_background(document_id: int) -> None:
from sqlalchemy import delete as sa_delete, select from sqlalchemy import delete as sa_delete, select
from app.db import Chunk, Document from app.db import Chunk, Document
from app.file_storage.service import purge_document_blobs
async with get_celery_session_maker()() as session: async with get_celery_session_maker()() as session:
batch_size = 500 batch_size = 500
@ -133,6 +134,9 @@ async def _delete_document_background(document_id: int) -> None:
await session.execute(sa_delete(Chunk).where(Chunk.id.in_(chunk_ids))) await session.execute(sa_delete(Chunk).where(Chunk.id.in_(chunk_ids)))
await session.commit() await session.commit()
# Remove stored blobs before the document_files rows cascade away.
await purge_document_blobs(session, document_ids=[document_id])
doc = await session.get(Document, document_id) doc = await session.get(Document, document_id)
if doc: if doc:
await session.delete(doc) await session.delete(doc)
@ -166,6 +170,7 @@ async def _delete_folder_documents(
from sqlalchemy import delete as sa_delete, select from sqlalchemy import delete as sa_delete, select
from app.db import Chunk, Document, Folder from app.db import Chunk, Document, Folder
from app.file_storage.service import purge_document_blobs
async with get_celery_session_maker()() as session: async with get_celery_session_maker()() as session:
batch_size = 500 batch_size = 500
@ -182,6 +187,9 @@ async def _delete_folder_documents(
await session.execute(sa_delete(Chunk).where(Chunk.id.in_(chunk_ids))) await session.execute(sa_delete(Chunk).where(Chunk.id.in_(chunk_ids)))
await session.commit() await session.commit()
# Remove stored blobs before the document_files rows cascade away.
await purge_document_blobs(session, document_ids=[doc_id])
doc = await session.get(Document, doc_id) doc = await session.get(Document, doc_id)
if doc: if doc:
await session.delete(doc) await session.delete(doc)
@ -214,6 +222,7 @@ async def _delete_search_space_background(search_space_id: int) -> None:
from sqlalchemy import delete as sa_delete, select from sqlalchemy import delete as sa_delete, select
from app.db import Chunk, Document, SearchSpace from app.db import Chunk, Document, SearchSpace
from app.file_storage.service import purge_document_blobs
async with get_celery_session_maker()() as session: async with get_celery_session_maker()() as session:
batch_size = 500 batch_size = 500
@ -240,6 +249,8 @@ async def _delete_search_space_background(search_space_id: int) -> None:
doc_ids = doc_ids_result.scalars().all() doc_ids = doc_ids_result.scalars().all()
if not doc_ids: if not doc_ids:
break break
# Remove stored blobs before the document_files rows cascade away.
await purge_document_blobs(session, document_ids=list(doc_ids))
await session.execute(sa_delete(Document).where(Document.id.in_(doc_ids))) await session.execute(sa_delete(Document).where(Document.id.in_(doc_ids)))
await session.commit() await session.commit()

View file

@ -1,364 +0,0 @@
"""Jira connector indexer using the unified parallel indexing pipeline."""
import contextlib
from collections.abc import Awaitable, Callable
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.jira_history import JiraHistoryConnector
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from .base import (
calculate_date_range,
check_duplicate_document_by_hash,
get_connector_by_id,
logger,
update_connector_last_indexed,
)
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
issue: dict,
formatted_issue: dict,
issue_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Jira issue dict to a ConnectorDocument."""
issue_id = issue.get("key", "")
issue_identifier = issue.get("key", "")
issue_title = issue.get("id", "")
state = formatted_issue.get("status", "Unknown")
priority = formatted_issue.get("priority", "Unknown")
comment_count = len(formatted_issue.get("comments", []))
metadata = {
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"priority": priority,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Jira Issue",
"connector_type": "Jira",
}
fallback_summary = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n"
f"Status: {state}\n\n{issue_content}"
)
return ConnectorDocument(
title=f"{issue_identifier}: {issue_title}",
source_markdown=issue_content,
unique_id=issue_id,
document_type=DocumentType.JIRA_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_jira_issues(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str | None = None,
end_date: str | None = None,
update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int, str | None]:
"""Index Jira issues and comments."""
task_logger = TaskLoggingService(session, search_space_id)
log_entry = await task_logger.log_task_start(
task_name="jira_issues_indexing",
source="connector_indexing_task",
message=f"Starting Jira issues indexing for connector {connector_id}",
metadata={
"connector_id": connector_id,
"user_id": str(user_id),
"start_date": start_date,
"end_date": end_date,
},
)
try:
connector = await get_connector_by_id(
session, connector_id, SearchSourceConnectorType.JIRA_CONNECTOR
)
if not connector:
await task_logger.log_task_failure(
log_entry,
f"Connector with ID {connector_id} not found",
"Connector not found",
{"error_type": "ConnectorNotFound"},
)
return 0, 0, f"Connector with ID {connector_id} not found"
await task_logger.log_task_progress(
log_entry,
f"Initializing Jira client for connector {connector_id}",
{"stage": "client_initialization"},
)
jira_client = JiraHistoryConnector(session=session, connector_id=connector_id)
if start_date == "undefined" or start_date == "":
start_date = None
if end_date == "undefined" or end_date == "":
end_date = None
start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365
)
await task_logger.log_task_progress(
log_entry,
f"Fetching Jira issues from {start_date_str} to {end_date_str}",
{
"stage": "fetching_issues",
"start_date": start_date_str,
"end_date": end_date_str,
},
)
try:
issues, error = await jira_client.get_issues_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True
)
if error:
if "No issues found" in error:
logger.info(f"No Jira issues found: {error}")
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
logger.info(
f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found"
)
await task_logger.log_task_success(
log_entry,
f"No Jira issues found in date range {start_date_str} to {end_date_str}",
{"issues_found": 0},
)
await jira_client.close()
return 0, 0, None
else:
logger.error(f"Failed to get Jira issues: {error}")
await task_logger.log_task_failure(
log_entry,
f"Failed to get Jira issues: {error}",
"API Error",
{"error_type": "APIError"},
)
await jira_client.close()
return 0, 0, f"Failed to get Jira issues: {error}"
logger.info(f"Retrieved {len(issues)} issues from Jira API")
except Exception as e:
logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True)
await jira_client.close()
return 0, 0, f"Error fetching Jira issues: {e!s}"
if not issues:
logger.info("No Jira issues found for the specified date range")
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
await jira_client.close()
return 0, 0, None
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=f"{issue.get('key', '')}: {issue.get('id', '')}",
document_type=DocumentType.JIRA_CONNECTOR,
unique_id=issue.get("key", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"issue_id": issue.get("key", ""),
"connector_id": connector_id,
"connector_type": "Jira",
},
)
for issue in issues
if issue.get("key") and issue.get("id")
]
await pipeline.create_placeholder_documents(placeholders)
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
duplicate_content_count = 0
for issue in issues:
try:
issue_id = issue.get("key")
issue_identifier = issue.get("key", "")
issue_title = issue.get("id", "")
if not issue_id or not issue_title:
logger.warning(
f"Skipping issue with missing ID or title: {issue_id or 'Unknown'}"
)
documents_skipped += 1
continue
formatted_issue = jira_client.format_issue(issue)
issue_content = jira_client.format_issue_to_markdown(formatted_issue)
if not issue_content:
logger.warning(
f"Skipping issue with no content: {issue_identifier} - {issue_title}"
)
documents_skipped += 1
continue
doc = _build_connector_doc(
issue,
formatted_issue,
issue_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, compute_content_hash(doc)
)
if duplicate_by_content:
logger.info(
f"Jira issue {issue_identifier} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue
connector_docs.append(doc)
except Exception as e:
logger.error(
f"Error building ConnectorDocument for issue {issue_identifier}: {e!s}",
exc_info=True,
)
documents_skipped += 1
continue
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s: AsyncSession):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
await update_connector_last_indexed(session, connector, update_last_indexed)
logger.info(f"Final commit: Total {documents_indexed} Jira issues processed")
try:
await session.commit()
logger.info("Successfully committed all JIRA document changes to database")
except Exception as e:
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same issue was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
await task_logger.log_task_success(
log_entry,
f"Successfully completed JIRA indexing for connector {connector_id}",
{
"documents_indexed": documents_indexed,
"documents_skipped": documents_skipped,
"documents_failed": documents_failed,
"duplicate_content_count": duplicate_content_count,
},
)
logger.info(
f"JIRA indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)"
)
await jira_client.close()
return documents_indexed, documents_skipped, warning_message
except SQLAlchemyError as db_error:
await session.rollback()
await task_logger.log_task_failure(
log_entry,
f"Database error during JIRA indexing for connector {connector_id}",
str(db_error),
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
if "jira_client" in locals():
with contextlib.suppress(Exception):
await jira_client.close()
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
log_entry,
f"Failed to index JIRA issues for connector {connector_id}",
str(e),
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True)
if "jira_client" in locals():
with contextlib.suppress(Exception):
await jira_client.close()
return 0, 0, f"Failed to index JIRA issues: {e!s}"

View file

@ -46,6 +46,7 @@ dependencies = [
"redis>=5.2.1", "redis>=5.2.1",
"firecrawl-py>=4.9.0", "firecrawl-py>=4.9.0",
"boto3>=1.35.0", "boto3>=1.35.0",
"azure-storage-blob>=12.23.0",
"fake-useragent>=2.2.0", "fake-useragent>=2.2.0",
"trafilatura>=2.0.0", "trafilatura>=2.0.0",
"fastapi-users[oauth,sqlalchemy]>=15.0.3", "fastapi-users[oauth,sqlalchemy]>=15.0.3",

View file

@ -1,387 +0,0 @@
"""Tests for Jira indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.jira_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.jira_indexer import (
_build_connector_doc,
index_jira_issues,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_issue(
issue_key: str = "ENG-1",
issue_id: str = "10001",
title: str = "Fix login",
):
return {"key": issue_key, "id": issue_id, "title": title}
def _make_formatted_issue(
issue_key: str = "ENG-1",
issue_id: str = "10001",
title: str = "Fix login",
status: str = "In Progress",
priority: str = "High",
comments=None,
):
return {
"key": issue_key,
"id": issue_id,
"title": title,
"status": status,
"priority": priority,
"comments": comments or [],
}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
issue = _make_issue(issue_key="ENG-42", issue_id="4242", title="Fix auth bug")
formatted = _make_formatted_issue(
issue_key="ENG-42",
issue_id="4242",
title="Fix auth bug",
status="Done",
priority="Urgent",
comments=[{"id": "c1"}],
)
markdown = "# ENG-42: Fix auth bug\n\nBody"
doc = _build_connector_doc(
issue,
formatted,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "ENG-42: 4242"
assert doc.unique_id == "ENG-42"
assert doc.document_type == DocumentType.JIRA_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["issue_id"] == "ENG-42"
assert doc.metadata["issue_identifier"] == "ENG-42"
assert doc.metadata["issue_title"] == "4242"
assert doc.metadata["state"] == "Done"
assert doc.metadata["priority"] == "Urgent"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Jira Issue"
assert doc.metadata["connector_type"] == "Jira"
assert doc.fallback_summary is not None
assert "ENG-42" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
doc = _build_connector_doc(
_make_issue(),
_make_formatted_issue(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_jira_client(issues=None, error=None):
client = MagicMock()
client.get_issues_by_date_range = AsyncMock(
return_value=(issues if issues is not None else [], error),
)
client.format_issue = MagicMock(
side_effect=lambda i: _make_formatted_issue(
issue_key=i.get("key", ""),
issue_id=i.get("id", ""),
title=i.get("title", ""),
)
)
client.format_issue_to_markdown = MagicMock(
side_effect=lambda fi: f"# {fi.get('key', '')}: {fi.get('id', '')}\n\nContent"
)
client.close = AsyncMock()
return client
@pytest.fixture
def jira_mocks(monkeypatch):
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
jira_client = _mock_jira_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod,
"JiraHistoryConnector",
MagicMock(return_value=jira_client),
)
monkeypatch.setattr(
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"jira_client": jira_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_jira_issues(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_issue_calls_pipeline_and_returns_indexed_count(jira_mocks):
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
jira_mocks["batch_mock"].assert_called_once()
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.JIRA_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(jira_mocks):
await _run_index(jira_mocks)
call_kwargs = jira_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(jira_mocks):
await _run_index(jira_mocks)
jira_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Issue skipping (missing key/title/content)
# ---------------------------------------------------------------------------
async def test_issues_with_missing_key_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
{"key": "", "id": "10002", "title": "No key"},
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_issues_with_missing_title_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
{"key": "ENG-2", "id": "", "title": "Missing id used as title"},
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_issues_with_no_content_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
_make_issue(issue_key="ENG-2", issue_id="10002"),
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
jira_mocks["jira_client"].format_issue_to_markdown.side_effect = [
"# ENG-1: 10001\n\nContent",
"",
]
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_issues_are_skipped(jira_mocks, monkeypatch):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
_make_issue(issue_key="ENG-2", issue_id="10002"),
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(jira_mocks):
heartbeat_cb = AsyncMock()
await _run_index(jira_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = jira_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty issues and no-data success tuple
# ---------------------------------------------------------------------------
async def test_empty_issues_returns_zero_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = ([], None)
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
jira_mocks["batch_mock"].assert_not_called()
async def test_no_issues_error_message_returns_success_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
[],
"No issues found in date range",
)
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
async def test_api_error_still_returns_3_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
[],
"API exploded",
)
result = await _run_index(jira_mocks)
assert len(result) == 3
assert result[0] == 0
assert result[1] == 0
assert "Failed to get Jira issues" in result[2]
# ---------------------------------------------------------------------------
# Slice 7: Failed docs warning
# ---------------------------------------------------------------------------
async def test_failed_docs_warning_in_result(jira_mocks):
jira_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(jira_mocks)
assert warning is not None
assert "2 failed" in warning

8518
surfsense_backend/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -243,6 +243,38 @@ export function getConnectorTitle(connectorType: string): string {
); );
} }
/**
* Primary way a user interacts with a connector.
* Drives the two top-level groupings in the connector catalog UI.
*/
export type ConnectorCategory = "knowledge_base" | "tools_live";
export const CONNECTOR_CATEGORY_LABELS: Record<ConnectorCategory, string> = {
knowledge_base: "Knowledge Base",
tools_live: "Tools & Live Sources",
};
const KNOWLEDGE_BASE_CONNECTOR_TYPES = new Set<string>([
EnumConnectorName.GOOGLE_DRIVE_CONNECTOR,
EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
EnumConnectorName.ONEDRIVE_CONNECTOR,
EnumConnectorName.DROPBOX_CONNECTOR,
EnumConnectorName.NOTION_CONNECTOR,
EnumConnectorName.CONFLUENCE_CONNECTOR,
EnumConnectorName.YOUTUBE_CONNECTOR,
EnumConnectorName.WEBCRAWLER_CONNECTOR,
EnumConnectorName.BOOKSTACK_CONNECTOR,
EnumConnectorName.GITHUB_CONNECTOR,
EnumConnectorName.ELASTICSEARCH_CONNECTOR,
EnumConnectorName.CIRCLEBACK_CONNECTOR,
EnumConnectorName.OBSIDIAN_CONNECTOR,
]);
/** Unmapped connectors surface under Tools & Live Sources. */
export function getConnectorCategory(connectorType: string): ConnectorCategory {
return KNOWLEDGE_BASE_CONNECTOR_TYPES.has(connectorType) ? "knowledge_base" : "tools_live";
}
// Composio Toolkits (available integrations via Composio) // Composio Toolkits (available integrations via Composio)
export const COMPOSIO_TOOLKITS = [ export const COMPOSIO_TOOLKITS = [
{ {

View file

@ -9,7 +9,10 @@ import { isSelfHosted } from "@/lib/env-config";
import { ConnectorCard } from "../components/connector-card"; import { ConnectorCard } from "../components/connector-card";
import { import {
COMPOSIO_CONNECTORS, COMPOSIO_CONNECTORS,
CONNECTOR_CATEGORY_LABELS,
type ConnectorCategory,
CRAWLERS, CRAWLERS,
getConnectorCategory,
OAUTH_CONNECTORS, OAUTH_CONNECTORS,
OTHER_CONNECTORS, OTHER_CONNECTORS,
} from "../constants/connector-constants"; } from "../constants/connector-constants";
@ -20,19 +23,6 @@ type ComposioConnector = (typeof COMPOSIO_CONNECTORS)[number];
type OtherConnector = (typeof OTHER_CONNECTORS)[number]; type OtherConnector = (typeof OTHER_CONNECTORS)[number];
type CrawlerConnector = (typeof CRAWLERS)[number]; type CrawlerConnector = (typeof CRAWLERS)[number];
const DOCUMENT_FILE_CONNECTOR_TYPES = new Set<string>([
EnumConnectorName.GOOGLE_DRIVE_CONNECTOR,
EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
EnumConnectorName.ONEDRIVE_CONNECTOR,
EnumConnectorName.DROPBOX_CONNECTOR,
]);
const OTHER_DOCUMENT_CONNECTOR_TYPES = new Set<string>([
EnumConnectorName.YOUTUBE_CONNECTOR,
EnumConnectorName.NOTION_CONNECTOR,
EnumConnectorName.AIRTABLE_CONNECTOR,
]);
/** /**
* Extract the display name from a full connector name. * Extract the display name from a full connector name.
* Full names are in format "Base Name - identifier" (e.g., "Gmail - john@example.com"). * Full names are in format "Base Name - identifier" (e.g., "Gmail - john@example.com").
@ -106,45 +96,23 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
c.description.toLowerCase().includes(searchQuery.toLowerCase()) c.description.toLowerCase().includes(searchQuery.toLowerCase())
); );
const nativeGoogleDriveConnectors = filteredOAuth.filter( const inCategory =
(c) => c.connectorType === EnumConnectorName.GOOGLE_DRIVE_CONNECTOR (category: ConnectorCategory) =>
); <T extends { connectorType?: string }>(connector: T): boolean =>
const composioGoogleDriveConnectors = filteredComposio.filter( !!connector.connectorType && getConnectorCategory(connector.connectorType) === category;
(c) => c.connectorType === EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
);
const fileStorageConnectors = filteredOAuth.filter(
(c) =>
c.connectorType === EnumConnectorName.ONEDRIVE_CONNECTOR ||
c.connectorType === EnumConnectorName.DROPBOX_CONNECTOR
);
const otherDocumentYouTubeConnectors = filteredCrawlers.filter( const knowledgeBase = {
(c) => c.connectorType === EnumConnectorName.YOUTUBE_CONNECTOR oauth: filteredOAuth.filter(inCategory("knowledge_base")),
); composio: filteredComposio.filter(inCategory("knowledge_base")),
const otherDocumentNotionConnectors = filteredOAuth.filter( other: filteredOther.filter(inCategory("knowledge_base")),
(c) => c.connectorType === EnumConnectorName.NOTION_CONNECTOR crawlers: filteredCrawlers.filter(inCategory("knowledge_base")),
); };
const otherDocumentAirtableConnectors = filteredOAuth.filter( const toolsLive = {
(c) => c.connectorType === EnumConnectorName.AIRTABLE_CONNECTOR oauth: filteredOAuth.filter(inCategory("tools_live")),
); composio: filteredComposio.filter(inCategory("tools_live")),
other: filteredOther.filter(inCategory("tools_live")),
const moreIntegrationsComposio = filteredComposio.filter( crawlers: filteredCrawlers.filter(inCategory("tools_live")),
(c) => };
!DOCUMENT_FILE_CONNECTOR_TYPES.has(c.connectorType) &&
!OTHER_DOCUMENT_CONNECTOR_TYPES.has(c.connectorType)
);
const moreIntegrationsOAuth = filteredOAuth.filter(
(c) =>
!DOCUMENT_FILE_CONNECTOR_TYPES.has(c.connectorType) &&
!OTHER_DOCUMENT_CONNECTOR_TYPES.has(c.connectorType)
);
const moreIntegrationsOther = filteredOther;
const moreIntegrationsCrawlers = filteredCrawlers.filter(
(c) =>
!c.connectorType ||
(!DOCUMENT_FILE_CONNECTOR_TYPES.has(c.connectorType) &&
!OTHER_DOCUMENT_CONNECTOR_TYPES.has(c.connectorType))
);
const renderOAuthCard = (connector: OAuthConnector | ComposioConnector) => { const renderOAuthCard = (connector: OAuthConnector | ComposioConnector) => {
const isConnected = connectedTypes.has(connector.connectorType); const isConnected = connectedTypes.has(connector.connectorType);
@ -275,20 +243,18 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
); );
}; };
const hasDocumentFileConnectors = const hasKnowledgeBase =
nativeGoogleDriveConnectors.length > 0 || knowledgeBase.oauth.length > 0 ||
composioGoogleDriveConnectors.length > 0 || knowledgeBase.composio.length > 0 ||
fileStorageConnectors.length > 0; knowledgeBase.other.length > 0 ||
const hasMoreIntegrations = knowledgeBase.crawlers.length > 0;
otherDocumentYouTubeConnectors.length > 0 || const hasToolsLive =
otherDocumentNotionConnectors.length > 0 || toolsLive.oauth.length > 0 ||
otherDocumentAirtableConnectors.length > 0 || toolsLive.composio.length > 0 ||
moreIntegrationsComposio.length > 0 || toolsLive.other.length > 0 ||
moreIntegrationsOAuth.length > 0 || toolsLive.crawlers.length > 0;
moreIntegrationsOther.length > 0 ||
moreIntegrationsCrawlers.length > 0;
const hasAnyResults = hasDocumentFileConnectors || hasMoreIntegrations; const hasAnyResults = hasKnowledgeBase || hasToolsLive;
if (!hasAnyResults && searchQuery) { if (!hasAnyResults && searchQuery) {
return ( return (
@ -302,36 +268,34 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
return ( return (
<div className="space-y-8"> <div className="space-y-8">
{/* File Storage Integrations */} {hasKnowledgeBase && (
{hasDocumentFileConnectors && (
<section> <section>
<div className="flex items-center gap-2 mb-4"> <div className="flex items-center gap-2 mb-4">
<h3 className="text-sm font-semibold text-muted-foreground"> <h3 className="text-sm font-semibold text-muted-foreground">
File Storage Integrations {CONNECTOR_CATEGORY_LABELS.knowledge_base}
</h3> </h3>
</div> </div>
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3"> <div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
{nativeGoogleDriveConnectors.map(renderOAuthCard)} {knowledgeBase.oauth.map(renderOAuthCard)}
{composioGoogleDriveConnectors.map(renderOAuthCard)} {knowledgeBase.composio.map(renderOAuthCard)}
{fileStorageConnectors.map(renderOAuthCard)} {knowledgeBase.crawlers.map(renderCrawlerCard)}
{knowledgeBase.other.map(renderOtherCard)}
</div> </div>
</section> </section>
)} )}
{/* More Integrations */} {hasToolsLive && (
{hasMoreIntegrations && (
<section> <section>
<div className="flex items-center gap-2 mb-4"> <div className="flex items-center gap-2 mb-4">
<h3 className="text-sm font-semibold text-muted-foreground">More Integrations</h3> <h3 className="text-sm font-semibold text-muted-foreground">
{CONNECTOR_CATEGORY_LABELS.tools_live}
</h3>
</div> </div>
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3"> <div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
{otherDocumentYouTubeConnectors.map(renderCrawlerCard)} {toolsLive.oauth.map(renderOAuthCard)}
{otherDocumentNotionConnectors.map(renderOAuthCard)} {toolsLive.composio.map(renderOAuthCard)}
{otherDocumentAirtableConnectors.map(renderOAuthCard)} {toolsLive.crawlers.map(renderCrawlerCard)}
{moreIntegrationsComposio.map(renderOAuthCard)} {toolsLive.other.map(renderOtherCard)}
{moreIntegrationsOAuth.map(renderOAuthCard)}
{moreIntegrationsOther.map(renderOtherCard)}
{moreIntegrationsCrawlers.map(renderCrawlerCard)}
</div> </div>
</section> </section>
)} )}

View file

@ -0,0 +1,79 @@
"use client";
import { Download } from "lucide-react";
import { useEffect, useState } from "react";
import { toast } from "sonner";
import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { authenticatedFetch } from "@/lib/auth-utils";
import { BACKEND_URL } from "@/lib/env-config";
interface DownloadOriginalButtonProps {
documentId: number;
}
/** Renders only when the document has a stored ORIGINAL file; downloads it on click. */
export function DownloadOriginalButton({ documentId }: DownloadOriginalButtonProps) {
const [originalFilename, setOriginalFilename] = useState<string | null>(null);
const [downloading, setDownloading] = useState(false);
useEffect(() => {
let active = true;
documentsApiService
.getDocumentFiles(documentId)
.then((files) => {
if (!active) return;
const original = files.find((file) => file.kind === "ORIGINAL");
setOriginalFilename(original?.original_filename ?? null);
})
.catch(() => {
if (active) setOriginalFilename(null);
});
return () => {
active = false;
};
}, [documentId]);
if (!originalFilename) return null;
const handleDownload = async () => {
setDownloading(true);
try {
const response = await authenticatedFetch(
`${BACKEND_URL}/api/v1/documents/${documentId}/download-original`,
{ method: "GET" }
);
if (!response.ok) throw new Error("Download failed");
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const anchor = document.createElement("a");
anchor.href = url;
anchor.download = originalFilename;
document.body.appendChild(anchor);
anchor.click();
anchor.remove();
URL.revokeObjectURL(url);
toast.success("Download started");
} catch {
toast.error("Failed to download original file");
} finally {
setDownloading(false);
}
};
return (
<Button
variant="ghost"
size="icon"
className="size-6"
onClick={handleDownload}
disabled={downloading}
title={`Download original (${originalFilename})`}
>
{downloading ? <Spinner size="xs" /> : <Download className="size-3.5" />}
<span className="sr-only">Download original file</span>
</Button>
);
}

View file

@ -15,6 +15,7 @@ import dynamic from "next/dynamic";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom";
import { DownloadOriginalButton } from "@/components/documents/download-original-button";
import { VersionHistoryButton } from "@/components/documents/version-history"; import { VersionHistoryButton } from "@/components/documents/version-history";
import { SourceCodeEditor } from "@/components/editor/source-code-editor"; import { SourceCodeEditor } from "@/components/editor/source-code-editor";
import { import {
@ -584,6 +585,9 @@ export function EditorPanelContent({
documentType={editorDoc.document_type} documentType={editorDoc.document_type}
/> />
)} )}
{!isLocalFileMode && !isMemoryMode && documentId && (
<DownloadOriginalButton documentId={documentId} />
)}
<Button <Button
variant="ghost" variant="ghost"
size="icon" size="icon"
@ -668,6 +672,9 @@ export function EditorPanelContent({
documentType={editorDoc.document_type} documentType={editorDoc.document_type}
/> />
)} )}
{!isLocalFileMode && !isMemoryMode && documentId && (
<DownloadOriginalButton documentId={documentId} />
)}
<Button <Button
variant="ghost" variant="ghost"
size="icon" size="icon"

View file

@ -281,6 +281,23 @@ export const deleteDocumentResponse = z.object({
message: z.literal("Document deleted successfully"), message: z.literal("Document deleted successfully"),
}); });
/**
* Document files (stored originals / derived artifacts)
*/
export const documentFileKindEnum = z.enum(["ORIGINAL", "REDACTED", "FILLED_FORM"]);
export const documentFileRead = z.object({
id: z.number(),
document_id: z.number(),
kind: documentFileKindEnum,
original_filename: z.string(),
mime_type: z.string().nullable().optional(),
size_bytes: z.number(),
created_at: z.string(),
});
export const getDocumentFilesResponse = z.array(documentFileRead);
export type Document = z.infer<typeof document>; export type Document = z.infer<typeof document>;
export type DocumentTitleRead = z.infer<typeof documentTitleRead>; export type DocumentTitleRead = z.infer<typeof documentTitleRead>;
export type GetDocumentsRequest = z.infer<typeof getDocumentsRequest>; export type GetDocumentsRequest = z.infer<typeof getDocumentsRequest>;
@ -314,3 +331,6 @@ export type GetDocumentChunksRequest = z.infer<typeof getDocumentChunksRequest>;
export type GetDocumentChunksResponse = z.infer<typeof getDocumentChunksResponse>; export type GetDocumentChunksResponse = z.infer<typeof getDocumentChunksResponse>;
export type ChunkRead = z.infer<typeof chunkRead>; export type ChunkRead = z.infer<typeof chunkRead>;
export type ProcessingMode = z.infer<typeof processingModeEnum>; export type ProcessingMode = z.infer<typeof processingModeEnum>;
export type DocumentFileKind = z.infer<typeof documentFileKindEnum>;
export type DocumentFileRead = z.infer<typeof documentFileRead>;
export type GetDocumentFilesResponse = z.infer<typeof getDocumentFilesResponse>;

View file

@ -30,6 +30,8 @@ import {
searchDocumentsResponse, searchDocumentsResponse,
searchDocumentTitlesRequest, searchDocumentTitlesRequest,
searchDocumentTitlesResponse, searchDocumentTitlesResponse,
type DocumentFileRead,
getDocumentFilesResponse,
type UpdateDocumentRequest, type UpdateDocumentRequest,
type UploadDocumentRequest, type UploadDocumentRequest,
updateDocumentRequest, updateDocumentRequest,
@ -381,6 +383,14 @@ class DocumentsApiService {
}); });
}; };
/**
* List the stored files for a document (e.g. its original upload).
* Used to gate the "Download original" affordance.
*/
getDocumentFiles = async (documentId: number): Promise<DocumentFileRead[]> => {
return baseApiService.get(`/api/v1/documents/${documentId}/files`, getDocumentFilesResponse);
};
listDocumentVersions = async (documentId: number) => { listDocumentVersions = async (documentId: number) => {
return baseApiService.get(`/api/v1/documents/${documentId}/versions`); return baseApiService.get(`/api/v1/documents/${documentId}/versions`);
}; };