Merge pull request #229 from MODSetter/dev

fix: linting
This commit is contained in:
Rohan Verma 2025-07-26 01:43:21 +05:30 committed by GitHub
commit 617a7a34b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
92 changed files with 6274 additions and 5163 deletions

224
.github/workflows/code-quality.yml vendored Normal file
View file

@ -0,0 +1,224 @@
name: Code Quality Checks
on:
pull_request:
branches: [main, dev]
types: [opened, synchronize, reopened, ready_for_review]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
file-quality:
name: File Quality Checks
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Fetch base branch
run: |
# Ensure we have the base branch reference for comparison
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install pre-commit
run: pip install pre-commit
- name: Cache pre-commit hooks
uses: actions/cache@v4
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
pre-commit-
- name: Install hook environments (cache)
run: pre-commit install-hooks
- name: Run file quality checks on changed files
run: |
# Get list of changed files and run specific hooks on them
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
BASE_REF="${{ github.base_ref }}"
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
BASE_REF="origin/${{ github.base_ref }}"
else
echo "Base branch reference not found, running file quality hooks on all files"
pre-commit run --all-files check-yaml check-json check-toml check-merge-conflict check-added-large-files debug-statements check-case-conflict
exit 0
fi
echo "Running file quality hooks on changed files against $BASE_REF"
# Run each hook individually on changed files
SKIP=detect-secrets,bandit,ruff,ruff-format,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
# Exit with the same code as pre-commit
exit ${exit_code:-0}
security-scan:
name: Security Scan
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Fetch base branch
run: |
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install pre-commit
run: pip install pre-commit
- name: Cache pre-commit hooks
uses: actions/cache@v4
with:
path: ~/.cache/pre-commit
key: pre-commit-security-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
pre-commit-security-
- name: Install hook environments (cache)
run: pre-commit install-hooks
- name: Run security scans on changed files
run: |
# Get base ref for comparison
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
BASE_REF="${{ github.base_ref }}"
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
BASE_REF="origin/${{ github.base_ref }}"
else
echo "Base branch reference not found, running security scans on all files"
echo "⚠️ This may take longer than normal"
pre-commit run --all-files detect-secrets bandit
exit 0
fi
echo "Running security scans on changed files against $BASE_REF"
# Run only security hooks on changed files
SKIP=check-yaml,check-json,check-toml,check-merge-conflict,check-added-large-files,debug-statements,check-case-conflict,ruff,ruff-format,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
# Exit with the same code as pre-commit
exit ${exit_code:-0}
python-backend:
name: Python Backend Quality
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install UV
uses: astral-sh/setup-uv@v3
- name: Check if backend files changed
id: backend-changes
uses: dorny/paths-filter@v3
with:
filters: |
backend:
- 'surfsense_backend/**'
- name: Cache dependencies
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v4
with:
path: |
~/.cache/uv
surfsense_backend/.venv
key: python-deps-${{ hashFiles('surfsense_backend/uv.lock') }}
- name: Install dependencies
if: steps.backend-changes.outputs.backend == 'true'
working-directory: surfsense_backend
run: uv sync
- name: Install pre-commit for backend checks
if: steps.backend-changes.outputs.backend == 'true'
run: pip install pre-commit
- name: Cache pre-commit hooks
if: steps.backend-changes.outputs.backend == 'true'
uses: actions/cache@v4
with:
path: ~/.cache/pre-commit
key: pre-commit-backend-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
pre-commit-backend-
- name: Install hook environments (cache)
if: steps.backend-changes.outputs.backend == 'true'
run: pre-commit install-hooks
- name: Run Python backend quality checks
if: steps.backend-changes.outputs.backend == 'true'
run: |
# Get base ref for comparison
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
BASE_REF="${{ github.base_ref }}"
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
BASE_REF="origin/${{ github.base_ref }}"
else
echo "Base branch reference not found, running Python backend checks on all files"
pre-commit run --all-files ruff ruff-format
exit 0
fi
echo "Running Python backend checks on changed files against $BASE_REF"
# Run only ruff hooks on changed Python files
SKIP=detect-secrets,bandit,check-yaml,check-json,check-toml,check-merge-conflict,check-added-large-files,debug-statements,check-case-conflict,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
# Exit with the same code as pre-commit
exit ${exit_code:-0}
quality-gate:
name: Quality Gate
runs-on: ubuntu-latest
needs: [file-quality, security-scan, python-backend]
if: always()
steps:
- name: Check all jobs status
run: |
if [[ "${{ needs.file-quality.result }}" == "failure" ||
"${{ needs.security-scan.result }}" == "failure" ||
"${{ needs.python-backend.result }}" == "failure" ]]; then
echo "❌ Code quality checks failed"
exit 1
else
echo "✅ All code quality checks passed"
fi

View file

@ -3,7 +3,7 @@ name: pre-commit
on: on:
push: push:
pull_request: pull_request:
branches: [main] branches: [main, dev]
jobs: jobs:
pre-commit: pre-commit:

2
.gitignore vendored
View file

@ -1,3 +1,5 @@
.flashrank_cache* .flashrank_cache*
podcasts/ podcasts/
.env .env
.ruff_cache/

View file

@ -6,8 +6,6 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0 rev: v5.0.0
hooks: hooks:
- id: trailing-whitespace
exclude: '\.md$'
- id: check-yaml - id: check-yaml
args: [--multi, --unsafe] args: [--multi, --unsafe]
- id: check-json - id: check-json
@ -31,52 +29,36 @@ repos:
.*\.env\.template| .*\.env\.template|
.*/tests/.*| .*/tests/.*|
.*test.*\.py| .*test.*\.py|
test_.*\.py|
.github/workflows/.*\.yml| .github/workflows/.*\.yml|
.github/workflows/.*\.yaml| .github/workflows/.*\.yaml|
.*pnpm-lock\.yaml| .*pnpm-lock\.yaml|
.*alembic\.ini| .*alembic\.ini|
.*alembic/versions/.*\.py|
.*\.mdx$ .*\.mdx$
)$ )$
# Python Backend Hooks (surfsense_backend) # Python Backend Hooks (surfsense_backend) - Using Ruff for linting and formatting
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: black
files: ^surfsense_backend/
language_version: python3
- repo: https://github.com/pycqa/isort
rev: 6.0.1
hooks:
- id: isort
files: ^surfsense_backend/
args: ["--profile", "black", "--line-length", "88"]
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.4 rev: v0.12.5
hooks: hooks:
- id: ruff - id: ruff
name: ruff-check
files: ^surfsense_backend/ files: ^surfsense_backend/
args: [--fix, --exit-non-zero-on-fix] exclude: ^surfsense_backend/(test_.*\.py|.*test.*\.py)
args: [--fix]
- id: ruff-format - id: ruff-format
name: ruff-format
files: ^surfsense_backend/ files: ^surfsense_backend/
exclude: ^surfsense_backend/(test_.*\.py|.*test.*\.py)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.0
hooks:
- id: mypy
files: ^surfsense_backend/
additional_dependencies: ['types-requests']
args: [--ignore-missing-imports, --disallow-untyped-defs]
- repo: https://github.com/PyCQA/bandit - repo: https://github.com/PyCQA/bandit
rev: 1.8.6 rev: 1.8.6
hooks: hooks:
- id: bandit - id: bandit
files: ^surfsense_backend/ files: ^surfsense_backend/
args: ['-r', '-f', 'json'] args: ['-f', 'json', '--severity-level', 'high', '--confidence-level', 'high']
exclude: ^surfsense_backend/(tests/|alembic/) exclude: ^surfsense_backend/(tests/|test_.*\.py|.*test.*\.py|alembic/)
# Frontend/Extension Hooks (TypeScript/JavaScript) # Frontend/Extension Hooks (TypeScript/JavaScript)
- repo: https://github.com/pre-commit/mirrors-prettier - repo: https://github.com/pre-commit/mirrors-prettier

View file

@ -1,8 +1,8 @@
import asyncio import asyncio
from logging.config import fileConfig
import os import os
import sys import sys
from logging.config import fileConfig
from sqlalchemy import pool from sqlalchemy import pool
from sqlalchemy.engine import Connection from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config from sqlalchemy.ext.asyncio import async_engine_from_config
@ -11,10 +11,10 @@ from alembic import context
# Ensure the app directory is in the Python path # Ensure the app directory is in the Python path
# This allows Alembic to find your models # This allows Alembic to find your models
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), "..")))
# Import your models base # Import your models base
from app.db import Base # Assuming your Base is defined in app.db from app.db import Base # Assuming your Base is defined in app.db
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.

View file

@ -4,17 +4,15 @@ Revision ID: 10
Revises: 9 Revises: 9
""" """
from typing import Sequence, Union from collections.abc import Sequence
from alembic import op from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "10" revision: str = "10"
down_revision: Union[str, None] = "9" down_revision: str | None = "9"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
# Define the ENUM type name # Define the ENUM type name
CHAT_TYPE_ENUM = "chattype" CHAT_TYPE_ENUM = "chattype"
@ -27,12 +25,7 @@ def upgrade() -> None:
old_enum_name = f"{CHAT_TYPE_ENUM}_old" old_enum_name = f"{CHAT_TYPE_ENUM}_old"
# New enum values # New enum values
new_values = ( new_values = ("QNA", "REPORT_GENERAL", "REPORT_DEEP", "REPORT_DEEPER")
"QNA",
"REPORT_GENERAL",
"REPORT_DEEP",
"REPORT_DEEPER"
)
new_values_sql = ", ".join([f"'{v}'" for v in new_values]) new_values_sql = ", ".join([f"'{v}'" for v in new_values])
# Table and column info # Table and column info
@ -46,19 +39,31 @@ def upgrade() -> None:
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({new_values_sql})") op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({new_values_sql})")
# Step 3: Add a temporary column with the new type # Step 3: Add a temporary column with the new type
op.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}") op.execute(
f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}"
)
# Step 4: Update the temporary column with mapped values # Step 4: Update the temporary column with mapped values
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'QNA' WHERE {column_name}::text = 'GENERAL'") op.execute(
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEP' WHERE {column_name}::text = 'DEEP'") f"UPDATE {table_name} SET {column_name}_new = 'QNA' WHERE {column_name}::text = 'GENERAL'"
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPER'") )
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPEST'") op.execute(
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEP' WHERE {column_name}::text = 'DEEP'"
)
op.execute(
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPER'"
)
op.execute(
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPEST'"
)
# Step 5: Drop the old column # Step 5: Drop the old column
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}") op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
# Step 6: Rename the new column to the original name # Step 6: Rename the new column to the original name
op.execute(f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}") op.execute(
f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}"
)
# Step 7: Drop the old enum type # Step 7: Drop the old enum type
op.execute(f"DROP TYPE {old_enum_name}") op.execute(f"DROP TYPE {old_enum_name}")
@ -71,12 +76,7 @@ def downgrade() -> None:
old_enum_name = f"{CHAT_TYPE_ENUM}_old" old_enum_name = f"{CHAT_TYPE_ENUM}_old"
# Original enum values # Original enum values
original_values = ( original_values = ("GENERAL", "DEEP", "DEEPER", "DEEPEST")
"GENERAL",
"DEEP",
"DEEPER",
"DEEPEST"
)
original_values_sql = ", ".join([f"'{v}'" for v in original_values]) original_values_sql = ", ".join([f"'{v}'" for v in original_values])
# Table and column info # Table and column info
@ -90,19 +90,31 @@ def downgrade() -> None:
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({original_values_sql})") op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({original_values_sql})")
# Step 3: Add a temporary column with the original type # Step 3: Add a temporary column with the original type
op.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}") op.execute(
f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}"
)
# Step 4: Update the temporary column with mapped values back to old values # Step 4: Update the temporary column with mapped values back to old values
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'QNA'") op.execute(
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'REPORT_GENERAL'") f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'QNA'"
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'DEEP' WHERE {column_name}::text = 'REPORT_DEEP'") )
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'DEEPER' WHERE {column_name}::text = 'REPORT_DEEPER'") op.execute(
f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'REPORT_GENERAL'"
)
op.execute(
f"UPDATE {table_name} SET {column_name}_new = 'DEEP' WHERE {column_name}::text = 'REPORT_DEEP'"
)
op.execute(
f"UPDATE {table_name} SET {column_name}_new = 'DEEPER' WHERE {column_name}::text = 'REPORT_DEEPER'"
)
# Step 5: Drop the old column # Step 5: Drop the old column
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}") op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
# Step 6: Rename the new column to the original name # Step 6: Rename the new column to the original name
op.execute(f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}") op.execute(
f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}"
)
# Step 7: Drop the old enum type # Step 7: Drop the old enum type
op.execute(f"DROP TYPE {old_enum_name}") op.execute(f"DROP TYPE {old_enum_name}")

View file

@ -4,16 +4,17 @@ Revision ID: 11
Revises: 10 Revises: 10
""" """
from typing import Sequence, Union from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "11" revision: str = "11"
down_revision: Union[str, None] = "10" down_revision: str | None = "10"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:

View file

@ -4,16 +4,17 @@ Revision ID: 12
Revises: 11 Revises: 11
""" """
from typing import Sequence, Union from collections.abc import Sequence
from sqlalchemy import inspect
from alembic import op from alembic import op
from sqlalchemy import inspect
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "12" revision: str = "12"
down_revision: Union[str, None] = "11" down_revision: str | None = "11"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:

View file

@ -4,15 +4,15 @@ Revision ID: 13
Revises: 12 Revises: 12
""" """
from typing import Sequence, Union from collections.abc import Sequence
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "13" revision: str = "13"
down_revision: Union[str, None] = "12" down_revision: str | None = "12"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:

View file

@ -5,7 +5,7 @@ Revises:
""" """
from typing import Sequence, Union from collections.abc import Sequence
from alembic import op from alembic import op
@ -15,9 +15,9 @@ from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "1" revision: str = "1"
down_revision: Union[str, None] = None down_revision: str | None = None
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
@ -63,10 +63,8 @@ def downgrade() -> None:
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR')" "CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR')"
) )
op.execute( op.execute(
( "ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING " "connector_type::text::searchsourceconnectortype"
"connector_type::text::searchsourceconnectortype"
)
) )
op.execute("DROP TYPE searchsourceconnectortype_old") op.execute("DROP TYPE searchsourceconnectortype_old")

View file

@ -5,15 +5,15 @@ Revises: e55302644c51
""" """
from typing import Sequence, Union from collections.abc import Sequence
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "2" revision: str = "2"
down_revision: Union[str, None] = "e55302644c51" down_revision: str | None = "e55302644c51"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
@ -49,10 +49,8 @@ def downgrade() -> None:
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR')" "CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR')"
) )
op.execute( op.execute(
( "ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING " "connector_type::text::searchsourceconnectortype"
"connector_type::text::searchsourceconnectortype"
)
) )
op.execute("DROP TYPE searchsourceconnectortype_old") op.execute("DROP TYPE searchsourceconnectortype_old")

View file

@ -5,15 +5,15 @@ Revises: 2
""" """
from typing import Sequence, Union from collections.abc import Sequence
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "3" revision: str = "3"
down_revision: Union[str, None] = "2" down_revision: str | None = "2"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
# Define the ENUM type name and the new value # Define the ENUM type name and the new value
ENUM_NAME = "documenttype" # Make sure this matches the name in your DB (usually lowercase class name) ENUM_NAME = "documenttype" # Make sure this matches the name in your DB (usually lowercase class name)

View file

@ -5,37 +5,26 @@ Revises: 3
""" """
from typing import Sequence, Union from collections.abc import Sequence
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "4" revision: str = "4"
down_revision: Union[str, None] = "3" down_revision: str | None = "3"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
ENUM_NAME = "searchsourceconnectortype" # ### commands auto generated by Alembic - please adjust! ###
NEW_VALUE = "LINKUP_API"
op.execute( # Manually add the command to add the enum value
f""" op.execute("ALTER TYPE searchsourceconnectortype ADD VALUE 'LINKUP_API'")
DO $$
BEGIN # Pass for the rest, as autogenerate didn't run to add other schema details
IF NOT EXISTS ( pass
SELECT 1 FROM pg_enum # ### end Alembic commands ###
WHERE enumlabel = '{NEW_VALUE}'
AND enumtypid = (
SELECT oid FROM pg_type WHERE typname = '{ENUM_NAME}'
)
) THEN
ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}';
END IF;
END$$;
"""
)
def downgrade() -> None: def downgrade() -> None:
@ -49,10 +38,8 @@ def downgrade() -> None:
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR', 'LINEAR_CONNECTOR')" "CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR', 'LINEAR_CONNECTOR')"
) )
op.execute( op.execute(
( "ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING " "connector_type::text::searchsourceconnectortype"
"connector_type::text::searchsourceconnectortype"
)
) )
op.execute("DROP TYPE searchsourceconnectortype_old") op.execute("DROP TYPE searchsourceconnectortype_old")

View file

@ -4,54 +4,73 @@ Revision ID: 5
Revises: 4 Revises: 4
""" """
from typing import Sequence, Union
from alembic import op from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '5' revision: str = "5"
down_revision: Union[str, None] = '4' down_revision: str | None = "4"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
# Alter Chat table # Alter Chat table
op.alter_column('chats', 'title', op.alter_column(
existing_type=sa.String(200), "chats",
type_=sa.String(), "title",
existing_nullable=False) existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
)
# Alter Document table # Alter Document table
op.alter_column('documents', 'title', op.alter_column(
existing_type=sa.String(200), "documents",
type_=sa.String(), "title",
existing_nullable=False) existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
)
# Alter Podcast table # Alter Podcast table
op.alter_column('podcasts', 'title', op.alter_column(
existing_type=sa.String(200), "podcasts",
type_=sa.String(), "title",
existing_nullable=False) existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
)
def downgrade() -> None: def downgrade() -> None:
# Revert Chat table # Revert Chat table
op.alter_column('chats', 'title', op.alter_column(
existing_type=sa.String(), "chats",
type_=sa.String(200), "title",
existing_nullable=False) existing_type=sa.String(),
type_=sa.String(200),
existing_nullable=False,
)
# Revert Document table # Revert Document table
op.alter_column('documents', 'title', op.alter_column(
existing_type=sa.String(), "documents",
type_=sa.String(200), "title",
existing_nullable=False) existing_type=sa.String(),
type_=sa.String(200),
existing_nullable=False,
)
# Revert Podcast table # Revert Podcast table
op.alter_column('podcasts', 'title', op.alter_column(
existing_type=sa.String(), "podcasts",
type_=sa.String(200), "title",
existing_nullable=False) existing_type=sa.String(),
type_=sa.String(200),
existing_nullable=False,
)

View file

@ -5,18 +5,19 @@ Revises: 5
""" """
from typing import Sequence, Union from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.dialects.postgresql import JSON
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "6" revision: str = "6"
down_revision: Union[str, None] = "5" down_revision: str | None = "5"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:

View file

@ -5,17 +5,18 @@ Revises: 6
""" """
from typing import Sequence, Union from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
from sqlalchemy import inspect from sqlalchemy import inspect
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "7" revision: str = "7"
down_revision: Union[str, None] = "6" down_revision: str | None = "6"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:

View file

@ -4,17 +4,18 @@ Revision ID: 8
Revises: 7 Revises: 7
""" """
from typing import Sequence, Union from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
from sqlalchemy import inspect from sqlalchemy import inspect
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "8" revision: str = "8"
down_revision: Union[str, None] = "7" down_revision: str | None = "7"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:

View file

@ -4,15 +4,15 @@ Revision ID: 9
Revises: 8 Revises: 8
""" """
from typing import Sequence, Union from collections.abc import Sequence
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "9" revision: str = "9"
down_revision: Union[str, None] = "8" down_revision: str | None = "8"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
# Define the ENUM type name and the new value # Define the ENUM type name and the new value
CONNECTOR_ENUM = "searchsourceconnectortype" CONNECTOR_ENUM = "searchsourceconnectortype"

View file

@ -1,12 +1,12 @@
from typing import Sequence, Union from collections.abc import Sequence
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "e55302644c51" revision: str = "e55302644c51"
down_revision: Union[str, None] = "1" down_revision: str | None = "1"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
# Define the ENUM type name and the new value # Define the ENUM type name and the new value
ENUM_NAME = "documenttype" ENUM_NAME = "documenttype"

View file

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Optional
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
@ -21,7 +20,7 @@ class Configuration:
@classmethod @classmethod
def from_runnable_config( def from_runnable_config(
cls, config: Optional[RunnableConfig] = None cls, config: RunnableConfig | None = None
) -> Configuration: ) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object.""" """Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {} configurable = (config.get("configurable") or {}) if config else {}

View file

@ -1,14 +1,11 @@
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from .configuration import Configuration from .configuration import Configuration
from .nodes import create_merged_podcast_audio, create_podcast_transcript
from .state import State from .state import State
from .nodes import create_merged_podcast_audio, create_podcast_transcript
def build_graph(): def build_graph():
# Define a new graph # Define a new graph
workflow = StateGraph(State, config_schema=Configuration) workflow = StateGraph(State, config_schema=Configuration)
@ -27,5 +24,6 @@ def build_graph():
return graph return graph
# Compile the graph once when the module is loaded # Compile the graph once when the module is loaded
graph = build_graph() graph = build_graph()

View file

@ -1,23 +1,26 @@
from typing import Any, Dict import asyncio
import json import json
import os import os
import uuid import uuid
from pathlib import Path from pathlib import Path
import asyncio from typing import Any
from ffmpeg.asyncio import FFmpeg
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from litellm import aspeech from litellm import aspeech
from ffmpeg.asyncio import FFmpeg
from .configuration import Configuration
from .state import PodcastTranscriptEntry, State, PodcastTranscripts
from .prompts import get_podcast_generation_prompt
from app.config import config as app_config from app.config import config as app_config
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from .configuration import Configuration
from .prompts import get_podcast_generation_prompt
from .state import PodcastTranscriptEntry, PodcastTranscripts, State
async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dict[str, Any]:
async def create_podcast_transcript(
state: State, config: RunnableConfig
) -> dict[str, Any]:
"""Each node does work.""" """Each node does work."""
# Get configuration from runnable config # Get configuration from runnable config
@ -37,7 +40,9 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
# Create the messages # Create the messages
messages = [ messages = [
SystemMessage(content=prompt), SystemMessage(content=prompt),
HumanMessage(content=f"<source_content>{state.source_content}</source_content>") HumanMessage(
content=f"<source_content>{state.source_content}</source_content>"
),
] ]
# Generate the podcast transcript # Generate the podcast transcript
@ -45,9 +50,11 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
# First try the direct approach # First try the direct approach
try: try:
podcast_transcript = PodcastTranscripts.model_validate(json.loads(llm_response.content)) podcast_transcript = PodcastTranscripts.model_validate(
json.loads(llm_response.content)
)
except (json.JSONDecodeError, ValueError) as e: except (json.JSONDecodeError, ValueError) as e:
print(f"Direct JSON parsing failed, trying fallback approach: {str(e)}") print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
# Fallback: Parse the JSON response manually # Fallback: Parse the JSON response manually
try: try:
@ -55,8 +62,8 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
content = llm_response.content content = llm_response.content
# Find the JSON in the content (handle case where LLM might add additional text) # Find the JSON in the content (handle case where LLM might add additional text)
json_start = content.find('{') json_start = content.find("{")
json_end = content.rfind('}') + 1 json_end = content.rfind("}") + 1
if json_start >= 0 and json_end > json_start: if json_start >= 0 and json_end > json_start:
json_str = content[json_start:json_end] json_str = content[json_start:json_end]
@ -66,7 +73,7 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
# Convert to Pydantic model # Convert to Pydantic model
podcast_transcript = PodcastTranscripts.model_validate(parsed_data) podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
print(f"Successfully parsed podcast transcript using fallback approach") print("Successfully parsed podcast transcript using fallback approach")
else: else:
# If JSON structure not found, raise a clear error # If JSON structure not found, raise a clear error
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}" error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
@ -75,36 +82,35 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
except (json.JSONDecodeError, ValueError) as e2: except (json.JSONDecodeError, ValueError) as e2:
# Log the error and re-raise it # Log the error and re-raise it
error_message = f"Error parsing LLM response (fallback also failed): {str(e2)}" error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
print(f"Error parsing LLM response: {str(e2)}") print(f"Error parsing LLM response: {e2!s}")
print(f"Raw response: {llm_response.content}") print(f"Raw response: {llm_response.content}")
raise raise
return { return {"podcast_transcript": podcast_transcript.podcast_transcripts}
"podcast_transcript": podcast_transcript.podcast_transcripts
}
async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> Dict[str, Any]: async def create_merged_podcast_audio(
state: State, config: RunnableConfig
) -> dict[str, Any]:
"""Generate audio for each transcript and merge them into a single podcast file.""" """Generate audio for each transcript and merge them into a single podcast file."""
configuration = Configuration.from_runnable_config(config) configuration = Configuration.from_runnable_config(config)
starting_transcript = PodcastTranscriptEntry( starting_transcript = PodcastTranscriptEntry(
speaker_id=1, speaker_id=1, dialog=f"Welcome to {configuration.podcast_title} Podcast."
dialog=f"Welcome to {configuration.podcast_title} Podcast."
) )
transcript = state.podcast_transcript transcript = state.podcast_transcript
# Merge the starting transcript with the podcast transcript # Merge the starting transcript with the podcast transcript
# Check if transcript is a PodcastTranscripts object or already a list # Check if transcript is a PodcastTranscripts object or already a list
if hasattr(transcript, 'podcast_transcripts'): if hasattr(transcript, "podcast_transcripts"):
transcript_entries = transcript.podcast_transcripts transcript_entries = transcript.podcast_transcripts
else: else:
transcript_entries = transcript transcript_entries = transcript
merged_transcript = [starting_transcript] + transcript_entries merged_transcript = [starting_transcript, *transcript_entries]
# Create a temporary directory for audio files # Create a temporary directory for audio files
temp_dir = Path("temp_audio") temp_dir = Path("temp_audio")
@ -118,7 +124,7 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
# Map of speaker_id to voice # Map of speaker_id to voice
voice_mapping = { voice_mapping = {
0: "alloy", # Default/intro voice 0: "alloy", # Default/intro voice
1: "echo", # First speaker 1: "echo", # First speaker
# 2: "fable", # Second speaker # 2: "fable", # Second speaker
# 3: "onyx", # Third speaker # 3: "onyx", # Third speaker
# 4: "nova", # Fourth speaker # 4: "nova", # Fourth speaker
@ -130,7 +136,7 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
async def generate_speech_for_segment(segment, index): async def generate_speech_for_segment(segment, index):
# Handle both dictionary and PodcastTranscriptEntry objects # Handle both dictionary and PodcastTranscriptEntry objects
if hasattr(segment, 'speaker_id'): if hasattr(segment, "speaker_id"):
speaker_id = segment.speaker_id speaker_id = segment.speaker_id
dialog = segment.dialog dialog = segment.dialog
else: else:
@ -165,16 +171,19 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
) )
# Save the audio to a file - use proper streaming method # Save the audio to a file - use proper streaming method
with open(filename, 'wb') as f: with open(filename, "wb") as f:
f.write(response.content) f.write(response.content)
return filename return filename
except Exception as e: except Exception as e:
print(f"Error generating speech for segment {index}: {str(e)}") print(f"Error generating speech for segment {index}: {e!s}")
raise raise
# Generate all audio files concurrently # Generate all audio files concurrently
tasks = [generate_speech_for_segment(segment, i) for i, segment in enumerate(merged_transcript)] tasks = [
generate_speech_for_segment(segment, i)
for i, segment in enumerate(merged_transcript)
]
audio_files = await asyncio.gather(*tasks) audio_files = await asyncio.gather(*tasks)
# Merge audio files using ffmpeg # Merge audio files using ffmpeg
@ -191,7 +200,9 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
for i in range(len(audio_files)): for i in range(len(audio_files)):
filter_complex.append(f"[{i}:0]") filter_complex.append(f"[{i}:0]")
filter_complex_str = "".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]" filter_complex_str = (
"".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
)
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str) ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
ffmpeg = ffmpeg.output(output_path, map="[outa]") ffmpeg = ffmpeg.output(output_path, map="[outa]")
@ -201,17 +212,18 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
print(f"Successfully created podcast audio: {output_path}") print(f"Successfully created podcast audio: {output_path}")
except Exception as e: except Exception as e:
print(f"Error merging audio files: {str(e)}") print(f"Error merging audio files: {e!s}")
raise raise
finally: finally:
# Clean up temporary files # Clean up temporary files
for audio_file in audio_files: for audio_file in audio_files:
try: try:
os.remove(audio_file) os.remove(audio_file)
except: except Exception as e:
print(f"Error removing audio file {audio_file}: {e!s}")
pass pass
return { return {
"podcast_transcript": merged_transcript, "podcast_transcript": merged_transcript,
"final_podcast_file_path": output_path "final_podcast_file_path": output_path,
} }

View file

@ -3,14 +3,16 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
class PodcastTranscriptEntry(BaseModel): class PodcastTranscriptEntry(BaseModel):
""" """
Represents a single entry in a podcast transcript. Represents a single entry in a podcast transcript.
""" """
speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)") speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)")
dialog: str = Field(..., description="The dialog text spoken by the speaker") dialog: str = Field(..., description="The dialog text spoken by the speaker")
@ -19,11 +21,12 @@ class PodcastTranscripts(BaseModel):
""" """
Represents the full podcast transcript structure. Represents the full podcast transcript structure.
""" """
podcast_transcripts: List[PodcastTranscriptEntry] = Field(
..., podcast_transcripts: list[PodcastTranscriptEntry] = Field(
description="List of transcript entries with alternating speakers" ..., description="List of transcript entries with alternating speakers"
) )
@dataclass @dataclass
class State: class State:
"""Defines the input state for the agent, representing a narrower interface to the outside world. """Defines the input state for the agent, representing a narrower interface to the outside world.
@ -32,8 +35,9 @@ class State:
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information. for more information.
""" """
# Runtime context # Runtime context
db_session: AsyncSession db_session: AsyncSession
source_content: str source_content: str
podcast_transcript: Optional[List[PodcastTranscriptEntry]] = None podcast_transcript: list[PodcastTranscriptEntry] | None = None
final_podcast_file_path: Optional[str] = None final_podcast_file_path: str | None = None

View file

@ -4,17 +4,20 @@ from __future__ import annotations
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum from enum import Enum
from typing import Optional, List, Any
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
class SearchMode(Enum): class SearchMode(Enum):
"""Enum defining the type of search mode.""" """Enum defining the type of search mode."""
CHUNKS = "CHUNKS" CHUNKS = "CHUNKS"
DOCUMENTS = "DOCUMENTS" DOCUMENTS = "DOCUMENTS"
class ResearchMode(Enum): class ResearchMode(Enum):
"""Enum defining the type of research mode.""" """Enum defining the type of research mode."""
QNA = "QNA" QNA = "QNA"
REPORT_GENERAL = "REPORT_GENERAL" REPORT_GENERAL = "REPORT_GENERAL"
REPORT_DEEP = "REPORT_DEEP" REPORT_DEEP = "REPORT_DEEP"
@ -28,16 +31,16 @@ class Configuration:
# Input parameters provided at invocation # Input parameters provided at invocation
user_query: str user_query: str
num_sections: int num_sections: int
connectors_to_search: List[str] connectors_to_search: list[str]
user_id: str user_id: str
search_space_id: int search_space_id: int
search_mode: SearchMode search_mode: SearchMode
research_mode: ResearchMode research_mode: ResearchMode
document_ids_to_add_in_context: List[int] document_ids_to_add_in_context: list[int]
@classmethod @classmethod
def from_runnable_config( def from_runnable_config(
cls, config: Optional[RunnableConfig] = None cls, config: RunnableConfig | None = None
) -> Configuration: ) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object.""" """Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {} configurable = (config.get("configurable") or {}) if config else {}

View file

@ -1,15 +1,25 @@
from typing import Any, TypedDict
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from .state import State
from .nodes import reformulate_user_query, write_answer_outline, process_sections, handle_qna_workflow, generate_further_questions
from .configuration import Configuration, ResearchMode from .configuration import Configuration, ResearchMode
from typing import TypedDict, List, Dict, Any, Optional from .nodes import (
generate_further_questions,
handle_qna_workflow,
process_sections,
reformulate_user_query,
write_answer_outline,
)
from .state import State
# Define what keys are in our state dict # Define what keys are in our state dict
class GraphState(TypedDict): class GraphState(TypedDict):
# Intermediate data produced during workflow # Intermediate data produced during workflow
answer_outline: Optional[Any] answer_outline: Any | None
# Final output # Final output
final_written_report: Optional[str] final_written_report: str | None
def build_graph(): def build_graph():
""" """
@ -51,8 +61,8 @@ def build_graph():
route_after_reformulate, route_after_reformulate,
{ {
"handle_qna_workflow": "handle_qna_workflow", "handle_qna_workflow": "handle_qna_workflow",
"write_answer_outline": "write_answer_outline" "write_answer_outline": "write_answer_outline",
} },
) )
# QNA workflow path: handle_qna_workflow -> generate_further_questions -> __end__ # QNA workflow path: handle_qna_workflow -> generate_further_questions -> __end__
@ -71,5 +81,6 @@ def build_graph():
return graph return graph
# Compile the graph once when the module is loaded # Compile the graph once when the module is loaded
graph = build_graph() graph = build_graph()

View file

@ -1,10 +1,7 @@
import asyncio import asyncio
import json import json
from typing import Any, Dict, List from typing import Any
from app.db import Document, SearchSpace
from app.services.connector_service import ConnectorService
from app.services.query_service import QueryService
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter from langgraph.types import StreamWriter
@ -13,6 +10,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
# Additional imports for document fetching # Additional imports for document fetching
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import Document, SearchSpace
from app.services.connector_service import ConnectorService
from app.services.query_service import QueryService
from .configuration import Configuration, SearchMode from .configuration import Configuration, SearchMode
from .prompts import ( from .prompts import (
get_answer_outline_system_prompt, get_answer_outline_system_prompt,
@ -26,8 +27,8 @@ from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_na
async def fetch_documents_by_ids( async def fetch_documents_by_ids(
document_ids: List[int], user_id: str, db_session: AsyncSession document_ids: list[int], user_id: str, db_session: AsyncSession
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
""" """
Fetch documents by their IDs with ownership check using DOCUMENTS mode approach. Fetch documents by their IDs with ownership check using DOCUMENTS mode approach.
@ -358,13 +359,13 @@ async def fetch_documents_by_ids(
return source_objects, formatted_documents return source_objects, formatted_documents
except Exception as e: except Exception as e:
print(f"Error fetching documents by IDs: {str(e)}") print(f"Error fetching documents by IDs: {e!s}")
return [], [] return [], []
async def write_answer_outline( async def write_answer_outline(
state: State, config: RunnableConfig, writer: StreamWriter state: State, config: RunnableConfig, writer: StreamWriter
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Create a structured answer outline based on the user query. Create a structured answer outline based on the user query.
@ -502,27 +503,27 @@ async def write_answer_outline(
except (json.JSONDecodeError, ValueError) as e: except (json.JSONDecodeError, ValueError) as e:
# Log the error and re-raise it # Log the error and re-raise it
error_message = f"Error parsing LLM response: {str(e)}" error_message = f"Error parsing LLM response: {e!s}"
writer({"yield_value": streaming_service.format_error(error_message)}) writer({"yield_value": streaming_service.format_error(error_message)})
print(f"Error parsing LLM response: {str(e)}") print(f"Error parsing LLM response: {e!s}")
print(f"Raw response: {response.content}") print(f"Raw response: {response.content}")
raise raise
async def fetch_relevant_documents( async def fetch_relevant_documents(
research_questions: List[str], research_questions: list[str],
user_id: str, user_id: str,
search_space_id: int, search_space_id: int,
db_session: AsyncSession, db_session: AsyncSession,
connectors_to_search: List[str], connectors_to_search: list[str],
writer: StreamWriter = None, writer: StreamWriter = None,
state: State = None, state: State = None,
top_k: int = 10, top_k: int = 10,
connector_service: ConnectorService = None, connector_service: ConnectorService = None,
search_mode: SearchMode = SearchMode.CHUNKS, search_mode: SearchMode = SearchMode.CHUNKS,
user_selected_sources: List[Dict[str, Any]] = None, user_selected_sources: list[dict[str, Any]] | None = None,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Fetch relevant documents for research questions using the provided connectors. Fetch relevant documents for research questions using the provided connectors.
@ -833,7 +834,10 @@ async def fetch_relevant_documents(
elif connector == "LINKUP_API": elif connector == "LINKUP_API":
linkup_mode = "standard" linkup_mode = "standard"
source_object, linkup_chunks = await connector_service.search_linkup( (
source_object,
linkup_chunks,
) = await connector_service.search_linkup(
user_query=reformulated_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
mode=linkup_mode, mode=linkup_mode,
@ -904,7 +908,7 @@ async def fetch_relevant_documents(
) )
except Exception as e: except Exception as e:
error_message = f"Error searching connector {connector}: {str(e)}" error_message = f"Error searching connector {connector}: {e!s}"
print(error_message) print(error_message)
# Stream error message # Stream error message
@ -913,7 +917,7 @@ async def fetch_relevant_documents(
writer( writer(
{ {
"yield_value": streaming_service.format_error( "yield_value": streaming_service.format_error(
f"Error searching {friendly_name}: {str(e)}" f"Error searching {friendly_name}: {e!s}"
) )
} }
) )
@ -948,36 +952,48 @@ async def fetch_relevant_documents(
if source_id and source_type: if source_id and source_type:
source_key = f"{source_type}_{source_id}" source_key = f"{source_type}_{source_id}"
current_sources_count = len(source_obj.get('sources', [])) current_sources_count = len(source_obj.get("sources", []))
if source_key not in seen_source_keys: if source_key not in seen_source_keys:
seen_source_keys.add(source_key) seen_source_keys.add(source_key)
deduplicated_sources.append(source_obj) deduplicated_sources.append(source_obj)
print(f"Debug: Added source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count}") print(
f"Debug: Added source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count}"
)
else: else:
# Check if this source object has more sources than the existing one # Check if this source object has more sources than the existing one
existing_index = None existing_index = None
for i, existing_source in enumerate(deduplicated_sources): for i, existing_source in enumerate(deduplicated_sources):
existing_id = existing_source.get('id') existing_id = existing_source.get("id")
existing_type = existing_source.get('type') existing_type = existing_source.get("type")
if existing_id == source_id and existing_type == source_type: if existing_id == source_id and existing_type == source_type:
existing_index = i existing_index = i
break break
if existing_index is not None: if existing_index is not None:
existing_sources_count = len(deduplicated_sources[existing_index].get('sources', [])) existing_sources_count = len(
deduplicated_sources[existing_index].get("sources", [])
)
if current_sources_count > existing_sources_count: if current_sources_count > existing_sources_count:
# Replace the existing source object with the new one that has more sources # Replace the existing source object with the new one that has more sources
deduplicated_sources[existing_index] = source_obj deduplicated_sources[existing_index] = source_obj
print(f"Debug: Replaced source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {existing_sources_count} -> {current_sources_count}") print(
f"Debug: Replaced source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {existing_sources_count} -> {current_sources_count}"
)
else: else:
print(f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count} <= {existing_sources_count}") print(
f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count} <= {existing_sources_count}"
)
else: else:
print(f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key} (couldn't find existing)") print(
f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key} (couldn't find existing)"
)
else: else:
# If there's no ID or type, just add it to be safe # If there's no ID or type, just add it to be safe
deduplicated_sources.append(source_obj) deduplicated_sources.append(source_obj)
print(f"Debug: Added source without ID/type - {source_obj.get('name', 'UNKNOWN')}") print(
f"Debug: Added source without ID/type - {source_obj.get('name', 'UNKNOWN')}"
)
# Stream info about deduplicated sources # Stream info about deduplicated sources
if streaming_service and writer: if streaming_service and writer:
@ -1039,7 +1055,7 @@ async def fetch_relevant_documents(
async def process_sections( async def process_sections(
state: State, config: RunnableConfig, writer: StreamWriter state: State, config: RunnableConfig, writer: StreamWriter
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Process all sections in parallel and combine the results. Process all sections in parallel and combine the results.
@ -1100,13 +1116,13 @@ async def process_sections(
) )
if configuration.num_sections == 1: if configuration.num_sections == 1:
TOP_K = 10 top_k = 10
elif configuration.num_sections == 3: elif configuration.num_sections == 3:
TOP_K = 20 top_k = 20
elif configuration.num_sections == 6: elif configuration.num_sections == 6:
TOP_K = 30 top_k = 30
else: else:
TOP_K = 10 top_k = 10
relevant_documents = [] relevant_documents = []
user_selected_documents = [] user_selected_documents = []
@ -1155,13 +1171,13 @@ async def process_sections(
connectors_to_search=configuration.connectors_to_search, connectors_to_search=configuration.connectors_to_search,
writer=writer, writer=writer,
state=state, state=state,
top_k=TOP_K, top_k=top_k,
connector_service=connector_service, connector_service=connector_service,
search_mode=configuration.search_mode, search_mode=configuration.search_mode,
user_selected_sources=user_selected_sources, user_selected_sources=user_selected_sources,
) )
except Exception as e: except Exception as e:
error_message = f"Error fetching relevant documents: {str(e)}" error_message = f"Error fetching relevant documents: {e!s}"
print(error_message) print(error_message)
writer({"yield_value": streaming_service.format_error(error_message)}) writer({"yield_value": streaming_service.format_error(error_message)})
# Log the error and continue with an empty list of documents # Log the error and continue with an empty list of documents
@ -1251,7 +1267,7 @@ async def process_sections(
for i, result in enumerate(section_results): for i, result in enumerate(section_results):
if isinstance(result, Exception): if isinstance(result, Exception):
section_title = answer_outline.answer_outline[i].section_title section_title = answer_outline.answer_outline[i].section_title
error_message = f"Error processing section '{section_title}': {str(result)}" error_message = f"Error processing section '{section_title}': {result!s}"
print(error_message) print(error_message)
writer({"yield_value": streaming_service.format_error(error_message)}) writer({"yield_value": streaming_service.format_error(error_message)})
processed_results.append(error_message) processed_results.append(error_message)
@ -1260,8 +1276,8 @@ async def process_sections(
# Combine the results into a final report with section titles # Combine the results into a final report with section titles
final_report = [] final_report = []
for i, (section, content) in enumerate( for _i, (section, content) in enumerate(
zip(answer_outline.answer_outline, processed_results) zip(answer_outline.answer_outline, processed_results, strict=False)
): ):
# Skip adding the section header since the content already contains the title # Skip adding the section header since the content already contains the title
final_report.append(content) final_report.append(content)
@ -1299,15 +1315,15 @@ async def process_sections(
async def process_section_with_documents( async def process_section_with_documents(
section_id: int, section_id: int,
section_title: str, section_title: str,
section_questions: List[str], section_questions: list[str],
user_id: str, user_id: str,
search_space_id: int, search_space_id: int,
relevant_documents: List[Dict[str, Any]], relevant_documents: list[dict[str, Any]],
user_query: str, user_query: str,
state: State = None, state: State = None,
writer: StreamWriter = None, writer: StreamWriter = None,
sub_section_type: SubSectionType = SubSectionType.MIDDLE, sub_section_type: SubSectionType = SubSectionType.MIDDLE,
section_contents: Dict[int, Dict[str, Any]] = None, section_contents: dict[int, dict[str, Any]] | None = None,
) -> str: ) -> str:
""" """
Process a single section using pre-fetched documents. Process a single section using pre-fetched documents.
@ -1388,7 +1404,7 @@ async def process_section_with_documents(
# Variables to track streaming state # Variables to track streaming state
complete_content = "" # Tracks the complete content received so far complete_content = "" # Tracks the complete content received so far
async for chunk_type, chunk in sub_section_writer_graph.astream( async for _chunk_type, chunk in sub_section_writer_graph.astream(
sub_state, config, stream_mode=["values"] sub_state, config, stream_mode=["values"]
): ):
if "final_answer" in chunk: if "final_answer" in chunk:
@ -1448,24 +1464,24 @@ async def process_section_with_documents(
return complete_content return complete_content
except Exception as e: except Exception as e:
print(f"Error processing section '{section_title}': {str(e)}") print(f"Error processing section '{section_title}': {e!s}")
# Send error update via streaming if available # Send error update via streaming if available
if state and state.streaming_service and writer: if state and state.streaming_service and writer:
writer( writer(
{ {
"yield_value": state.streaming_service.format_error( "yield_value": state.streaming_service.format_error(
f'Error processing section "{section_title}": {str(e)}' f'Error processing section "{section_title}": {e!s}'
) )
} }
) )
return f"Error processing section: {section_title}. Details: {str(e)}" return f"Error processing section: {section_title}. Details: {e!s}"
async def reformulate_user_query( async def reformulate_user_query(
state: State, config: RunnableConfig, writer: StreamWriter state: State, config: RunnableConfig, writer: StreamWriter
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Reforms the user query based on the chat history. Reforms the user query based on the chat history.
""" """
@ -1490,7 +1506,7 @@ async def reformulate_user_query(
async def handle_qna_workflow( async def handle_qna_workflow(
state: State, config: RunnableConfig, writer: StreamWriter state: State, config: RunnableConfig, writer: StreamWriter
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Handle the QNA research workflow. Handle the QNA research workflow.
@ -1532,7 +1548,7 @@ async def handle_qna_workflow(
) )
# Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM # Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM
TOP_K = 15 top_k = 15
relevant_documents = [] relevant_documents = []
user_selected_documents = [] user_selected_documents = []
@ -1584,13 +1600,13 @@ async def handle_qna_workflow(
connectors_to_search=configuration.connectors_to_search, connectors_to_search=configuration.connectors_to_search,
writer=writer, writer=writer,
state=state, state=state,
top_k=TOP_K, top_k=top_k,
connector_service=connector_service, connector_service=connector_service,
search_mode=configuration.search_mode, search_mode=configuration.search_mode,
user_selected_sources=user_selected_sources, user_selected_sources=user_selected_sources,
) )
except Exception as e: except Exception as e:
error_message = f"Error fetching relevant documents for QNA: {str(e)}" error_message = f"Error fetching relevant documents for QNA: {e!s}"
print(error_message) print(error_message)
writer({"yield_value": streaming_service.format_error(error_message)}) writer({"yield_value": streaming_service.format_error(error_message)})
# Continue with empty documents - the QNA agent will handle this gracefully # Continue with empty documents - the QNA agent will handle this gracefully
@ -1688,16 +1704,16 @@ async def handle_qna_workflow(
} }
except Exception as e: except Exception as e:
error_message = f"Error generating QNA answer: {str(e)}" error_message = f"Error generating QNA answer: {e!s}"
print(error_message) print(error_message)
writer({"yield_value": streaming_service.format_error(error_message)}) writer({"yield_value": streaming_service.format_error(error_message)})
return {"final_written_report": f"Error generating answer: {str(e)}"} return {"final_written_report": f"Error generating answer: {e!s}"}
async def generate_further_questions( async def generate_further_questions(
state: State, config: RunnableConfig, writer: StreamWriter state: State, config: RunnableConfig, writer: StreamWriter
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Generate contextually relevant follow-up questions based on chat history and available documents. Generate contextually relevant follow-up questions based on chat history and available documents.
@ -1748,7 +1764,7 @@ async def generate_further_questions(
chat_history_xml += f"<assistant>{message.content}</assistant>\n" chat_history_xml += f"<assistant>{message.content}</assistant>\n"
else: else:
# Handle other message types if needed # Handle other message types if needed
chat_history_xml += f"<message>{str(message)}</message>\n" chat_history_xml += f"<message>{message!s}</message>\n"
chat_history_xml += "</chat_history>" chat_history_xml += "</chat_history>"
# Format available documents for the prompt # Format available documents for the prompt
@ -1868,7 +1884,7 @@ async def generate_further_questions(
except (json.JSONDecodeError, ValueError) as e: except (json.JSONDecodeError, ValueError) as e:
# Log the error and return empty list # Log the error and return empty list
error_message = f"Error parsing further questions response: {str(e)}" error_message = f"Error parsing further questions response: {e!s}"
print(error_message) print(error_message)
writer( writer(
{"yield_value": streaming_service.format_error(f"Warning: {error_message}")} {"yield_value": streaming_service.format_error(f"Warning: {error_message}")}
@ -1880,7 +1896,7 @@ async def generate_further_questions(
except Exception as e: except Exception as e:
# Handle any other errors # Handle any other errors
error_message = f"Error generating further questions: {str(e)}" error_message = f"Error generating further questions: {e!s}"
print(error_message) print(error_message)
writer( writer(
{"yield_value": streaming_service.format_error(f"Warning: {error_message}")} {"yield_value": streaming_service.format_error(f"Warning: {error_message}")}

View file

@ -1,5 +1,4 @@
"""QnA Agent. """QnA Agent."""
"""
from .graph import graph from .graph import graph

View file

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Optional, List, Any from typing import Any
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
@ -15,13 +15,15 @@ class Configuration:
# Configuration parameters for the Q&A agent # Configuration parameters for the Q&A agent
user_query: str # The user's question to answer user_query: str # The user's question to answer
reformulated_query: str # The reformulated query reformulated_query: str # The reformulated query
relevant_documents: List[Any] # Documents provided directly to the agent for answering relevant_documents: list[
Any
] # Documents provided directly to the agent for answering
user_id: str # User identifier user_id: str # User identifier
search_space_id: int # Search space identifier search_space_id: int # Search space identifier
@classmethod @classmethod
def from_runnable_config( def from_runnable_config(
cls, config: Optional[RunnableConfig] = None cls, config: RunnableConfig | None = None
) -> Configuration: ) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object.""" """Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {} configurable = (config.get("configurable") or {}) if config else {}

View file

@ -1,7 +1,8 @@
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from .state import State
from .nodes import rerank_documents, answer_question
from .configuration import Configuration from .configuration import Configuration
from .nodes import answer_question, rerank_documents
from .state import State
# Define a new graph # Define a new graph
workflow = StateGraph(State, config_schema=Configuration) workflow = StateGraph(State, config_schema=Configuration)

View file

@ -1,17 +1,21 @@
from app.services.reranker_service import RerankerService from typing import Any
from .configuration import Configuration
from langchain_core.runnables import RunnableConfig
from .state import State
from typing import Any, Dict
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
from langchain_core.messages import HumanMessage, SystemMessage
from ..utils import (
optimize_documents_for_token_limit,
calculate_token_count,
format_documents_section
)
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]: from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from app.services.reranker_service import RerankerService
from ..utils import (
calculate_token_count,
format_documents_section,
optimize_documents_for_token_limit,
)
from .configuration import Configuration
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
from .state import State
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
""" """
Rerank the documents based on relevance to the user's question. Rerank the documents based on relevance to the user's question.
@ -30,9 +34,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
# If no documents were provided, return empty list # If no documents were provided, return empty list
if not documents or len(documents) == 0: if not documents or len(documents) == 0:
return { return {"reranked_documents": []}
"reranked_documents": []
}
# Get reranker service from app config # Get reranker service from app config
reranker_service = RerankerService.get_reranker_instance() reranker_service = RerankerService.get_reranker_instance()
@ -51,28 +53,34 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
"document": { "document": {
"id": doc.get("document", {}).get("id", ""), "id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""), "title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get("document_type", ""), "document_type": doc.get("document", {}).get(
"metadata": doc.get("document", {}).get("metadata", {}) "document_type", ""
} ),
} for i, doc in enumerate(documents) "metadata": doc.get("document", {}).get("metadata", {}),
},
}
for i, doc in enumerate(documents)
] ]
# Rerank documents using the user's query # Rerank documents using the user's query
reranked_docs = reranker_service.rerank_documents(user_query + "\n" + reformulated_query, reranker_input_docs) reranked_docs = reranker_service.rerank_documents(
user_query + "\n" + reformulated_query, reranker_input_docs
)
# Sort by score in descending order # Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}") print(
f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}"
)
except Exception as e: except Exception as e:
print(f"Error during reranking: {str(e)}") print(f"Error during reranking: {e!s}")
# Use original docs if reranking fails # Use original docs if reranking fails
return { return {"reranked_documents": reranked_docs}
"reranked_documents": reranked_docs
}
async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any]:
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
""" """
Answer the user's question using the provided documents. Answer the user's question using the provided documents.
@ -117,14 +125,15 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
# Use initial system prompt for token calculation # Use initial system prompt for token calculation
initial_system_prompt = get_qna_citation_system_prompt() initial_system_prompt = get_qna_citation_system_prompt()
base_messages = state.chat_history + [ base_messages = [
*state.chat_history,
SystemMessage(content=initial_system_prompt), SystemMessage(content=initial_system_prompt),
HumanMessage(content=base_human_message_template) HumanMessage(content=base_human_message_template),
] ]
# Optimize documents to fit within token limits # Optimize documents to fit within token limits
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit( optimized_documents, has_optimized_documents = (
documents, base_messages, llm.model optimize_documents_for_token_limit(documents, base_messages, llm.model)
) )
# Update state based on optimization result # Update state based on optimization result
@ -134,19 +143,26 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
has_documents = False has_documents = False
# Choose system prompt based on final document availability # Choose system prompt based on final document availability
system_prompt = get_qna_citation_system_prompt() if has_documents else get_qna_no_documents_system_prompt() system_prompt = (
get_qna_citation_system_prompt()
if has_documents
else get_qna_no_documents_system_prompt()
)
# Generate documents section # Generate documents section
documents_text = format_documents_section( documents_text = (
documents, format_documents_section(
"Source material from your personal knowledge base" documents, "Source material from your personal knowledge base"
) if has_documents else "" )
if has_documents
else ""
)
# Create final human message content # Create final human message content
instruction_text = ( instruction_text = (
"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner." "Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner."
if has_documents else if has_documents
"Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner." else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
) )
human_message_content = f""" human_message_content = f"""
@ -161,20 +177,18 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
""" """
# Create final messages for the LLM # Create final messages for the LLM
messages_with_chat_history = state.chat_history + [ messages_with_chat_history = [
*state.chat_history,
SystemMessage(content=system_prompt), SystemMessage(content=system_prompt),
HumanMessage(content=human_message_content) HumanMessage(content=human_message_content),
] ]
# Log final token count # Log final token count
total_tokens = calculate_token_count(messages_with_chat_history, llm.model) total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
print(f"Final token count: {total_tokens}") print(f"Final token count: {total_tokens}")
# Call the LLM and get the response # Call the LLM and get the response
response = await llm.ainvoke(messages_with_chat_history) response = await llm.ainvoke(messages_with_chat_history)
final_answer = response.content final_answer = response.content
return { return {"final_answer": final_answer}
"final_answer": final_answer
}

View file

@ -3,9 +3,11 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Any from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@dataclass @dataclass
class State: class State:
"""Defines the dynamic state for the Q&A agent during execution. """Defines the dynamic state for the Q&A agent during execution.
@ -19,7 +21,7 @@ class State:
# Runtime context # Runtime context
db_session: AsyncSession db_session: AsyncSession
chat_history: Optional[List[Any]] = field(default_factory=list) chat_history: list[Any] | None = field(default_factory=list)
# OUTPUT: Populated by agent nodes # OUTPUT: Populated by agent nodes
reranked_documents: Optional[List[Any]] = None reranked_documents: list[Any] | None = None
final_answer: Optional[str] = None final_answer: str | None = None

View file

@ -3,10 +3,13 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Any from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.services.streaming_service import StreamingService from app.services.streaming_service import StreamingService
@dataclass @dataclass
class State: class State:
"""Defines the dynamic state for the agent during execution. """Defines the dynamic state for the agent during execution.
@ -15,23 +18,23 @@ class State:
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information. for more information.
""" """
# Runtime context (not part of actual graph state) # Runtime context (not part of actual graph state)
db_session: AsyncSession db_session: AsyncSession
# Streaming service # Streaming service
streaming_service: StreamingService streaming_service: StreamingService
chat_history: Optional[List[Any]] = field(default_factory=list) chat_history: list[Any] | None = field(default_factory=list)
reformulated_query: Optional[str] = field(default=None) reformulated_query: str | None = field(default=None)
# Using field to explicitly mark as part of state # Using field to explicitly mark as part of state
answer_outline: Optional[Any] = field(default=None) answer_outline: Any | None = field(default=None)
further_questions: Optional[Any] = field(default=None) further_questions: Any | None = field(default=None)
# Temporary field to hold reranked documents from sub-agents for further question generation # Temporary field to hold reranked documents from sub-agents for further question generation
reranked_documents: Optional[List[Any]] = field(default=None) reranked_documents: list[Any] | None = field(default=None)
# OUTPUT: Populated by agent nodes # OUTPUT: Populated by agent nodes
# Using field to explicitly mark as part of state # Using field to explicitly mark as part of state
final_written_report: Optional[str] = field(default=None) final_written_report: str | None = field(default=None)

View file

@ -4,13 +4,14 @@ from __future__ import annotations
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum from enum import Enum
from typing import Optional, List, Any from typing import Any
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
class SubSectionType(Enum): class SubSectionType(Enum):
"""Enum defining the type of sub-section.""" """Enum defining the type of sub-section."""
START = "START" START = "START"
MIDDLE = "MIDDLE" MIDDLE = "MIDDLE"
END = "END" END = "END"
@ -22,17 +23,16 @@ class Configuration:
# Input parameters provided at invocation # Input parameters provided at invocation
sub_section_title: str sub_section_title: str
sub_section_questions: List[str] sub_section_questions: list[str]
sub_section_type: SubSectionType sub_section_type: SubSectionType
user_query: str user_query: str
relevant_documents: List[Any] # Documents provided directly to the agent relevant_documents: list[Any] # Documents provided directly to the agent
user_id: str user_id: str
search_space_id: int search_space_id: int
@classmethod @classmethod
def from_runnable_config( def from_runnable_config(
cls, config: Optional[RunnableConfig] = None cls, config: RunnableConfig | None = None
) -> Configuration: ) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object.""" """Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {} configurable = (config.get("configurable") or {}) if config else {}

View file

@ -1,7 +1,8 @@
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from .state import State
from .nodes import write_sub_section, rerank_documents
from .configuration import Configuration from .configuration import Configuration
from .nodes import rerank_documents, write_sub_section
from .state import State
# Define a new graph # Define a new graph
workflow = StateGraph(State, config_schema=Configuration) workflow = StateGraph(State, config_schema=Configuration)

View file

@ -1,18 +1,21 @@
from .configuration import Configuration from typing import Any
from langchain_core.runnables import RunnableConfig
from .state import State
from typing import Any, Dict
from app.services.reranker_service import RerankerService
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
from langchain_core.messages import HumanMessage, SystemMessage
from .configuration import SubSectionType
from ..utils import (
optimize_documents_for_token_limit,
calculate_token_count,
format_documents_section
)
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]: from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from app.services.reranker_service import RerankerService
from ..utils import (
calculate_token_count,
format_documents_section,
optimize_documents_for_token_limit,
)
from .configuration import Configuration, SubSectionType
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
from .state import State
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
""" """
Rerank the documents based on relevance to the sub-section title. Rerank the documents based on relevance to the sub-section title.
@ -30,9 +33,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
# If no documents were provided, return empty list # If no documents were provided, return empty list
if not documents or len(documents) == 0: if not documents or len(documents) == 0:
return { return {"reranked_documents": []}
"reranked_documents": []
}
# Get reranker service from app config # Get reranker service from app config
reranker_service = RerankerService.get_reranker_instance() reranker_service = RerankerService.get_reranker_instance()
@ -46,7 +47,9 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
# rerank_query = "\n".join(sub_section_questions) # rerank_query = "\n".join(sub_section_questions)
# rerank_query = configuration.user_query # rerank_query = configuration.user_query
rerank_query = configuration.user_query + "\n" + "\n".join(sub_section_questions) rerank_query = (
configuration.user_query + "\n" + "\n".join(sub_section_questions)
)
# Convert documents to format expected by reranker if needed # Convert documents to format expected by reranker if needed
reranker_input_docs = [ reranker_input_docs = [
@ -57,28 +60,34 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
"document": { "document": {
"id": doc.get("document", {}).get("id", ""), "id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""), "title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get("document_type", ""), "document_type": doc.get("document", {}).get(
"metadata": doc.get("document", {}).get("metadata", {}) "document_type", ""
} ),
} for i, doc in enumerate(documents) "metadata": doc.get("document", {}).get("metadata", {}),
},
}
for i, doc in enumerate(documents)
] ]
# Rerank documents using the section title # Rerank documents using the section title
reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs) reranked_docs = reranker_service.rerank_documents(
rerank_query, reranker_input_docs
)
# Sort by score in descending order # Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
print(f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}") print(
f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}"
)
except Exception as e: except Exception as e:
print(f"Error during reranking: {str(e)}") print(f"Error during reranking: {e!s}")
# Use original docs if reranking fails # Use original docs if reranking fails
return { return {"reranked_documents": reranked_docs}
"reranked_documents": reranked_docs
}
async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]:
async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, Any]:
""" """
Write the sub-section using the provided documents. Write the sub-section using the provided documents.
@ -118,7 +127,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
section_position_context_map = { section_position_context_map = {
SubSectionType.START: "This is the INTRODUCTION section.", SubSectionType.START: "This is the INTRODUCTION section.",
SubSectionType.MIDDLE: "This is a MIDDLE section. Ensure this content flows naturally from previous sections and into subsequent ones. This could be any middle section in the document, so maintain coherence with the overall structure while addressing the specific topic of this section. Do not provide any conclusions in this section, as conclusions should only appear in the final section.", SubSectionType.MIDDLE: "This is a MIDDLE section. Ensure this content flows naturally from previous sections and into subsequent ones. This could be any middle section in the document, so maintain coherence with the overall structure while addressing the specific topic of this section. Do not provide any conclusions in this section, as conclusions should only appear in the final section.",
SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure." SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure.",
} }
section_position_context = section_position_context_map.get(sub_section_type, "") section_position_context = section_position_context_map.get(sub_section_type, "")
@ -152,14 +161,15 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Use initial system prompt for token calculation # Use initial system prompt for token calculation
initial_system_prompt = get_citation_system_prompt() initial_system_prompt = get_citation_system_prompt()
base_messages = state.chat_history + [ base_messages = [
*state.chat_history,
SystemMessage(content=initial_system_prompt), SystemMessage(content=initial_system_prompt),
HumanMessage(content=base_human_message_template) HumanMessage(content=base_human_message_template),
] ]
# Optimize documents to fit within token limits # Optimize documents to fit within token limits
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit( optimized_documents, has_optimized_documents = (
documents, base_messages, llm.model optimize_documents_for_token_limit(documents, base_messages, llm.model)
) )
# Update state based on optimization result # Update state based on optimization result
@ -169,16 +179,22 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
has_documents = False has_documents = False
# Choose system prompt based on final document availability # Choose system prompt based on final document availability
system_prompt = get_citation_system_prompt() if has_documents else get_no_documents_system_prompt() system_prompt = (
get_citation_system_prompt()
if has_documents
else get_no_documents_system_prompt()
)
# Generate documents section # Generate documents section
documents_text = format_documents_section(documents, "Source material") if has_documents else "" documents_text = (
format_documents_section(documents, "Source material") if has_documents else ""
)
# Create final human message content # Create final human message content
instruction_text = ( instruction_text = (
"Please write content for this sub-section using the provided source material and cite all information appropriately." "Please write content for this sub-section using the provided source material and cite all information appropriately."
if has_documents else if has_documents
"Please write content for this sub-section based on our conversation history and your general knowledge." else "Please write content for this sub-section based on our conversation history and your general knowledge."
) )
human_message_content = f""" human_message_content = f"""
@ -206,9 +222,10 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
""" """
# Create final messages for the LLM # Create final messages for the LLM
messages_with_chat_history = state.chat_history + [ messages_with_chat_history = [
*state.chat_history,
SystemMessage(content=system_prompt), SystemMessage(content=system_prompt),
HumanMessage(content=human_message_content) HumanMessage(content=human_message_content),
] ]
# Log final token count # Log final token count
@ -219,7 +236,4 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
response = await llm.ainvoke(messages_with_chat_history) response = await llm.ainvoke(messages_with_chat_history)
final_answer = response.content final_answer = response.content
return { return {"final_answer": final_answer}
"final_answer": final_answer
}

View file

@ -3,9 +3,11 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Any from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@dataclass @dataclass
class State: class State:
"""Defines the dynamic state for the agent during execution. """Defines the dynamic state for the agent during execution.
@ -14,11 +16,11 @@ class State:
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information. for more information.
""" """
# Runtime context # Runtime context
db_session: AsyncSession db_session: AsyncSession
chat_history: Optional[List[Any]] = field(default_factory=list) chat_history: list[Any] | None = field(default_factory=list)
# OUTPUT: Populated by agent nodes # OUTPUT: Populated by agent nodes
reranked_documents: Optional[List[Any]] = None reranked_documents: list[Any] | None = None
final_answer: Optional[str] = None final_answer: str | None = None

View file

@ -1,23 +1,33 @@
from typing import List, Dict, Any, Tuple, NamedTuple from typing import Any, NamedTuple
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from litellm import get_model_info, token_counter
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from litellm import token_counter, get_model_info
class Section(BaseModel): class Section(BaseModel):
"""A section in the answer outline.""" """A section in the answer outline."""
section_id: int = Field(..., description="The zero-based index of the section") section_id: int = Field(..., description="The zero-based index of the section")
section_title: str = Field(..., description="The title of the section") section_title: str = Field(..., description="The title of the section")
questions: List[str] = Field(..., description="Questions to research for this section") questions: list[str] = Field(
..., description="Questions to research for this section"
)
class AnswerOutline(BaseModel): class AnswerOutline(BaseModel):
"""The complete answer outline with all sections.""" """The complete answer outline with all sections."""
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
answer_outline: list[Section] = Field(
..., description="List of sections in the answer outline"
)
class DocumentTokenInfo(NamedTuple): class DocumentTokenInfo(NamedTuple):
"""Information about a document and its token cost.""" """Information about a document and its token cost."""
index: int index: int
document: Dict[str, Any] document: dict[str, Any]
formatted_content: str formatted_content: str
token_count: int token_count: int
@ -36,7 +46,7 @@ def get_connector_emoji(connector_name: str) -> str:
"JIRA_CONNECTOR": "🎫", "JIRA_CONNECTOR": "🎫",
"DISCORD_CONNECTOR": "🗨️", "DISCORD_CONNECTOR": "🗨️",
"TAVILY_API": "🔍", "TAVILY_API": "🔍",
"LINKUP_API": "🔗" "LINKUP_API": "🔗",
} }
return connector_emojis.get(connector_name, "🔎") return connector_emojis.get(connector_name, "🔎")
@ -55,31 +65,26 @@ def get_connector_friendly_name(connector_name: str) -> str:
"JIRA_CONNECTOR": "Jira", "JIRA_CONNECTOR": "Jira",
"DISCORD_CONNECTOR": "Discord", "DISCORD_CONNECTOR": "Discord",
"TAVILY_API": "Tavily Search", "TAVILY_API": "Tavily Search",
"LINKUP_API": "Linkup Search" "LINKUP_API": "Linkup Search",
} }
return connector_friendly_names.get(connector_name, connector_name) return connector_friendly_names.get(connector_name, connector_name)
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]: def convert_langchain_messages_to_dict(
messages: list[BaseMessage],
) -> list[dict[str, str]]:
"""Convert LangChain messages to format expected by token_counter.""" """Convert LangChain messages to format expected by token_counter."""
role_mapping = { role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
'system': 'system',
'human': 'user',
'ai': 'assistant'
}
converted_messages = [] converted_messages = []
for msg in messages: for msg in messages:
role = role_mapping.get(getattr(msg, 'type', None), 'user') role = role_mapping.get(getattr(msg, "type", None), "user")
converted_messages.append({ converted_messages.append({"role": role, "content": str(msg.content)})
"role": role,
"content": str(msg.content)
})
return converted_messages return converted_messages
def format_document_for_citation(document: Dict[str, Any]) -> str: def format_document_for_citation(document: dict[str, Any]) -> str:
"""Format a single document for citation in the standard XML format.""" """Format a single document for citation in the standard XML format."""
content = document.get("content", "") content = document.get("content", "")
doc_info = document.get("document", {}) doc_info = document.get("document", {})
@ -97,7 +102,9 @@ def format_document_for_citation(document: Dict[str, Any]) -> str:
</document>""" </document>"""
def format_documents_section(documents: List[Dict[str, Any]], section_title: str = "Source material") -> str: def format_documents_section(
documents: list[dict[str, Any]], section_title: str = "Source material"
) -> str:
"""Format multiple documents into a complete documents section.""" """Format multiple documents into a complete documents section."""
if not documents: if not documents:
return "" return ""
@ -110,7 +117,9 @@ def format_documents_section(documents: List[Dict[str, Any]], section_title: str
</documents>""" </documents>"""
def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) -> List[DocumentTokenInfo]: def calculate_document_token_costs(
documents: list[dict[str, Any]], model: str
) -> list[DocumentTokenInfo]:
"""Pre-calculate token costs for each document.""" """Pre-calculate token costs for each document."""
document_token_info = [] document_token_info = []
@ -119,24 +128,24 @@ def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str)
# Calculate token count for this document # Calculate token count for this document
token_count = token_counter( token_count = token_counter(
messages=[{"role": "user", "content": formatted_doc}], messages=[{"role": "user", "content": formatted_doc}], model=model
model=model
) )
document_token_info.append(DocumentTokenInfo( document_token_info.append(
index=i, DocumentTokenInfo(
document=doc, index=i,
formatted_content=formatted_doc, document=doc,
token_count=token_count formatted_content=formatted_doc,
)) token_count=token_count,
)
)
return document_token_info return document_token_info
def find_optimal_documents_with_binary_search( def find_optimal_documents_with_binary_search(
document_tokens: List[DocumentTokenInfo], document_tokens: list[DocumentTokenInfo], available_tokens: int
available_tokens: int ) -> list[DocumentTokenInfo]:
) -> List[DocumentTokenInfo]:
"""Use binary search to find the maximum number of documents that fit within token limit.""" """Use binary search to find the maximum number of documents that fit within token limit."""
if not document_tokens or available_tokens <= 0: if not document_tokens or available_tokens <= 0:
return [] return []
@ -147,8 +156,7 @@ def find_optimal_documents_with_binary_search(
while left <= right: while left <= right:
mid = (left + right) // 2 mid = (left + right) // 2
current_docs = document_tokens[:mid] current_docs = document_tokens[:mid]
current_token_sum = sum( current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
doc_info.token_count for doc_info in current_docs)
if current_token_sum <= available_tokens: if current_token_sum <= available_tokens:
optimal_docs = current_docs optimal_docs = current_docs
@ -163,20 +171,18 @@ def get_model_context_window(model_name: str) -> int:
"""Get the total context window size for a model (input + output tokens).""" """Get the total context window size for a model (input + output tokens)."""
try: try:
model_info = get_model_info(model_name) model_info = get_model_info(model_name)
context_window = model_info.get( context_window = model_info.get("max_input_tokens", 4096) # Default fallback
'max_input_tokens', 4096) # Default fallback
return context_window return context_window
except Exception as e: except Exception as e:
print( print(
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}") f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
)
return 4096 # Conservative fallback return 4096 # Conservative fallback
def optimize_documents_for_token_limit( def optimize_documents_for_token_limit(
documents: List[Dict[str, Any]], documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
base_messages: List[BaseMessage], ) -> tuple[list[dict[str, Any]], bool]:
model_name: str
) -> Tuple[List[Dict[str, Any]], bool]:
""" """
Optimize documents to fit within token limits using binary search. Optimize documents to fit within token limits using binary search.
@ -201,7 +207,8 @@ def optimize_documents_for_token_limit(
available_tokens_for_docs = context_window - base_tokens available_tokens_for_docs = context_window - base_tokens
print( print(
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}") f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
)
if available_tokens_for_docs <= 0: if available_tokens_for_docs <= 0:
print("No tokens available for documents after base content and output buffer") print("No tokens available for documents after base content and output buffer")
@ -212,8 +219,7 @@ def optimize_documents_for_token_limit(
# Find optimal number of documents using binary search # Find optimal number of documents using binary search
optimal_doc_info = find_optimal_documents_with_binary_search( optimal_doc_info = find_optimal_documents_with_binary_search(
document_token_info, document_token_info, available_tokens_for_docs
available_tokens_for_docs
) )
# Extract the original document objects # Extract the original document objects
@ -221,12 +227,13 @@ def optimize_documents_for_token_limit(
has_documents_remaining = len(optimized_documents) > 0 has_documents_remaining = len(optimized_documents) > 0
print( print(
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents") f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
)
return optimized_documents, has_documents_remaining return optimized_documents, has_documents_remaining
def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int: def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
"""Calculate token count for a list of LangChain messages.""" """Calculate token count for a list of LangChain messages."""
model = model_name model = model_name
messages_dict = convert_langchain_messages_to_dict(messages) messages_dict = convert_langchain_messages_to_dict(messages)

View file

@ -2,22 +2,13 @@ from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, create_db_and_tables, get_async_session
from app.schemas import UserCreate, UserRead, UserUpdate
from app.routes import router as crud_router
from app.config import config from app.config import config
from app.db import User, create_db_and_tables, get_async_session
from app.users import ( from app.routes import router as crud_router
SECRET, from app.schemas import UserCreate, UserRead, UserUpdate
auth_backend, from app.users import SECRET, auth_backend, current_active_user, fastapi_users
fastapi_users,
current_active_user
)
@asynccontextmanager @asynccontextmanager
@ -64,12 +55,10 @@ app.include_router(
if config.AUTH_TYPE == "GOOGLE": if config.AUTH_TYPE == "GOOGLE":
from app.users import google_oauth_client from app.users import google_oauth_client
app.include_router( app.include_router(
fastapi_users.get_oauth_router( fastapi_users.get_oauth_router(
google_oauth_client, google_oauth_client, auth_backend, SECRET, is_verified_by_default=True
auth_backend,
SECRET,
is_verified_by_default=True
), ),
prefix="/auth/google", prefix="/auth/google",
tags=["auth"], tags=["auth"],
@ -79,5 +68,8 @@ app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
@app.get("/verify-token") @app.get("/verify-token")
async def authenticated_route(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)): async def authenticated_route(
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
return {"message": "Token is valid"} return {"message": "Token is valid"}

View file

@ -1,13 +1,11 @@
import os import os
from pathlib import Path
import shutil import shutil
from pathlib import Path
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
from dotenv import load_dotenv from dotenv import load_dotenv
from rerankers import Reranker from rerankers import Reranker
# Get the base directory of the project # Get the base directory of the project
BASE_DIR = Path(__file__).resolve().parent.parent.parent BASE_DIR = Path(__file__).resolve().parent.parent.parent
@ -25,30 +23,30 @@ def is_ffmpeg_installed():
return shutil.which("ffmpeg") is not None return shutil.which("ffmpeg") is not None
class Config: class Config:
# Check if ffmpeg is installed # Check if ffmpeg is installed
if not is_ffmpeg_installed(): if not is_ffmpeg_installed():
import static_ffmpeg import static_ffmpeg
# ffmpeg installed on first call to add_paths(), threadsafe. # ffmpeg installed on first call to add_paths(), threadsafe.
static_ffmpeg.add_paths() static_ffmpeg.add_paths()
# check if ffmpeg is installed again # check if ffmpeg is installed again
if not is_ffmpeg_installed(): if not is_ffmpeg_installed():
raise ValueError("FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster.") raise ValueError(
"FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster."
)
# Database # Database
DATABASE_URL = os.getenv("DATABASE_URL") DATABASE_URL = os.getenv("DATABASE_URL")
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL") NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
# AUTH: Google OAuth # AUTH: Google OAuth
AUTH_TYPE = os.getenv("AUTH_TYPE") AUTH_TYPE = os.getenv("AUTH_TYPE")
if AUTH_TYPE == "GOOGLE": if AUTH_TYPE == "GOOGLE":
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID") GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET") GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
# LLM instances are now managed per-user through the LLMConfig system # LLM instances are now managed per-user through the LLMConfig system
# Legacy environment variables removed in favor of user-specific configurations # Legacy environment variables removed in favor of user-specific configurations
@ -56,10 +54,10 @@ class Config:
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL) embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
chunker_instance = RecursiveChunker( chunker_instance = RecursiveChunker(
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512) chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
) )
code_chunker_instance = CodeChunker( code_chunker_instance = CodeChunker(
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512) chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
) )
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage # Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
@ -97,17 +95,18 @@ class Config:
STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE") STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE")
STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY") STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY")
# Validation Checks # Validation Checks
# Check embedding dimension # Check embedding dimension
if hasattr(embedding_model_instance, 'dimension') and embedding_model_instance.dimension > 2000: if (
hasattr(embedding_model_instance, "dimension")
and embedding_model_instance.dimension > 2000
):
raise ValueError( raise ValueError(
f"Embedding dimension for Model: {EMBEDDING_MODEL} " f"Embedding dimension for Model: {EMBEDDING_MODEL} "
f"has {embedding_model_instance.dimension} dimensions, which " f"has {embedding_model_instance.dimension} dimensions, which "
f"exceeds the maximum of 2000 allowed by PGVector." f"exceeds the maximum of 2000 allowed by PGVector."
) )
@classmethod @classmethod
def get_settings(cls): def get_settings(cls):
"""Get all settings as a dictionary.""" """Get all settings as a dictionary."""

View file

@ -1,26 +1,25 @@
import os import os
def _parse_bool(value): def _parse_bool(value):
"""Parse boolean value from string.""" """Parse boolean value from string."""
return value.lower() == "true" if value else False return value.lower() == "true" if value else False
def _parse_int(value, var_name): def _parse_int(value, var_name):
"""Parse integer value with error handling.""" """Parse integer value with error handling."""
try: try:
return int(value) return int(value)
except ValueError: except ValueError:
raise ValueError(f"Invalid integer value for {var_name}: {value}") raise ValueError(f"Invalid integer value for {var_name}: {value}") from None
def _parse_headers(value): def _parse_headers(value):
"""Parse headers from comma-separated string.""" """Parse headers from comma-separated string."""
try: try:
return [ return [tuple(h.split(":", 1)) for h in value.split(",") if ":" in h]
tuple(h.split(":", 1))
for h in value.split(",")
if ":" in h
]
except Exception: except Exception:
raise ValueError(f"Invalid headers format: {value}") raise ValueError(f"Invalid headers format: {value}") from None
def load_uvicorn_config(args=None): def load_uvicorn_config(args=None):
@ -28,16 +27,16 @@ def load_uvicorn_config(args=None):
Load Uvicorn configuration from environment variables and CLI args. Load Uvicorn configuration from environment variables and CLI args.
Returns a dict suitable for passing to uvicorn.Config. Returns a dict suitable for passing to uvicorn.Config.
""" """
config_kwargs = dict( config_kwargs = {
app="app.app:app", "app": "app.app:app",
host=os.getenv("UVICORN_HOST", "0.0.0.0"), "host": os.getenv("UVICORN_HOST", "0.0.0.0"),
port=int(os.getenv("UVICORN_PORT", 8000)), "port": int(os.getenv("UVICORN_PORT", 8000)),
log_level=os.getenv("UVICORN_LOG_LEVEL", "info"), "log_level": os.getenv("UVICORN_LOG_LEVEL", "info"),
reload=args.reload if args else False, "reload": args.reload if args else False,
reload_dirs=["app"] if (args and args.reload) else None, "reload_dirs": ["app"] if (args and args.reload) else None,
) }
# Configuration mapping for advanced options # Configuration mapping for advanced options
config_mapping = { config_mapping = {
"UVICORN_PROXY_HEADERS": ("proxy_headers", _parse_bool), "UVICORN_PROXY_HEADERS": ("proxy_headers", _parse_bool),
"UVICORN_FORWARDED_ALLOW_IPS": ("forwarded_allow_ips", str), "UVICORN_FORWARDED_ALLOW_IPS": ("forwarded_allow_ips", str),
@ -51,15 +50,33 @@ def load_uvicorn_config(args=None):
"UVICORN_LOG_CONFIG": ("log_config", str), "UVICORN_LOG_CONFIG": ("log_config", str),
"UVICORN_SERVER_HEADER": ("server_header", _parse_bool), "UVICORN_SERVER_HEADER": ("server_header", _parse_bool),
"UVICORN_DATE_HEADER": ("date_header", _parse_bool), "UVICORN_DATE_HEADER": ("date_header", _parse_bool),
"UVICORN_LIMIT_CONCURRENCY": ("limit_concurrency", lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY")), "UVICORN_LIMIT_CONCURRENCY": (
"UVICORN_LIMIT_MAX_REQUESTS": ("limit_max_requests", lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS")), "limit_concurrency",
"UVICORN_TIMEOUT_KEEP_ALIVE": ("timeout_keep_alive", lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE")), lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY"),
"UVICORN_TIMEOUT_NOTIFY": ("timeout_notify", lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY")), ),
"UVICORN_LIMIT_MAX_REQUESTS": (
"limit_max_requests",
lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS"),
),
"UVICORN_TIMEOUT_KEEP_ALIVE": (
"timeout_keep_alive",
lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE"),
),
"UVICORN_TIMEOUT_NOTIFY": (
"timeout_notify",
lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY"),
),
"UVICORN_SSL_KEYFILE": ("ssl_keyfile", str), "UVICORN_SSL_KEYFILE": ("ssl_keyfile", str),
"UVICORN_SSL_CERTFILE": ("ssl_certfile", str), "UVICORN_SSL_CERTFILE": ("ssl_certfile", str),
"UVICORN_SSL_KEYFILE_PASSWORD": ("ssl_keyfile_password", str), "UVICORN_SSL_KEYFILE_PASSWORD": ("ssl_keyfile_password", str),
"UVICORN_SSL_VERSION": ("ssl_version", lambda x: _parse_int(x, "UVICORN_SSL_VERSION")), "UVICORN_SSL_VERSION": (
"UVICORN_SSL_CERT_REQS": ("ssl_cert_reqs", lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS")), "ssl_version",
lambda x: _parse_int(x, "UVICORN_SSL_VERSION"),
),
"UVICORN_SSL_CERT_REQS": (
"ssl_cert_reqs",
lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS"),
),
"UVICORN_SSL_CA_CERTS": ("ssl_ca_certs", str), "UVICORN_SSL_CA_CERTS": ("ssl_ca_certs", str),
"UVICORN_SSL_CIPHERS": ("ssl_ciphers", str), "UVICORN_SSL_CIPHERS": ("ssl_ciphers", str),
"UVICORN_HEADERS": ("headers", _parse_headers), "UVICORN_HEADERS": ("headers", _parse_headers),
@ -76,7 +93,6 @@ def load_uvicorn_config(args=None):
try: try:
config_kwargs[config_key] = parser(value) config_kwargs[config_key] = parser(value)
except ValueError as e: except ValueError as e:
raise ValueError(f"Configuration error for {env_var}: {e}") raise ValueError(f"Configuration error for {env_var}: {e}") from e
return config_kwargs return config_kwargs

View file

@ -6,11 +6,12 @@ A module for interacting with Discord's HTTP API to retrieve guilds, channels, a
Requires a Discord bot token. Requires a Discord bot token.
""" """
import asyncio
import datetime
import logging import logging
import discord import discord
from discord.ext import commands from discord.ext import commands
import datetime
import asyncio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
class DiscordConnector(commands.Bot): class DiscordConnector(commands.Bot):
"""Class for retrieving guild, channel, and message history from Discord.""" """Class for retrieving guild, channel, and message history from Discord."""
def __init__(self, token: str = None): def __init__(self, token: str | None = None):
""" """
Initialize the DiscordConnector with a bot token. Initialize the DiscordConnector with a bot token.
@ -30,7 +31,9 @@ class DiscordConnector(commands.Bot):
intents.messages = True # Required to fetch messages intents.messages = True # Required to fetch messages
intents.message_content = True # Required to read message content intents.message_content = True # Required to read message content
intents.members = True # Required to fetch member information intents.members = True # Required to fetch member information
super().__init__(command_prefix="!", intents=intents) # command_prefix is required but not strictly used here super().__init__(
command_prefix="!", intents=intents
) # command_prefix is required but not strictly used here
self.token = token self.token = token
self._bot_task = None # Holds the async bot task self._bot_task = None # Holds the async bot task
self._is_running = False # Flag to track if the bot is running self._is_running = False # Flag to track if the bot is running
@ -48,7 +51,7 @@ class DiscordConnector(commands.Bot):
@self.event @self.event
async def on_disconnect(): async def on_disconnect():
logger.debug("Bot disconnected from Discord gateway.") logger.debug("Bot disconnected from Discord gateway.")
self._is_running = False # Reset flag on disconnect self._is_running = False # Reset flag on disconnect
@self.event @self.event
async def on_resumed(): async def on_resumed():
@ -63,17 +66,23 @@ class DiscordConnector(commands.Bot):
try: try:
if self._is_running: if self._is_running:
logger.warning("Bot is already running. Use close_bot() to stop it before starting again.") logger.warning(
"Bot is already running. Use close_bot() to stop it before starting again."
)
return return
await self.start(self.token) await self.start(self.token)
logger.info("Discord bot started successfully.") logger.info("Discord bot started successfully.")
except discord.LoginFailure: except discord.LoginFailure:
logger.error("Failed to log in: Invalid token was provided. Please check your bot token.") logger.error(
"Failed to log in: Invalid token was provided. Please check your bot token."
)
self._is_running = False self._is_running = False
raise raise
except discord.PrivilegedIntentsRequired as e: except discord.PrivilegedIntentsRequired as e:
logger.error(f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page.") logger.error(
f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page."
)
self._is_running = False self._is_running = False
raise raise
except discord.ConnectionClosed as e: except discord.ConnectionClosed as e:
@ -96,7 +105,6 @@ class DiscordConnector(commands.Bot):
else: else:
logger.info("Bot is not running or already disconnected.") logger.info("Bot is not running or already disconnected.")
def set_token(self, token: str) -> None: def set_token(self, token: str) -> None:
""" """
Set the discord bot token. Set the discord bot token.
@ -106,7 +114,9 @@ class DiscordConnector(commands.Bot):
""" """
logger.info("Setting Discord bot token.") logger.info("Setting Discord bot token.")
self.token = token self.token = token
logger.info("Token set successfully. You can now start the bot with start_bot().") logger.info(
"Token set successfully. You can now start the bot with start_bot()."
)
async def _wait_until_ready(self): async def _wait_until_ready(self):
"""Helper to wait until the bot is connected and ready.""" """Helper to wait until the bot is connected and ready."""
@ -115,16 +125,20 @@ class DiscordConnector(commands.Bot):
# Give the event loop a chance to switch to the bot's startup task. # Give the event loop a chance to switch to the bot's startup task.
# This allows self.start() to begin initializing the client. # This allows self.start() to begin initializing the client.
# Terrible solution, but necessary to avoid blocking the event loop. # Terrible solution, but necessary to avoid blocking the event loop.
await asyncio.sleep(1) # Yield control to the event loop await asyncio.sleep(1) # Yield control to the event loop
try: try:
await asyncio.wait_for(self.wait_until_ready(), timeout=60.0) await asyncio.wait_for(self.wait_until_ready(), timeout=60.0)
logger.info("Bot is ready.") logger.info("Bot is ready.")
except asyncio.TimeoutError: except TimeoutError:
logger.error(f"Bot did not become ready within 60 seconds. Connection may have failed.") logger.error(
"Bot did not become ready within 60 seconds. Connection may have failed."
)
raise raise
except Exception as e: except Exception as e:
logger.error(f"An unexpected error occurred while waiting for the bot to be ready: {e}") logger.error(
f"An unexpected error occurred while waiting for the bot to be ready: {e}"
)
raise raise
async def get_guilds(self) -> list[dict]: async def get_guilds(self) -> list[dict]:
@ -143,7 +157,9 @@ class DiscordConnector(commands.Bot):
guilds_data = [] guilds_data = []
for guild in self.guilds: for guild in self.guilds:
member_count = guild.member_count if guild.member_count is not None else "N/A" member_count = (
guild.member_count if guild.member_count is not None else "N/A"
)
guilds_data.append( guilds_data.append(
{ {
"id": str(guild.id), "id": str(guild.id),
@ -184,14 +200,16 @@ class DiscordConnector(commands.Bot):
{"id": str(channel.id), "name": channel.name, "type": "text"} {"id": str(channel.id), "name": channel.name, "type": "text"}
) )
logger.info(f"Fetched {len(channels_data)} text channels from guild {guild_id}.") logger.info(
f"Fetched {len(channels_data)} text channels from guild {guild_id}."
)
return channels_data return channels_data
async def get_channel_history( async def get_channel_history(
self, self,
channel_id: str, channel_id: str,
start_date: str = None, start_date: str | None = None,
end_date: str = None, end_date: str | None = None,
) -> list[dict]: ) -> list[dict]:
""" """
Fetch message history from a text channel. Fetch message history from a text channel.
@ -227,20 +245,26 @@ class DiscordConnector(commands.Bot):
if start_date: if start_date:
try: try:
start_datetime = datetime.datetime.fromisoformat(start_date).replace(tzinfo=datetime.timezone.utc) start_datetime = datetime.datetime.fromisoformat(start_date).replace(
tzinfo=datetime.UTC
)
after = start_datetime after = start_datetime
except ValueError: except ValueError:
logger.warning(f"Invalid start_date format: {start_date}. Ignoring.") logger.warning(f"Invalid start_date format: {start_date}. Ignoring.")
if end_date: if end_date:
try: try:
end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(tzinfo=datetime.timezone.utc) end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(
tzinfo=datetime.UTC
)
before = end_datetime before = end_datetime
except ValueError: except ValueError:
logger.warning(f"Invalid end_date format: {end_date}. Ignoring.") logger.warning(f"Invalid end_date format: {end_date}. Ignoring.")
try: try:
async for message in channel.history(limit=None, before=before, after=after): async for message in channel.history(
limit=None, before=before, after=after
):
messages_data.append( messages_data.append(
{ {
"id": str(message.id), "id": str(message.id),
@ -251,7 +275,9 @@ class DiscordConnector(commands.Bot):
} }
) )
except discord.Forbidden: except discord.Forbidden:
logger.error(f"Bot does not have permissions to read message history in channel {channel_id}.") logger.error(
f"Bot does not have permissions to read message history in channel {channel_id}."
)
raise raise
except discord.HTTPException as e: except discord.HTTPException as e:
logger.error(f"Failed to fetch messages from channel {channel_id}: {e}") logger.error(f"Failed to fetch messages from channel {channel_id}: {e}")
@ -278,7 +304,9 @@ class DiscordConnector(commands.Bot):
permissions to view members. permissions to view members.
""" """
await self._wait_until_ready() await self._wait_until_ready()
logger.info(f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}") logger.info(
f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}"
)
guild = self.get_guild(int(guild_id)) guild = self.get_guild(int(guild_id))
if not guild: if not guild:
@ -294,7 +322,9 @@ class DiscordConnector(commands.Bot):
return { return {
"id": str(member.id), "id": str(member.id),
"name": member.name, "name": member.name,
"joined_at": member.joined_at.isoformat() if member.joined_at else None, "joined_at": member.joined_at.isoformat()
if member.joined_at
else None,
"roles": roles, "roles": roles,
} }
logger.warning(f"User {user_id} not found in guild {guild_id}.") logger.warning(f"User {user_id} not found in guild {guild_id}.")
@ -303,8 +333,12 @@ class DiscordConnector(commands.Bot):
logger.warning(f"User {user_id} not found in guild {guild_id}.") logger.warning(f"User {user_id} not found in guild {guild_id}.")
return None return None
except discord.Forbidden: except discord.Forbidden:
logger.error(f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled.") logger.error(
f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled."
)
raise raise
except discord.HTTPException as e: except discord.HTTPException as e:
logger.error(f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}") logger.error(
f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}"
)
return None return None

View file

@ -1,54 +1,91 @@
import base64 import base64
import logging import logging
from typing import List, Optional, Dict, Any from typing import Any
from github3 import login as github_login, exceptions as github_exceptions
from github3.repos.contents import Contents from github3 import exceptions as github_exceptions, login as github_login
from github3.exceptions import ForbiddenError, NotFoundError from github3.exceptions import ForbiddenError, NotFoundError
from github3.repos.contents import Contents
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# List of common code file extensions to target # List of common code file extensions to target
CODE_EXTENSIONS = { CODE_EXTENSIONS = {
'.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp', ".py",
'.cs', '.go', '.rb', '.php', '.swift', '.kt', '.scala', '.rs', '.m', ".js",
'.sh', '.bash', '.ps1', '.lua', '.pl', '.pm', '.r', '.dart', '.sql' ".jsx",
".ts",
".tsx",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rb",
".php",
".swift",
".kt",
".scala",
".rs",
".m",
".sh",
".bash",
".ps1",
".lua",
".pl",
".pm",
".r",
".dart",
".sql",
} }
# List of common documentation/text file extensions # List of common documentation/text file extensions
DOC_EXTENSIONS = { DOC_EXTENSIONS = {
'.md', '.txt', '.rst', '.adoc', '.html', '.htm', '.xml', '.json', '.yaml', '.yml', '.toml' ".md",
".txt",
".rst",
".adoc",
".html",
".htm",
".xml",
".json",
".yaml",
".yml",
".toml",
} }
# Maximum file size in bytes (e.g., 1MB) # Maximum file size in bytes (e.g., 1MB)
MAX_FILE_SIZE = 1 * 1024 * 1024 MAX_FILE_SIZE = 1 * 1024 * 1024
class GitHubConnector: class GitHubConnector:
"""Connector for interacting with the GitHub API.""" """Connector for interacting with the GitHub API."""
# Directories to skip during file traversal # Directories to skip during file traversal
SKIPPED_DIRS = { SKIPPED_DIRS = {
# Version control # Version control
'.git', ".git",
# Dependencies # Dependencies
'node_modules', "node_modules",
'vendor', "vendor",
# Build artifacts / Caches # Build artifacts / Caches
'build', "build",
'dist', "dist",
'target', "target",
'__pycache__', "__pycache__",
# Virtual environments # Virtual environments
'venv', "venv",
'.venv', ".venv",
'env', "env",
# IDE/Editor config # IDE/Editor config
'.vscode', ".vscode",
'.idea', ".idea",
'.project', ".project",
'.settings', ".settings",
# Temporary / Logs # Temporary / Logs
'tmp', "tmp",
'logs', "logs",
# Add other project-specific irrelevant directories if needed # Add other project-specific irrelevant directories if needed
} }
@ -68,35 +105,39 @@ class GitHubConnector:
logger.info("Successfully authenticated with GitHub API.") logger.info("Successfully authenticated with GitHub API.")
except (github_exceptions.AuthenticationFailed, ForbiddenError) as e: except (github_exceptions.AuthenticationFailed, ForbiddenError) as e:
logger.error(f"GitHub authentication failed: {e}") logger.error(f"GitHub authentication failed: {e}")
raise ValueError("Invalid GitHub token or insufficient permissions.") raise ValueError("Invalid GitHub token or insufficient permissions.") from e
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize GitHub client: {e}") logger.error(f"Failed to initialize GitHub client: {e}")
raise raise e
def get_user_repositories(self) -> List[Dict[str, Any]]: def get_user_repositories(self) -> list[dict[str, Any]]:
"""Fetches repositories accessible by the authenticated user.""" """Fetches repositories accessible by the authenticated user."""
repos_data = [] repos_data = []
try: try:
# type='owner' fetches repos owned by the user # type='owner' fetches repos owned by the user
# type='member' fetches repos the user is a collaborator on (including orgs) # type='member' fetches repos the user is a collaborator on (including orgs)
# type='all' fetches both # type='all' fetches both
for repo in self.gh.repositories(type='all', sort='updated'): for repo in self.gh.repositories(type="all", sort="updated"):
repos_data.append({ repos_data.append(
"id": repo.id, {
"name": repo.name, "id": repo.id,
"full_name": repo.full_name, "name": repo.name,
"private": repo.private, "full_name": repo.full_name,
"url": repo.html_url, "private": repo.private,
"description": repo.description or "", "url": repo.html_url,
"last_updated": repo.updated_at if repo.updated_at else None, "description": repo.description or "",
}) "last_updated": repo.updated_at if repo.updated_at else None,
}
)
logger.info(f"Fetched {len(repos_data)} repositories.") logger.info(f"Fetched {len(repos_data)} repositories.")
return repos_data return repos_data
except Exception as e: except Exception as e:
logger.error(f"Failed to fetch GitHub repositories: {e}") logger.error(f"Failed to fetch GitHub repositories: {e}")
return [] # Return empty list on error return [] # Return empty list on error
def get_repository_files(self, repo_full_name: str, path: str = '') -> List[Dict[str, Any]]: def get_repository_files(
self, repo_full_name: str, path: str = ""
) -> list[dict[str, Any]]:
""" """
Recursively fetches details of relevant files (code, docs) within a repository path. Recursively fetches details of relevant files (code, docs) within a repository path.
@ -110,54 +151,72 @@ class GitHubConnector:
""" """
files_list = [] files_list = []
try: try:
owner, repo_name = repo_full_name.split('/') owner, repo_name = repo_full_name.split("/")
repo = self.gh.repository(owner, repo_name) repo = self.gh.repository(owner, repo_name)
if not repo: if not repo:
logger.warning(f"Repository '{repo_full_name}' not found.") logger.warning(f"Repository '{repo_full_name}' not found.")
return [] return []
contents = repo.directory_contents(directory_path=path) # Use directory_contents for clarity contents = repo.directory_contents(
directory_path=path
) # Use directory_contents for clarity
# contents returns a list of tuples (name, content_obj) # contents returns a list of tuples (name, content_obj)
for item_name, content_item in contents: for _item_name, content_item in contents:
if not isinstance(content_item, Contents): if not isinstance(content_item, Contents):
continue continue
if content_item.type == 'dir': if content_item.type == "dir":
# Check if the directory name is in the skipped list # Check if the directory name is in the skipped list
if content_item.name in self.SKIPPED_DIRS: if content_item.name in self.SKIPPED_DIRS:
logger.debug(f"Skipping directory: {content_item.path}") logger.debug(f"Skipping directory: {content_item.path}")
continue # Skip recursion for this directory continue # Skip recursion for this directory
# Recursively fetch contents of subdirectory # Recursively fetch contents of subdirectory
files_list.extend(self.get_repository_files(repo_full_name, path=content_item.path)) files_list.extend(
elif content_item.type == 'file': self.get_repository_files(
repo_full_name, path=content_item.path
)
)
elif content_item.type == "file":
# Check if the file extension is relevant and size is within limits # Check if the file extension is relevant and size is within limits
file_extension = '.' + content_item.name.split('.')[-1].lower() if '.' in content_item.name else '' file_extension = (
"." + content_item.name.split(".")[-1].lower()
if "." in content_item.name
else ""
)
is_code = file_extension in CODE_EXTENSIONS is_code = file_extension in CODE_EXTENSIONS
is_doc = file_extension in DOC_EXTENSIONS is_doc = file_extension in DOC_EXTENSIONS
if (is_code or is_doc) and content_item.size <= MAX_FILE_SIZE: if (is_code or is_doc) and content_item.size <= MAX_FILE_SIZE:
files_list.append({ files_list.append(
"path": content_item.path, {
"sha": content_item.sha, "path": content_item.path,
"url": content_item.html_url, "sha": content_item.sha,
"size": content_item.size, "url": content_item.html_url,
"type": "code" if is_code else "doc" "size": content_item.size,
}) "type": "code" if is_code else "doc",
}
)
elif content_item.size > MAX_FILE_SIZE: elif content_item.size > MAX_FILE_SIZE:
logger.debug(f"Skipping large file: {content_item.path} ({content_item.size} bytes)") logger.debug(
f"Skipping large file: {content_item.path} ({content_item.size} bytes)"
)
else: else:
logger.debug(f"Skipping irrelevant file type: {content_item.path}") logger.debug(
f"Skipping irrelevant file type: {content_item.path}"
)
except (NotFoundError, ForbiddenError) as e: except (NotFoundError, ForbiddenError) as e:
logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}") logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}")
except Exception as e: except Exception as e:
logger.error(f"Failed to get files for {repo_full_name} at path '{path}': {e}") logger.error(
f"Failed to get files for {repo_full_name} at path '{path}': {e}"
)
# Return what we have collected so far in case of partial failure # Return what we have collected so far in case of partial failure
return files_list return files_list
def get_file_content(self, repo_full_name: str, file_path: str) -> Optional[str]: def get_file_content(self, repo_full_name: str, file_path: str) -> str | None:
""" """
Fetches the decoded content of a specific file. Fetches the decoded content of a specific file.
@ -169,43 +228,69 @@ class GitHubConnector:
The decoded file content as a string, or None if fetching fails or file is too large. The decoded file content as a string, or None if fetching fails or file is too large.
""" """
try: try:
owner, repo_name = repo_full_name.split('/') owner, repo_name = repo_full_name.split("/")
repo = self.gh.repository(owner, repo_name) repo = self.gh.repository(owner, repo_name)
if not repo: if not repo:
logger.warning(f"Repository '{repo_full_name}' not found when fetching file '{file_path}'.") logger.warning(
f"Repository '{repo_full_name}' not found when fetching file '{file_path}'."
)
return None return None
content_item = repo.file_contents(path=file_path) # Use file_contents for clarity content_item = repo.file_contents(
path=file_path
) # Use file_contents for clarity
if not content_item or not isinstance(content_item, Contents) or content_item.type != 'file': if (
logger.warning(f"File '{file_path}' not found or is not a file in '{repo_full_name}'.") not content_item
or not isinstance(content_item, Contents)
or content_item.type != "file"
):
logger.warning(
f"File '{file_path}' not found or is not a file in '{repo_full_name}'."
)
return None return None
if content_item.size > MAX_FILE_SIZE: if content_item.size > MAX_FILE_SIZE:
logger.warning(f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch.") logger.warning(
f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch."
)
return None return None
# Content is base64 encoded # Content is base64 encoded
if content_item.content: if content_item.content:
try: try:
decoded_content = base64.b64decode(content_item.content).decode('utf-8') decoded_content = base64.b64decode(content_item.content).decode(
"utf-8"
)
return decoded_content return decoded_content
except UnicodeDecodeError: except UnicodeDecodeError:
logger.warning(f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'.") logger.warning(
f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'."
)
try: try:
# Try a fallback encoding # Try a fallback encoding
decoded_content = base64.b64decode(content_item.content).decode('latin-1') decoded_content = base64.b64decode(content_item.content).decode(
"latin-1"
)
return decoded_content return decoded_content
except Exception as decode_err: except Exception as decode_err:
logger.error(f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}") logger.error(
return None # Give up if fallback fails f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}"
)
return None # Give up if fallback fails
else: else:
logger.warning(f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty.") logger.warning(
return "" # Return empty string for empty files f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty."
)
return "" # Return empty string for empty files
except (NotFoundError, ForbiddenError) as e: except (NotFoundError, ForbiddenError) as e:
logger.warning(f"Cannot access file '{file_path}' in '{repo_full_name}': {e}") logger.warning(
return None f"Cannot access file '{file_path}' in '{repo_full_name}': {e}"
except Exception as e: )
logger.error(f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}") return None
except Exception as e:
logger.error(
f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}"
)
return None return None

View file

@ -7,7 +7,7 @@ Allows fetching issue lists and their comments, projects and more.
import base64 import base64
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any
import requests import requests
@ -17,9 +17,9 @@ class JiraConnector:
def __init__( def __init__(
self, self,
base_url: Optional[str] = None, base_url: str | None = None,
email: Optional[str] = None, email: str | None = None,
api_token: Optional[str] = None, api_token: str | None = None,
): ):
""" """
Initialize the JiraConnector class. Initialize the JiraConnector class.
@ -65,7 +65,7 @@ class JiraConnector:
""" """
self.api_token = api_token self.api_token = api_token
def get_headers(self) -> Dict[str, str]: def get_headers(self) -> dict[str, str]:
""" """
Get headers for Jira API requests using Basic Authentication. Get headers for Jira API requests using Basic Authentication.
@ -92,8 +92,8 @@ class JiraConnector:
} }
def make_api_request( def make_api_request(
self, endpoint: str, params: Optional[Dict[str, Any]] = None self, endpoint: str, params: dict[str, Any] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Make a request to the Jira API. Make a request to the Jira API.
@ -138,7 +138,7 @@ class JiraConnector:
""" """
return self.make_api_request("project/search") return self.make_api_request("project/search")
def get_all_issues(self, project_key: Optional[str] = None) -> List[Dict[str, Any]]: def get_all_issues(self, project_key: str | None = None) -> list[dict[str, Any]]:
""" """
Fetch all issues from Jira. Fetch all issues from Jira.
@ -204,8 +204,8 @@ class JiraConnector:
start_date: str, start_date: str,
end_date: str, end_date: str,
include_comments: bool = True, include_comments: bool = True,
project_key: Optional[str] = None, project_key: str | None = None,
) -> tuple[List[Dict[str, Any]], Optional[str]]: ) -> tuple[list[dict[str, Any]], str | None]:
""" """
Fetch issues within a date range. Fetch issues within a date range.
@ -226,9 +226,9 @@ class JiraConnector:
) )
# TODO : This JQL needs some improvement to work as expected # TODO : This JQL needs some improvement to work as expected
jql = f"{date_filter}" _jql = f"{date_filter}"
if project_key: if project_key:
jql = ( _jql = (
f'project = "{project_key}" AND {date_filter} ORDER BY created DESC' f'project = "{project_key}" AND {date_filter} ORDER BY created DESC'
) )
@ -250,7 +250,7 @@ class JiraConnector:
fields.append("comment") fields.append("comment")
params = { params = {
# "jql": "", TODO : Add a JQL query to filter from a date range # "jql": "", TODO : Add a JQL query to filter from a date range
"fields": ",".join(fields), "fields": ",".join(fields),
"maxResults": 100, "maxResults": 100,
"startAt": 0, "startAt": 0,
@ -283,9 +283,9 @@ class JiraConnector:
return all_issues, None return all_issues, None
except Exception as e: except Exception as e:
return [], f"Error fetching issues: {str(e)}" return [], f"Error fetching issues: {e!s}"
def format_issue(self, issue: Dict[str, Any]) -> Dict[str, Any]: def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
""" """
Format an issue for easier consumption. Format an issue for easier consumption.
@ -401,7 +401,7 @@ class JiraConnector:
return formatted return formatted
def format_issue_to_markdown(self, issue: Dict[str, Any]) -> str: def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
""" """
Convert an issue to markdown format. Convert an issue to markdown format.

View file

@ -5,15 +5,16 @@ A module for retrieving issues and comments from Linear.
Allows fetching issue lists and their comments with date range filtering. Allows fetching issue lists and their comments with date range filtering.
""" """
import requests
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union from typing import Any
import requests
class LinearConnector: class LinearConnector:
"""Class for retrieving issues and comments from Linear.""" """Class for retrieving issues and comments from Linear."""
def __init__(self, token: str = None): def __init__(self, token: str | None = None):
""" """
Initialize the LinearConnector class. Initialize the LinearConnector class.
@ -32,7 +33,7 @@ class LinearConnector:
""" """
self.token = token self.token = token
def get_headers(self) -> Dict[str, str]: def get_headers(self) -> dict[str, str]:
""" """
Get headers for Linear API requests. Get headers for Linear API requests.
@ -45,12 +46,11 @@ class LinearConnector:
if not self.token: if not self.token:
raise ValueError("Linear token not initialized. Call set_token() first.") raise ValueError("Linear token not initialized. Call set_token() first.")
return { return {"Content-Type": "application/json", "Authorization": self.token}
'Content-Type': 'application/json',
'Authorization': self.token
}
def execute_graphql_query(self, query: str, variables: Dict[str, Any] = None) -> Dict[str, Any]: def execute_graphql_query(
self, query: str, variables: dict[str, Any] | None = None
) -> dict[str, Any]:
""" """
Execute a GraphQL query against the Linear API. Execute a GraphQL query against the Linear API.
@ -69,23 +69,21 @@ class LinearConnector:
raise ValueError("Linear token not initialized. Call set_token() first.") raise ValueError("Linear token not initialized. Call set_token() first.")
headers = self.get_headers() headers = self.get_headers()
payload = {'query': query} payload = {"query": query}
if variables: if variables:
payload['variables'] = variables payload["variables"] = variables
response = requests.post( response = requests.post(self.api_url, headers=headers, json=payload)
self.api_url,
headers=headers,
json=payload
)
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()
else: else:
raise Exception(f"Query failed with status code {response.status_code}: {response.text}") raise Exception(
f"Query failed with status code {response.status_code}: {response.text}"
)
def get_all_issues(self, include_comments: bool = True) -> List[Dict[str, Any]]: def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
""" """
Fetch all issues from Linear. Fetch all issues from Linear.
@ -151,17 +149,18 @@ class LinearConnector:
result = self.execute_graphql_query(query) result = self.execute_graphql_query(query)
# Extract issues from the response # Extract issues from the response
if "data" in result and "issues" in result["data"] and "nodes" in result["data"]["issues"]: if (
"data" in result
and "issues" in result["data"]
and "nodes" in result["data"]["issues"]
):
return result["data"]["issues"]["nodes"] return result["data"]["issues"]["nodes"]
return [] return []
def get_issues_by_date_range( def get_issues_by_date_range(
self, self, start_date: str, end_date: str, include_comments: bool = True
start_date: str, ) -> tuple[list[dict[str, Any]], str | None]:
end_date: str,
include_comments: bool = True
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
""" """
Fetch issues within a date range. Fetch issues within a date range.
@ -263,7 +262,12 @@ class LinearConnector:
# Check for errors # Check for errors
if "errors" in result: if "errors" in result:
error_message = "; ".join([error.get("message", "Unknown error") for error in result["errors"]]) error_message = "; ".join(
[
error.get("message", "Unknown error")
for error in result["errors"]
]
)
return [], f"GraphQL errors: {error_message}" return [], f"GraphQL errors: {error_message}"
# Extract issues from the response # Extract issues from the response
@ -278,7 +282,9 @@ class LinearConnector:
if "pageInfo" in issues_page: if "pageInfo" in issues_page:
page_info = issues_page["pageInfo"] page_info = issues_page["pageInfo"]
has_next_page = page_info.get("hasNextPage", False) has_next_page = page_info.get("hasNextPage", False)
cursor = page_info.get("endCursor") if has_next_page else None cursor = (
page_info.get("endCursor") if has_next_page else None
)
else: else:
has_next_page = False has_next_page = False
else: else:
@ -290,12 +296,12 @@ class LinearConnector:
return all_issues, None return all_issues, None
except Exception as e: except Exception as e:
return [], f"Error fetching issues: {str(e)}" return [], f"Error fetching issues: {e!s}"
except ValueError as e: except ValueError as e:
return [], f"Invalid date format: {str(e)}. Please use YYYY-MM-DD." return [], f"Invalid date format: {e!s}. Please use YYYY-MM-DD."
def format_issue(self, issue: Dict[str, Any]) -> Dict[str, Any]: def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
""" """
Format an issue for easier consumption. Format an issue for easier consumption.
@ -311,21 +317,35 @@ class LinearConnector:
"identifier": issue.get("identifier", ""), "identifier": issue.get("identifier", ""),
"title": issue.get("title", ""), "title": issue.get("title", ""),
"description": issue.get("description", ""), "description": issue.get("description", ""),
"state": issue.get("state", {}).get("name", "Unknown") if issue.get("state") else "Unknown", "state": issue.get("state", {}).get("name", "Unknown")
"state_type": issue.get("state", {}).get("type", "Unknown") if issue.get("state") else "Unknown", if issue.get("state")
else "Unknown",
"state_type": issue.get("state", {}).get("type", "Unknown")
if issue.get("state")
else "Unknown",
"created_at": issue.get("createdAt", ""), "created_at": issue.get("createdAt", ""),
"updated_at": issue.get("updatedAt", ""), "updated_at": issue.get("updatedAt", ""),
"creator": { "creator": {
"id": issue.get("creator", {}).get("id", "") if issue.get("creator") else "", "id": issue.get("creator", {}).get("id", "")
"name": issue.get("creator", {}).get("name", "Unknown") if issue.get("creator") else "Unknown", if issue.get("creator")
"email": issue.get("creator", {}).get("email", "") if issue.get("creator") else "" else "",
} if issue.get("creator") else {"id": "", "name": "Unknown", "email": ""}, "name": issue.get("creator", {}).get("name", "Unknown")
if issue.get("creator")
else "Unknown",
"email": issue.get("creator", {}).get("email", "")
if issue.get("creator")
else "",
}
if issue.get("creator")
else {"id": "", "name": "Unknown", "email": ""},
"assignee": { "assignee": {
"id": issue.get("assignee", {}).get("id", ""), "id": issue.get("assignee", {}).get("id", ""),
"name": issue.get("assignee", {}).get("name", "Unknown"), "name": issue.get("assignee", {}).get("name", "Unknown"),
"email": issue.get("assignee", {}).get("email", "") "email": issue.get("assignee", {}).get("email", ""),
} if issue.get("assignee") else None, }
"comments": [] if issue.get("assignee")
else None,
"comments": [],
} }
# Extract comments if available # Extract comments if available
@ -337,16 +357,24 @@ class LinearConnector:
"created_at": comment.get("createdAt", ""), "created_at": comment.get("createdAt", ""),
"updated_at": comment.get("updatedAt", ""), "updated_at": comment.get("updatedAt", ""),
"user": { "user": {
"id": comment.get("user", {}).get("id", "") if comment.get("user") else "", "id": comment.get("user", {}).get("id", "")
"name": comment.get("user", {}).get("name", "Unknown") if comment.get("user") else "Unknown", if comment.get("user")
"email": comment.get("user", {}).get("email", "") if comment.get("user") else "" else "",
} if comment.get("user") else {"id": "", "name": "Unknown", "email": ""} "name": comment.get("user", {}).get("name", "Unknown")
if comment.get("user")
else "Unknown",
"email": comment.get("user", {}).get("email", "")
if comment.get("user")
else "",
}
if comment.get("user")
else {"id": "", "name": "Unknown", "email": ""},
} }
formatted["comments"].append(formatted_comment) formatted["comments"].append(formatted_comment)
return formatted return formatted
def format_issue_to_markdown(self, issue: Dict[str, Any]) -> str: def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
""" """
Convert an issue to markdown format. Convert an issue to markdown format.
@ -363,37 +391,37 @@ class LinearConnector:
# Build the markdown content # Build the markdown content
markdown = f"# {issue.get('identifier', 'No ID')}: {issue.get('title', 'No Title')}\n\n" markdown = f"# {issue.get('identifier', 'No ID')}: {issue.get('title', 'No Title')}\n\n"
if issue.get('state'): if issue.get("state"):
markdown += f"**Status:** {issue['state']}\n\n" markdown += f"**Status:** {issue['state']}\n\n"
if issue.get('assignee') and issue['assignee'].get('name'): if issue.get("assignee") and issue["assignee"].get("name"):
markdown += f"**Assignee:** {issue['assignee']['name']}\n" markdown += f"**Assignee:** {issue['assignee']['name']}\n"
if issue.get('creator') and issue['creator'].get('name'): if issue.get("creator") and issue["creator"].get("name"):
markdown += f"**Created by:** {issue['creator']['name']}\n" markdown += f"**Created by:** {issue['creator']['name']}\n"
if issue.get('created_at'): if issue.get("created_at"):
created_date = self.format_date(issue['created_at']) created_date = self.format_date(issue["created_at"])
markdown += f"**Created:** {created_date}\n" markdown += f"**Created:** {created_date}\n"
if issue.get('updated_at'): if issue.get("updated_at"):
updated_date = self.format_date(issue['updated_at']) updated_date = self.format_date(issue["updated_at"])
markdown += f"**Updated:** {updated_date}\n\n" markdown += f"**Updated:** {updated_date}\n\n"
if issue.get('description'): if issue.get("description"):
markdown += f"## Description\n\n{issue['description']}\n\n" markdown += f"## Description\n\n{issue['description']}\n\n"
if issue.get('comments'): if issue.get("comments"):
markdown += f"## Comments ({len(issue['comments'])})\n\n" markdown += f"## Comments ({len(issue['comments'])})\n\n"
for comment in issue['comments']: for comment in issue["comments"]:
user_name = "Unknown" user_name = "Unknown"
if comment.get('user') and comment['user'].get('name'): if comment.get("user") and comment["user"].get("name"):
user_name = comment['user']['name'] user_name = comment["user"]["name"]
comment_date = "Unknown date" comment_date = "Unknown date"
if comment.get('created_at'): if comment.get("created_at"):
comment_date = self.format_date(comment['created_at']) comment_date = self.format_date(comment["created_at"])
markdown += f"### {user_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n" markdown += f"### {user_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n"
@ -414,8 +442,8 @@ class LinearConnector:
return "Unknown date" return "Unknown date"
try: try:
dt = datetime.fromisoformat(iso_date.replace('Z', '+00:00')) dt = datetime.fromisoformat(iso_date.replace("Z", "+00:00"))
return dt.strftime('%Y-%m-%d %H:%M:%S') return dt.strftime("%Y-%m-%d %H:%M:%S")
except ValueError: except ValueError:
return iso_date return iso_date

View file

@ -1,5 +1,6 @@
from notion_client import Client from notion_client import Client
class NotionHistoryConnector: class NotionHistoryConnector:
def __init__(self, token): def __init__(self, token):
""" """
@ -26,10 +27,7 @@ class NotionHistoryConnector:
search_params = {} search_params = {}
# Filter for pages only (not databases) # Filter for pages only (not databases)
search_params["filter"] = { search_params["filter"] = {"value": "page", "property": "object"}
"value": "page",
"property": "object"
}
# Add date filters if provided # Add date filters if provided
if start_date or end_date: if start_date or end_date:
@ -45,7 +43,7 @@ class NotionHistoryConnector:
if date_filter: if date_filter:
search_params["sort"] = { search_params["sort"] = {
"direction": "descending", "direction": "descending",
"timestamp": "last_edited_time" "timestamp": "last_edited_time",
} }
# First, get a list of all pages the integration has access to # First, get a list of all pages the integration has access to
@ -60,11 +58,13 @@ class NotionHistoryConnector:
# Get detailed page information # Get detailed page information
page_content = self.get_page_content(page_id) page_content = self.get_page_content(page_id)
all_page_data.append({ all_page_data.append(
"page_id": page_id, {
"title": self.get_page_title(page), "page_id": page_id,
"content": page_content "title": self.get_page_title(page),
}) "content": page_content,
}
)
return all_page_data return all_page_data
@ -81,9 +81,11 @@ class NotionHistoryConnector:
# Title can be in different properties depending on the page type # Title can be in different properties depending on the page type
if "properties" in page: if "properties" in page:
# Try to find a title property # Try to find a title property
for prop_name, prop_data in page["properties"].items(): for _prop_name, prop_data in page["properties"].items():
if prop_data["type"] == "title" and len(prop_data["title"]) > 0: if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
return " ".join([text_obj["plain_text"] for text_obj in prop_data["title"]]) return " ".join(
[text_obj["plain_text"] for text_obj in prop_data["title"]]
)
# If no title found, return the page ID as fallback # If no title found, return the page ID as fallback
return f"Untitled page ({page['id']})" return f"Untitled page ({page['id']})"
@ -105,7 +107,9 @@ class NotionHistoryConnector:
# Paginate through all blocks # Paginate through all blocks
while has_more: while has_more:
if cursor: if cursor:
response = self.notion.blocks.children.list(block_id=page_id, start_cursor=cursor) response = self.notion.blocks.children.list(
block_id=page_id, start_cursor=cursor
)
else: else:
response = self.notion.blocks.children.list(block_id=page_id) response = self.notion.blocks.children.list(block_id=page_id)
@ -153,7 +157,7 @@ class NotionHistoryConnector:
"id": block_id, "id": block_id,
"type": block_type, "type": block_type,
"content": content, "content": content,
"children": child_blocks "children": child_blocks,
} }
def extract_block_content(self, block): def extract_block_content(self, block):
@ -170,7 +174,9 @@ class NotionHistoryConnector:
# Different block types have different structures # Different block types have different structures
if block_type in block and "rich_text" in block[block_type]: if block_type in block and "rich_text" in block[block_type]:
return "".join([text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]]) return "".join(
[text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]]
)
elif block_type == "image": elif block_type == "image":
# Instead of returning the raw URL which may contain sensitive AWS credentials, # Instead of returning the raw URL which may contain sensitive AWS credentials,
# return a placeholder or reference to the image # return a placeholder or reference to the image
@ -183,13 +189,16 @@ class NotionHistoryConnector:
# Only return the domain part of external URLs to avoid potential sensitive parameters # Only return the domain part of external URLs to avoid potential sensitive parameters
try: try:
from urllib.parse import urlparse from urllib.parse import urlparse
parsed_url = urlparse(url) parsed_url = urlparse(url)
return f"[External Image from {parsed_url.netloc}]" return f"[External Image from {parsed_url.netloc}]"
except: except Exception:
return "[External Image]" return "[External Image]"
elif block_type == "code": elif block_type == "code":
language = block["code"]["language"] language = block["code"]["language"]
code_text = "".join([text_obj["plain_text"] for text_obj in block["code"]["rich_text"]]) code_text = "".join(
[text_obj["plain_text"] for text_obj in block["code"]["rich_text"]]
)
return f"```{language}\n{code_text}\n```" return f"```{language}\n{code_text}\n```"
elif block_type == "equation": elif block_type == "equation":
return block["equation"]["expression"] return block["equation"]["expression"]

View file

@ -5,20 +5,21 @@ A module for retrieving conversation history from Slack channels.
Allows fetching channel lists and message history with date range filtering. Allows fetching channel lists and message history with date range filtering.
""" """
import time # Added import import logging # Added import
import logging # Added import import time # Added import
from datetime import datetime
from typing import Any
from slack_sdk import WebClient from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError from slack_sdk.errors import SlackApiError
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any
logger = logging.getLogger(__name__) # Added logger logger = logging.getLogger(__name__) # Added logger
class SlackHistory: class SlackHistory:
"""Class for retrieving conversation history from Slack channels.""" """Class for retrieving conversation history from Slack channels."""
def __init__(self, token: str = None): def __init__(self, token: str | None = None):
""" """
Initialize the SlackHistory class. Initialize the SlackHistory class.
@ -36,7 +37,7 @@ class SlackHistory:
""" """
self.client = WebClient(token=token) self.client = WebClient(token=token)
def get_all_channels(self, include_private: bool = True) -> List[Dict[str, Any]]: def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]:
""" """
Fetch all channels that the bot has access to, with rate limit handling. Fetch all channels that the bot has access to, with rate limit handling.
@ -54,7 +55,7 @@ class SlackHistory:
if not self.client: if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.") raise ValueError("Slack client not initialized. Call set_token() first.")
channels_list = [] # Changed from dict to list channels_list = [] # Changed from dict to list
types = "public_channel" types = "public_channel"
if include_private: if include_private:
types += ",private_channel" types += ",private_channel"
@ -65,14 +66,14 @@ class SlackHistory:
while is_first_request or next_cursor: while is_first_request or next_cursor:
try: try:
if not is_first_request: # Add delay only for paginated requests if not is_first_request: # Add delay only for paginated requests
logger.info(f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}") logger.info(
f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}"
)
time.sleep(3) time.sleep(3)
current_limit = 1000 # Max limit current_limit = 1000 # Max limit
api_result = self.client.conversations_list( api_result = self.client.conversations_list(
types=types, types=types, cursor=next_cursor, limit=current_limit
cursor=next_cursor,
limit=current_limit
) )
channels_on_page = api_result["channels"] channels_on_page = api_result["channels"]
@ -86,12 +87,13 @@ class SlackHistory:
# It indicates if the authenticated user (bot) is a member. # It indicates if the authenticated user (bot) is a member.
# For public channels, this might be true or the API might not focus on it # For public channels, this might be true or the API might not focus on it
# if the bot can read it anyway. For private, it's crucial. # if the bot can read it anyway. For private, it's crucial.
"is_member": channel.get("is_member", False) "is_member": channel.get("is_member", False),
} }
channels_list.append(channel_data) channels_list.append(channel_data)
else: else:
logger.warning(f"Channel found with missing name or id. Data: {channel}") logger.warning(
f"Channel found with missing name or id. Data: {channel}"
)
next_cursor = api_result.get("response_metadata", {}).get("next_cursor") next_cursor = api_result.get("response_metadata", {}).get("next_cursor")
is_first_request = False # Subsequent requests are not the first is_first_request = False # Subsequent requests are not the first
@ -101,21 +103,29 @@ class SlackHistory:
except SlackApiError as e: except SlackApiError as e:
if e.response is not None and e.response.status_code == 429: if e.response is not None and e.response.status_code == 429:
retry_after_header = e.response.headers.get('Retry-After') retry_after_header = e.response.headers.get("Retry-After")
wait_duration = 60 # Default wait time wait_duration = 60 # Default wait time
if retry_after_header and retry_after_header.isdigit(): if retry_after_header and retry_after_header.isdigit():
wait_duration = int(retry_after_header) wait_duration = int(retry_after_header)
logger.warning(f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}") logger.warning(
f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}"
)
time.sleep(wait_duration) time.sleep(wait_duration)
# The loop will continue, retrying with the same cursor # The loop will continue, retrying with the same cursor
else: else:
# Not a 429 error, or no response object, re-raise # Not a 429 error, or no response object, re-raise
raise SlackApiError(f"Error retrieving channels: {e}", e.response) raise SlackApiError(
f"Error retrieving channels: {e}", e.response
) from e
except Exception as general_error: except Exception as general_error:
# Handle other potential errors like network issues if necessary, or re-raise # Handle other potential errors like network issues if necessary, or re-raise
logger.error(f"An unexpected error occurred during channel fetching: {general_error}") logger.error(
raise RuntimeError(f"An unexpected error occurred during channel fetching: {general_error}") f"An unexpected error occurred during channel fetching: {general_error}"
)
raise RuntimeError(
f"An unexpected error occurred during channel fetching: {general_error}"
) from general_error
return channels_list return channels_list
@ -123,9 +133,9 @@ class SlackHistory:
self, self,
channel_id: str, channel_id: str,
limit: int = 1000, limit: int = 1000,
oldest: Optional[int] = None, oldest: int | None = None,
latest: Optional[int] = None latest: int | None = None,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Fetch conversation history for a channel. Fetch conversation history for a channel.
@ -151,7 +161,7 @@ class SlackHistory:
while True: while True:
try: try:
# Proactive delay for conversations.history (Tier 3) # Proactive delay for conversations.history (Tier 3)
time.sleep(1.2) # Wait 1.2 seconds before each history call. time.sleep(1.2) # Wait 1.2 seconds before each history call.
kwargs = { kwargs = {
"channel": channel_id, "channel": channel_id,
@ -165,14 +175,17 @@ class SlackHistory:
kwargs["cursor"] = next_cursor kwargs["cursor"] = next_cursor
current_api_call_successful = False current_api_call_successful = False
result = None # Ensure result is defined result = None # Ensure result is defined
try: try:
result = self.client.conversations_history(**kwargs) result = self.client.conversations_history(**kwargs)
current_api_call_successful = True current_api_call_successful = True
except SlackApiError as e_history: except SlackApiError as e_history:
if e_history.response is not None and e_history.response.status_code == 429: if (
retry_after_str = e_history.response.headers.get('Retry-After') e_history.response is not None
wait_time = 60 # Default and e_history.response.status_code == 429
):
retry_after_str = e_history.response.headers.get("Retry-After")
wait_time = 60 # Default
if retry_after_str and retry_after_str.isdigit(): if retry_after_str and retry_after_str.isdigit():
wait_time = int(retry_after_str) wait_time = int(retry_after_str)
logger.warning( logger.warning(
@ -182,10 +195,10 @@ class SlackHistory:
time.sleep(wait_time) time.sleep(wait_time)
# current_api_call_successful remains False, loop will retry this page # current_api_call_successful remains False, loop will retry this page
else: else:
raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors
if not current_api_call_successful: if not current_api_call_successful:
continue # Retry the current page fetch due to handled rate limit continue # Retry the current page fetch due to handled rate limit
# Process result if successful # Process result if successful
batch = result["messages"] batch = result["messages"]
@ -194,29 +207,36 @@ class SlackHistory:
if result.get("has_more", False) and len(messages) < limit: if result.get("has_more", False) and len(messages) < limit:
next_cursor = result["response_metadata"]["next_cursor"] next_cursor = result["response_metadata"]["next_cursor"]
else: else:
break # Exit pagination loop break # Exit pagination loop
except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try
if (e.response is not None and if (
hasattr(e.response, 'data') and e.response is not None
isinstance(e.response.data, dict) and and hasattr(e.response, "data")
e.response.data.get('error') == 'not_in_channel'): and isinstance(e.response.data, dict)
and e.response.data.get("error") == "not_in_channel"
):
logger.warning( logger.warning(
f"Bot is not in channel '{channel_id}'. Cannot fetch history. " f"Bot is not in channel '{channel_id}'. Cannot fetch history. "
"Please add the bot to this channel." "Please add the bot to this channel."
) )
return [] return []
# For other SlackApiErrors from inner block or this level # For other SlackApiErrors from inner block or this level
raise SlackApiError(f"Error retrieving history for channel {channel_id}: {e}", e.response) raise SlackApiError(
except Exception as general_error: # Catch any other unexpected errors f"Error retrieving history for channel {channel_id}: {e}",
logger.error(f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}") e.response,
) from e
except Exception as general_error: # Catch any other unexpected errors
logger.error(
f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}"
)
# Re-raise the general error to allow higher-level handling or visibility # Re-raise the general error to allow higher-level handling or visibility
raise raise general_error from general_error
return messages[:limit] return messages[:limit]
@staticmethod @staticmethod
def convert_date_to_timestamp(date_str: str) -> Optional[int]: def convert_date_to_timestamp(date_str: str) -> int | None:
""" """
Convert a date string in format YYYY-MM-DD to Unix timestamp. Convert a date string in format YYYY-MM-DD to Unix timestamp.
@ -233,12 +253,8 @@ class SlackHistory:
return None return None
def get_history_by_date_range( def get_history_by_date_range(
self, self, channel_id: str, start_date: str, end_date: str, limit: int = 1000
channel_id: str, ) -> tuple[list[dict[str, Any]], str | None]:
start_date: str,
end_date: str,
limit: int = 1000
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
""" """
Fetch conversation history within a date range. Fetch conversation history within a date range.
@ -253,7 +269,10 @@ class SlackHistory:
""" """
oldest = self.convert_date_to_timestamp(start_date) oldest = self.convert_date_to_timestamp(start_date)
if not oldest: if not oldest:
return [], f"Invalid start date format: {start_date}. Please use YYYY-MM-DD." return (
[],
f"Invalid start date format: {start_date}. Please use YYYY-MM-DD.",
)
latest = self.convert_date_to_timestamp(end_date) latest = self.convert_date_to_timestamp(end_date)
if not latest: if not latest:
@ -264,18 +283,15 @@ class SlackHistory:
try: try:
messages = self.get_conversation_history( messages = self.get_conversation_history(
channel_id=channel_id, channel_id=channel_id, limit=limit, oldest=oldest, latest=latest
limit=limit,
oldest=oldest,
latest=latest
) )
return messages, None return messages, None
except SlackApiError as e: except SlackApiError as e:
return [], f"Slack API error: {str(e)}" return [], f"Slack API error: {e!s}"
except ValueError as e: except ValueError as e:
return [], str(e) return [], str(e)
def get_user_info(self, user_id: str) -> Dict[str, Any]: def get_user_info(self, user_id: str) -> dict[str, Any]:
""" """
Get information about a user. Get information about a user.
@ -299,25 +315,37 @@ class SlackHistory:
# time.sleep(0.6) # Optional: ~100 req/min if ever needed. # time.sleep(0.6) # Optional: ~100 req/min if ever needed.
result = self.client.users_info(user=user_id) result = self.client.users_info(user=user_id)
return result["user"] # Success, return and exit loop implicitly return result["user"] # Success, return and exit loop implicitly
except SlackApiError as e_user_info: except SlackApiError as e_user_info:
if e_user_info.response is not None and e_user_info.response.status_code == 429: if (
retry_after_str = e_user_info.response.headers.get('Retry-After') e_user_info.response is not None
and e_user_info.response.status_code == 429
):
retry_after_str = e_user_info.response.headers.get("Retry-After")
wait_time = 30 # Default for Tier 4, can be adjusted wait_time = 30 # Default for Tier 4, can be adjusted
if retry_after_str and retry_after_str.isdigit(): if retry_after_str and retry_after_str.isdigit():
wait_time = int(retry_after_str) wait_time = int(retry_after_str)
logger.warning(f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds.") logger.warning(
f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds."
)
time.sleep(wait_time) time.sleep(wait_time)
continue # Retry the API call continue # Retry the API call
else: else:
# Not a 429 error, or no response object, re-raise # Not a 429 error, or no response object, re-raise
raise SlackApiError(f"Error retrieving user info for {user_id}: {e_user_info}", e_user_info.response) raise SlackApiError(
except Exception as general_error: # Catch any other unexpected errors f"Error retrieving user info for {user_id}: {e_user_info}",
logger.error(f"Unexpected error in get_user_info for user {user_id}: {general_error}") e_user_info.response,
raise # Re-raise unexpected errors ) from e_user_info
except Exception as general_error: # Catch any other unexpected errors
logger.error(
f"Unexpected error in get_user_info for user {user_id}: {general_error}"
)
raise general_error from general_error # Re-raise unexpected errors
def format_message(self, msg: Dict[str, Any], include_user_info: bool = False) -> Dict[str, Any]: def format_message(
self, msg: dict[str, Any], include_user_info: bool = False
) -> dict[str, Any]:
""" """
Format a message for easier consumption. Format a message for easier consumption.
@ -331,7 +359,9 @@ class SlackHistory:
formatted = { formatted = {
"text": msg.get("text", ""), "text": msg.get("text", ""),
"timestamp": msg.get("ts"), "timestamp": msg.get("ts"),
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime('%Y-%m-%d %H:%M:%S'), "datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime(
"%Y-%m-%d %H:%M:%S"
),
"user_id": msg.get("user", "UNKNOWN"), "user_id": msg.get("user", "UNKNOWN"),
"has_attachments": bool(msg.get("attachments")), "has_attachments": bool(msg.get("attachments")),
"has_files": bool(msg.get("files")), "has_files": bool(msg.get("files")),

View file

@ -1,23 +1,24 @@
import unittest import unittest
from unittest.mock import patch, Mock
from datetime import datetime from datetime import datetime
from unittest.mock import Mock, patch
from github3.exceptions import ForbiddenError # Import the specific exception
# Adjust the import path based on the actual location if test_github_connector.py # Adjust the import path based on the actual location if test_github_connector.py
# is not in the same directory as github_connector.py or if paths are set up differently. # is not in the same directory as github_connector.py or if paths are set up differently.
# Assuming surfsend_backend/app/connectors/test_github_connector.py # Assuming surfsend_backend/app/connectors/test_github_connector.py
from surfsense_backend.app.connectors.github_connector import GitHubConnector from surfsense_backend.app.connectors.github_connector import GitHubConnector
from github3.exceptions import ForbiddenError # Import the specific exception
class TestGitHubConnector(unittest.TestCase): class TestGitHubConnector(unittest.TestCase):
@patch("surfsense_backend.app.connectors.github_connector.github_login")
@patch('surfsense_backend.app.connectors.github_connector.github_login')
def test_get_user_repositories_uses_type_all(self, mock_github_login): def test_get_user_repositories_uses_type_all(self, mock_github_login):
# Mock the GitHub client object and its methods # Mock the GitHub client object and its methods
mock_gh_instance = Mock() mock_gh_instance = Mock()
mock_github_login.return_value = mock_gh_instance mock_github_login.return_value = mock_gh_instance
# Mock the self.gh.me() call in __init__ to prevent an actual API call # Mock the self.gh.me() call in __init__ to prevent an actual API call
mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization
# Prepare mock repository data # Prepare mock repository data
mock_repo1_data = Mock() mock_repo1_data = Mock()
@ -27,7 +28,9 @@ class TestGitHubConnector(unittest.TestCase):
mock_repo1_data.private = False mock_repo1_data.private = False
mock_repo1_data.html_url = "http://example.com/user/repo1" mock_repo1_data.html_url = "http://example.com/user/repo1"
mock_repo1_data.description = "Test repo 1" mock_repo1_data.description = "Test repo 1"
mock_repo1_data.updated_at = datetime(2023, 1, 1, 10, 30, 0) # Added time component mock_repo1_data.updated_at = datetime(
2023, 1, 1, 10, 30, 0
) # Added time component
mock_repo2_data = Mock() mock_repo2_data = Mock()
mock_repo2_data.id = 2 mock_repo2_data.id = 2
@ -36,7 +39,9 @@ class TestGitHubConnector(unittest.TestCase):
mock_repo2_data.private = True mock_repo2_data.private = True
mock_repo2_data.html_url = "http://example.com/org/org-repo" mock_repo2_data.html_url = "http://example.com/org/org-repo"
mock_repo2_data.description = "Org repo" mock_repo2_data.description = "Org repo"
mock_repo2_data.updated_at = datetime(2023, 1, 2, 12, 0, 0) # Added time component mock_repo2_data.updated_at = datetime(
2023, 1, 2, 12, 0, 0
) # Added time component
# Configure the mock for gh.repositories() call # Configure the mock for gh.repositories() call
# This method is an iterator, so it should return an iterable (e.g., a list) # This method is an iterator, so it should return an iterable (e.g., a list)
@ -46,26 +51,38 @@ class TestGitHubConnector(unittest.TestCase):
repositories = connector.get_user_repositories() repositories = connector.get_user_repositories()
# Assert that gh.repositories was called correctly # Assert that gh.repositories was called correctly
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated') mock_gh_instance.repositories.assert_called_once_with(
type="all", sort="updated"
)
# Assert the structure and content of the returned data # Assert the structure and content of the returned data
expected_repositories = [ expected_repositories = [
{ {
"id": 1, "name": "repo1", "full_name": "user/repo1", "private": False, "id": 1,
"url": "http://example.com/user/repo1", "description": "Test repo 1", "name": "repo1",
"last_updated": datetime(2023, 1, 1, 10, 30, 0) "full_name": "user/repo1",
"private": False,
"url": "http://example.com/user/repo1",
"description": "Test repo 1",
"last_updated": datetime(2023, 1, 1, 10, 30, 0),
}, },
{ {
"id": 2, "name": "org-repo", "full_name": "org/org-repo", "private": True, "id": 2,
"url": "http://example.com/org/org-repo", "description": "Org repo", "name": "org-repo",
"last_updated": datetime(2023, 1, 2, 12, 0, 0) "full_name": "org/org-repo",
} "private": True,
"url": "http://example.com/org/org-repo",
"description": "Org repo",
"last_updated": datetime(2023, 1, 2, 12, 0, 0),
},
] ]
self.assertEqual(repositories, expected_repositories) self.assertEqual(repositories, expected_repositories)
self.assertEqual(len(repositories), 2) self.assertEqual(len(repositories), 2)
@patch('surfsense_backend.app.connectors.github_connector.github_login') @patch("surfsense_backend.app.connectors.github_connector.github_login")
def test_get_user_repositories_handles_empty_description_and_none_updated_at(self, mock_github_login): def test_get_user_repositories_handles_empty_description_and_none_updated_at(
self, mock_github_login
):
# Mock the GitHub client object and its methods # Mock the GitHub client object and its methods
mock_gh_instance = Mock() mock_gh_instance = Mock()
mock_github_login.return_value = mock_gh_instance mock_github_login.return_value = mock_gh_instance
@ -77,24 +94,30 @@ class TestGitHubConnector(unittest.TestCase):
mock_repo_data.full_name = "user/repo_no_desc" mock_repo_data.full_name = "user/repo_no_desc"
mock_repo_data.private = False mock_repo_data.private = False
mock_repo_data.html_url = "http://example.com/user/repo_no_desc" mock_repo_data.html_url = "http://example.com/user/repo_no_desc"
mock_repo_data.description = None # Test None description mock_repo_data.description = None # Test None description
mock_repo_data.updated_at = None # Test None updated_at mock_repo_data.updated_at = None # Test None updated_at
mock_gh_instance.repositories.return_value = [mock_repo_data] mock_gh_instance.repositories.return_value = [mock_repo_data]
connector = GitHubConnector(token="fake_token") connector = GitHubConnector(token="fake_token")
repositories = connector.get_user_repositories() repositories = connector.get_user_repositories()
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated') mock_gh_instance.repositories.assert_called_once_with(
type="all", sort="updated"
)
expected_repositories = [ expected_repositories = [
{ {
"id": 1, "name": "repo_no_desc", "full_name": "user/repo_no_desc", "private": False, "id": 1,
"url": "http://example.com/user/repo_no_desc", "description": "", # Expect empty string "name": "repo_no_desc",
"last_updated": None # Expect None "full_name": "user/repo_no_desc",
"private": False,
"url": "http://example.com/user/repo_no_desc",
"description": "", # Expect empty string
"last_updated": None, # Expect None
} }
] ]
self.assertEqual(repositories, expected_repositories) self.assertEqual(repositories, expected_repositories)
@patch('surfsense_backend.app.connectors.github_connector.github_login') @patch("surfsense_backend.app.connectors.github_connector.github_login")
def test_github_connector_initialization_failure_forbidden(self, mock_github_login): def test_github_connector_initialization_failure_forbidden(self, mock_github_login):
# Test that __init__ raises ValueError on auth failure (ForbiddenError) # Test that __init__ raises ValueError on auth failure (ForbiddenError)
mock_gh_instance = Mock() mock_gh_instance = Mock()
@ -104,17 +127,21 @@ class TestGitHubConnector(unittest.TestCase):
# The actual response structure might vary, but github3.py's ForbiddenError # The actual response structure might vary, but github3.py's ForbiddenError
# can be instantiated with just a response object that has a status_code. # can be instantiated with just a response object that has a status_code.
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 403 # Typically Forbidden mock_response.status_code = 403 # Typically Forbidden
# Setup the side_effect for self.gh.me() # Setup the side_effect for self.gh.me()
mock_gh_instance.me.side_effect = ForbiddenError(mock_response) mock_gh_instance.me.side_effect = ForbiddenError(mock_response)
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError) as context:
GitHubConnector(token="invalid_token_forbidden") GitHubConnector(token="invalid_token_forbidden")
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception)) self.assertIn(
"Invalid GitHub token or insufficient permissions.", str(context.exception)
)
@patch('surfsense_backend.app.connectors.github_connector.github_login') @patch("surfsense_backend.app.connectors.github_connector.github_login")
def test_github_connector_initialization_failure_authentication_failed(self, mock_github_login): def test_github_connector_initialization_failure_authentication_failed(
self, mock_github_login
):
# Test that __init__ raises ValueError on auth failure (AuthenticationFailed, which is a subclass of ForbiddenError) # Test that __init__ raises ValueError on auth failure (AuthenticationFailed, which is a subclass of ForbiddenError)
# For github3.py, AuthenticationFailed is more specific for token issues. # For github3.py, AuthenticationFailed is more specific for token issues.
from github3.exceptions import AuthenticationFailed from github3.exceptions import AuthenticationFailed
@ -123,15 +150,17 @@ class TestGitHubConnector(unittest.TestCase):
mock_github_login.return_value = mock_gh_instance mock_github_login.return_value = mock_gh_instance
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 401 # Typically Unauthorized mock_response.status_code = 401 # Typically Unauthorized
mock_gh_instance.me.side_effect = AuthenticationFailed(mock_response) mock_gh_instance.me.side_effect = AuthenticationFailed(mock_response)
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError) as context:
GitHubConnector(token="invalid_token_authfailed") GitHubConnector(token="invalid_token_authfailed")
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception)) self.assertIn(
"Invalid GitHub token or insufficient permissions.", str(context.exception)
)
@patch('surfsense_backend.app.connectors.github_connector.github_login') @patch("surfsense_backend.app.connectors.github_connector.github_login")
def test_get_user_repositories_handles_api_exception(self, mock_github_login): def test_get_user_repositories_handles_api_exception(self, mock_github_login):
mock_gh_instance = Mock() mock_gh_instance = Mock()
mock_github_login.return_value = mock_gh_instance mock_github_login.return_value = mock_gh_instance
@ -142,13 +171,18 @@ class TestGitHubConnector(unittest.TestCase):
connector = GitHubConnector(token="fake_token") connector = GitHubConnector(token="fake_token")
# We expect it to log an error and return an empty list # We expect it to log an error and return an empty list
with patch('surfsense_backend.app.connectors.github_connector.logger') as mock_logger: with patch(
"surfsense_backend.app.connectors.github_connector.logger"
) as mock_logger:
repositories = connector.get_user_repositories() repositories = connector.get_user_repositories()
self.assertEqual(repositories, []) self.assertEqual(repositories, [])
mock_logger.error.assert_called_once() mock_logger.error.assert_called_once()
self.assertIn("Failed to fetch GitHub repositories: API Error", mock_logger.error.call_args[0][0]) self.assertIn(
"Failed to fetch GitHub repositories: API Error",
mock_logger.error.call_args[0][0],
)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -1,35 +1,39 @@
import unittest import unittest
import time # Imported to be available for patching target module from unittest.mock import Mock, call, patch
from unittest.mock import patch, Mock, call
from slack_sdk.errors import SlackApiError from slack_sdk.errors import SlackApiError
# Since test_slack_history.py is in the same directory as slack_history.py # Since test_slack_history.py is in the same directory as slack_history.py
from .slack_history import SlackHistory from .slack_history import SlackHistory
class TestSlackHistoryGetAllChannels(unittest.TestCase):
@patch('surfsense_backend.app.connectors.slack_history.logger') class TestSlackHistoryGetAllChannels(unittest.TestCase):
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('slack_sdk.WebClient') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
def test_get_all_channels_pagination_with_delay(self, MockWebClient, mock_sleep, mock_logger): @patch("slack_sdk.WebClient")
mock_client_instance = MockWebClient.return_value def test_get_all_channels_pagination_with_delay(
self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
# Mock API responses now include is_private and is_member # Mock API responses now include is_private and is_member
page1_response = { page1_response = {
"channels": [ "channels": [
{"name": "general", "id": "C1", "is_private": False, "is_member": True}, {"name": "general", "id": "C1", "is_private": False, "is_member": True},
{"name": "dev", "id": "C0", "is_private": False, "is_member": True} {"name": "dev", "id": "C0", "is_private": False, "is_member": True},
], ],
"response_metadata": {"next_cursor": "cursor123"} "response_metadata": {"next_cursor": "cursor123"},
} }
page2_response = { page2_response = {
"channels": [{"name": "random", "id": "C2", "is_private": True, "is_member": True}], "channels": [
"response_metadata": {"next_cursor": ""} {"name": "random", "id": "C2", "is_private": True, "is_member": True}
],
"response_metadata": {"next_cursor": ""},
} }
mock_client_instance.conversations_list.side_effect = [ mock_client_instance.conversations_list.side_effect = [
page1_response, page1_response,
page2_response page2_response,
] ]
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
@ -38,129 +42,163 @@ class TestSlackHistoryGetAllChannels(unittest.TestCase):
expected_channels_list = [ expected_channels_list = [
{"id": "C1", "name": "general", "is_private": False, "is_member": True}, {"id": "C1", "name": "general", "is_private": False, "is_member": True},
{"id": "C0", "name": "dev", "is_private": False, "is_member": True}, {"id": "C0", "name": "dev", "is_private": False, "is_member": True},
{"id": "C2", "name": "random", "is_private": True, "is_member": True} {"id": "C2", "name": "random", "is_private": True, "is_member": True},
] ]
self.assertEqual(len(channels_list), 3) self.assertEqual(len(channels_list), 3)
self.assertListEqual(channels_list, expected_channels_list) # Assert list equality self.assertListEqual(
channels_list, expected_channels_list
) # Assert list equality
expected_calls = [ expected_calls = [
call(types="public_channel,private_channel", cursor=None, limit=1000), call(types="public_channel,private_channel", cursor=None, limit=1000),
call(types="public_channel,private_channel", cursor="cursor123", limit=1000) call(
types="public_channel,private_channel", cursor="cursor123", limit=1000
),
] ]
mock_client_instance.conversations_list.assert_has_calls(expected_calls) mock_client_instance.conversations_list.assert_has_calls(expected_calls)
self.assertEqual(mock_client_instance.conversations_list.call_count, 2) self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
mock_sleep.assert_called_once_with(3) mock_sleep.assert_called_once_with(3)
mock_logger.info.assert_called_once_with("Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123") mock_logger.info.assert_called_once_with(
"Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123"
)
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_get_all_channels_rate_limit_with_retry_after(self, MockWebClient, mock_sleep, mock_logger): def test_get_all_channels_rate_limit_with_retry_after(
mock_client_instance = MockWebClient.return_value self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock() mock_error_response = Mock()
mock_error_response.status_code = 429 mock_error_response.status_code = 429
mock_error_response.headers = {'Retry-After': '5'} mock_error_response.headers = {"Retry-After": "5"}
successful_response = { successful_response = {
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}], "channels": [
"response_metadata": {"next_cursor": ""} {"name": "general", "id": "C1", "is_private": False, "is_member": True}
],
"response_metadata": {"next_cursor": ""},
} }
mock_client_instance.conversations_list.side_effect = [ mock_client_instance.conversations_list.side_effect = [
SlackApiError(message="ratelimited", response=mock_error_response), SlackApiError(message="ratelimited", response=mock_error_response),
successful_response successful_response,
] ]
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True) channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}] expected_channels_list = [
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
]
self.assertEqual(len(channels_list), 1) self.assertEqual(len(channels_list), 1)
self.assertListEqual(channels_list, expected_channels_list) self.assertListEqual(channels_list, expected_channels_list)
mock_sleep.assert_called_once_with(5) mock_sleep.assert_called_once_with(5)
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None") mock_logger.warning.assert_called_once_with(
"Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None"
)
expected_calls = [ expected_calls = [
call(types="public_channel,private_channel", cursor=None, limit=1000), call(types="public_channel,private_channel", cursor=None, limit=1000),
call(types="public_channel,private_channel", cursor=None, limit=1000) call(types="public_channel,private_channel", cursor=None, limit=1000),
] ]
mock_client_instance.conversations_list.assert_has_calls(expected_calls) mock_client_instance.conversations_list.assert_has_calls(expected_calls)
self.assertEqual(mock_client_instance.conversations_list.call_count, 2) self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_get_all_channels_rate_limit_no_retry_after_valid_header(self, MockWebClient, mock_sleep, mock_logger): def test_get_all_channels_rate_limit_no_retry_after_valid_header(
mock_client_instance = MockWebClient.return_value self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock() mock_error_response = Mock()
mock_error_response.status_code = 429 mock_error_response.status_code = 429
mock_error_response.headers = {'Retry-After': 'invalid_value'} mock_error_response.headers = {"Retry-After": "invalid_value"}
successful_response = { successful_response = {
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}], "channels": [
"response_metadata": {"next_cursor": ""} {"name": "general", "id": "C1", "is_private": False, "is_member": True}
],
"response_metadata": {"next_cursor": ""},
} }
mock_client_instance.conversations_list.side_effect = [ mock_client_instance.conversations_list.side_effect = [
SlackApiError(message="ratelimited", response=mock_error_response), SlackApiError(message="ratelimited", response=mock_error_response),
successful_response successful_response,
] ]
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True) channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}] expected_channels_list = [
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
]
self.assertListEqual(channels_list, expected_channels_list) self.assertListEqual(channels_list, expected_channels_list)
mock_sleep.assert_called_once_with(60) # Default fallback mock_sleep.assert_called_once_with(60) # Default fallback
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None") mock_logger.warning.assert_called_once_with(
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
)
self.assertEqual(mock_client_instance.conversations_list.call_count, 2) self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_get_all_channels_rate_limit_no_retry_after_header(self, MockWebClient, mock_sleep, mock_logger): def test_get_all_channels_rate_limit_no_retry_after_header(
mock_client_instance = MockWebClient.return_value self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock() mock_error_response = Mock()
mock_error_response.status_code = 429 mock_error_response.status_code = 429
mock_error_response.headers = {} mock_error_response.headers = {}
successful_response = { successful_response = {
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}], "channels": [
"response_metadata": {"next_cursor": ""} {"name": "general", "id": "C1", "is_private": False, "is_member": True}
],
"response_metadata": {"next_cursor": ""},
} }
mock_client_instance.conversations_list.side_effect = [ mock_client_instance.conversations_list.side_effect = [
SlackApiError(message="ratelimited", response=mock_error_response), SlackApiError(message="ratelimited", response=mock_error_response),
successful_response successful_response,
] ]
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True) channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}] expected_channels_list = [
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
]
self.assertListEqual(channels_list, expected_channels_list) self.assertListEqual(channels_list, expected_channels_list)
mock_sleep.assert_called_once_with(60) # Default fallback mock_sleep.assert_called_once_with(60) # Default fallback
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None") mock_logger.warning.assert_called_once_with(
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
)
self.assertEqual(mock_client_instance.conversations_list.call_count, 2) self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_get_all_channels_other_slack_api_error(self, MockWebClient, mock_sleep, mock_logger): def test_get_all_channels_other_slack_api_error(
mock_client_instance = MockWebClient.return_value self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock() mock_error_response = Mock()
mock_error_response.status_code = 500 mock_error_response.status_code = 500
mock_error_response.headers = {} mock_error_response.headers = {}
mock_error_response.data = {"ok": False, "error": "internal_error"} mock_error_response.data = {"ok": False, "error": "internal_error"}
original_error = SlackApiError(message="server error", response=mock_error_response) original_error = SlackApiError(
message="server error", response=mock_error_response
)
mock_client_instance.conversations_list.side_effect = original_error mock_client_instance.conversations_list.side_effect = original_error
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
@ -171,81 +209,101 @@ class TestSlackHistoryGetAllChannels(unittest.TestCase):
self.assertEqual(context.exception.response.status_code, 500) self.assertEqual(context.exception.response.status_code, 500)
self.assertIn("server error", str(context.exception)) self.assertIn("server error", str(context.exception))
mock_sleep.assert_not_called() mock_sleep.assert_not_called()
mock_logger.warning.assert_not_called() # Ensure no rate limit log mock_logger.warning.assert_not_called() # Ensure no rate limit log
mock_client_instance.conversations_list.assert_called_once_with( mock_client_instance.conversations_list.assert_called_once_with(
types="public_channel,private_channel", cursor=None, limit=1000 types="public_channel,private_channel", cursor=None, limit=1000
) )
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_get_all_channels_handles_missing_name_id_gracefully(self, MockWebClient, mock_sleep, mock_logger): def test_get_all_channels_handles_missing_name_id_gracefully(
mock_client_instance = MockWebClient.return_value self, mock_web_client, mock_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
response_with_malformed_data = { response_with_malformed_data = {
"channels": [ "channels": [
{"id": "C1_missing_name", "is_private": False, "is_member": True}, {"id": "C1_missing_name", "is_private": False, "is_member": True},
{"name": "channel_missing_id", "is_private": False, "is_member": True}, {"name": "channel_missing_id", "is_private": False, "is_member": True},
{"name": "general", "id": "C2_valid", "is_private": False, "is_member": True} {
"name": "general",
"id": "C2_valid",
"is_private": False,
"is_member": True,
},
], ],
"response_metadata": {"next_cursor": ""} "response_metadata": {"next_cursor": ""},
} }
mock_client_instance.conversations_list.return_value = response_with_malformed_data mock_client_instance.conversations_list.return_value = (
response_with_malformed_data
)
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
channels_list = slack_history.get_all_channels(include_private=True) channels_list = slack_history.get_all_channels(include_private=True)
expected_channels_list = [ expected_channels_list = [
{"id": "C2_valid", "name": "general", "is_private": False, "is_member": True} {
"id": "C2_valid",
"name": "general",
"is_private": False,
"is_member": True,
}
] ]
self.assertEqual(len(channels_list), 1) self.assertEqual(len(channels_list), 1)
self.assertListEqual(channels_list, expected_channels_list) self.assertListEqual(channels_list, expected_channels_list)
self.assertEqual(mock_logger.warning.call_count, 2) self.assertEqual(mock_logger.warning.call_count, 2)
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}") mock_logger.warning.assert_any_call(
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}") "Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}"
)
mock_logger.warning.assert_any_call(
"Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}"
)
mock_sleep.assert_not_called() mock_sleep.assert_not_called()
mock_client_instance.conversations_list.assert_called_once_with( mock_client_instance.conversations_list.assert_called_once_with(
types="public_channel,private_channel", cursor=None, limit=1000 types="public_channel,private_channel", cursor=None, limit=1000
) )
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()
class TestSlackHistoryGetConversationHistory(unittest.TestCase):
@patch('surfsense_backend.app.connectors.slack_history.logger') class TestSlackHistoryGetConversationHistory(unittest.TestCase):
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('slack_sdk.WebClient') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
def test_proactive_delay_single_page(self, MockWebClient, mock_time_sleep, mock_logger): @patch("slack_sdk.WebClient")
mock_client_instance = MockWebClient.return_value def test_proactive_delay_single_page(
self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_client_instance.conversations_history.return_value = { mock_client_instance.conversations_history.return_value = {
"messages": [{"text": "msg1"}], "messages": [{"text": "msg1"}],
"has_more": False "has_more": False,
} }
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
slack_history.get_conversation_history(channel_id="C123") slack_history.get_conversation_history(channel_id="C123")
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_proactive_delay_multiple_pages(self, MockWebClient, mock_time_sleep, mock_logger): def test_proactive_delay_multiple_pages(
mock_client_instance = MockWebClient.return_value self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_client_instance.conversations_history.side_effect = [ mock_client_instance.conversations_history.side_effect = [
{ {
"messages": [{"text": "msg1"}], "messages": [{"text": "msg1"}],
"has_more": True, "has_more": True,
"response_metadata": {"next_cursor": "cursor1"} "response_metadata": {"next_cursor": "cursor1"},
}, },
{ {"messages": [{"text": "msg2"}], "has_more": False},
"messages": [{"text": "msg2"}],
"has_more": False
}
] ]
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
@ -255,19 +313,19 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
self.assertEqual(mock_time_sleep.call_count, 2) self.assertEqual(mock_time_sleep.call_count, 2)
mock_time_sleep.assert_has_calls([call(1.2), call(1.2)]) mock_time_sleep.assert_has_calls([call(1.2), call(1.2)])
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger): def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value mock_client_instance = mock_web_client.return_value
mock_error_response = Mock() mock_error_response = Mock()
mock_error_response.status_code = 429 mock_error_response.status_code = 429
mock_error_response.headers = {'Retry-After': '5'} mock_error_response.headers = {"Retry-After": "5"}
mock_client_instance.conversations_history.side_effect = [ mock_client_instance.conversations_history.side_effect = [
SlackApiError(message="ratelimited", response=mock_error_response), SlackApiError(message="ratelimited", response=mock_error_response),
{"messages": [{"text": "msg1"}], "has_more": False} {"messages": [{"text": "msg1"}], "has_more": False},
] ]
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
@ -277,23 +335,26 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
self.assertEqual(messages[0]["text"], "msg1") self.assertEqual(messages[0]["text"], "msg1")
# Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt) # Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt)
mock_time_sleep.assert_has_calls([call(1.2), call(5), call(1.2)], any_order=False) mock_time_sleep.assert_has_calls(
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting [call(1.2), call(5), call(1.2)], any_order=False
)
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_not_in_channel_error(self, MockWebClient, mock_time_sleep, mock_logger): def test_not_in_channel_error(self, mock_web_client, mock_time_sleep, mock_logger):
mock_client_instance = MockWebClient.return_value mock_client_instance = mock_web_client.return_value
mock_error_response = Mock() mock_error_response = Mock()
mock_error_response.status_code = 403 # Typical for not_in_channel, but data matters more mock_error_response.status_code = (
mock_error_response.data = {'ok': False, 'error': 'not_in_channel'} 403 # Typical for not_in_channel, but data matters more
)
mock_error_response.data = {"ok": False, "error": "not_in_channel"}
# This error is now raised by the inner try-except, then caught by the outer one # This error is now raised by the inner try-except, then caught by the outer one
mock_client_instance.conversations_history.side_effect = SlackApiError( mock_client_instance.conversations_history.side_effect = SlackApiError(
message="not_in_channel error", message="not_in_channel error", response=mock_error_response
response=mock_error_response
) )
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
@ -303,18 +364,24 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
mock_logger.warning.assert_called_with( mock_logger.warning.assert_called_with(
"Bot is not in channel 'C123'. Cannot fetch history. Please add the bot to this channel." "Bot is not in channel 'C123'. Cannot fetch history. Please add the bot to this channel."
) )
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay before the API call mock_time_sleep.assert_called_once_with(
1.2
) # Proactive delay before the API call
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger): def test_other_slack_api_error_propagates(
mock_client_instance = MockWebClient.return_value self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock() mock_error_response = Mock()
mock_error_response.status_code = 500 mock_error_response.status_code = 500
mock_error_response.data = {'ok': False, 'error': 'internal_error'} mock_error_response.data = {"ok": False, "error": "internal_error"}
original_error = SlackApiError(message="server error", response=mock_error_response) original_error = SlackApiError(
message="server error", response=mock_error_response
)
mock_client_instance.conversations_history.side_effect = original_error mock_client_instance.conversations_history.side_effect = original_error
@ -323,44 +390,52 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
with self.assertRaises(SlackApiError) as context: with self.assertRaises(SlackApiError) as context:
slack_history.get_conversation_history(channel_id="C123") slack_history.get_conversation_history(channel_id="C123")
self.assertIn("Error retrieving history for channel C123", str(context.exception)) self.assertIn(
"Error retrieving history for channel C123", str(context.exception)
)
self.assertIs(context.exception.response, mock_error_response) self.assertIs(context.exception.response, mock_error_response)
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger): def test_general_exception_propagates(
mock_client_instance = MockWebClient.return_value self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
original_error = Exception("Something broke") original_error = Exception("Something broke")
mock_client_instance.conversations_history.side_effect = original_error mock_client_instance.conversations_history.side_effect = original_error
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
with self.assertRaises(Exception) as context: # Check for generic Exception with self.assertRaises(Exception) as context: # Check for generic Exception
slack_history.get_conversation_history(channel_id="C123") slack_history.get_conversation_history(channel_id="C123")
self.assertIs(context.exception, original_error) # Should re-raise the original error self.assertIs(
mock_logger.error.assert_called_once_with("Unexpected error in get_conversation_history for channel C123: Something broke") context.exception, original_error
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay ) # Should re-raise the original error
mock_logger.error.assert_called_once_with(
"Unexpected error in get_conversation_history for channel C123: Something broke"
)
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
class TestSlackHistoryGetUserInfo(unittest.TestCase): class TestSlackHistoryGetUserInfo(unittest.TestCase):
@patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("slack_sdk.WebClient")
@patch('slack_sdk.WebClient') def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger): mock_client_instance = mock_web_client.return_value
mock_client_instance = MockWebClient.return_value
mock_error_response = Mock() mock_error_response = Mock()
mock_error_response.status_code = 429 mock_error_response.status_code = 429
mock_error_response.headers = {'Retry-After': '3'} # Using 3 seconds for test mock_error_response.headers = {"Retry-After": "3"} # Using 3 seconds for test
successful_user_data = {"id": "U123", "name": "testuser"} successful_user_data = {"id": "U123", "name": "testuser"}
mock_client_instance.users_info.side_effect = [ mock_client_instance.users_info.side_effect = [
SlackApiError(message="ratelimited_userinfo", response=mock_error_response), SlackApiError(message="ratelimited_userinfo", response=mock_error_response),
{"user": successful_user_data} {"user": successful_user_data},
] ]
slack_history = SlackHistory(token="fake_token") slack_history = SlackHistory(token="fake_token")
@ -375,18 +450,26 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase):
) )
# Assert users_info was called twice (original + retry) # Assert users_info was called twice (original + retry)
self.assertEqual(mock_client_instance.users_info.call_count, 2) self.assertEqual(mock_client_instance.users_info.call_count, 2)
mock_client_instance.users_info.assert_has_calls([call(user="U123"), call(user="U123")]) mock_client_instance.users_info.assert_has_calls(
[call(user="U123"), call(user="U123")]
)
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') # time.sleep might be called by other logic, but not expected here @patch(
@patch('slack_sdk.WebClient') "surfsense_backend.app.connectors.slack_history.time.sleep"
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger): ) # time.sleep might be called by other logic, but not expected here
mock_client_instance = MockWebClient.return_value @patch("slack_sdk.WebClient")
def test_other_slack_api_error_propagates(
self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
mock_error_response = Mock() mock_error_response = Mock()
mock_error_response.status_code = 500 # Some other error mock_error_response.status_code = 500 # Some other error
mock_error_response.data = {'ok': False, 'error': 'internal_server_error'} mock_error_response.data = {"ok": False, "error": "internal_server_error"}
original_error = SlackApiError(message="internal server error", response=mock_error_response) original_error = SlackApiError(
message="internal server error", response=mock_error_response
)
mock_client_instance.users_info.side_effect = original_error mock_client_instance.users_info.side_effect = original_error
@ -398,13 +481,15 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase):
# Check that the raised error is the one we expect # Check that the raised error is the one we expect
self.assertIn("Error retrieving user info for U123", str(context.exception)) self.assertIn("Error retrieving user info for U123", str(context.exception))
self.assertIs(context.exception.response, mock_error_response) self.assertIs(context.exception.response, mock_error_response)
mock_time_sleep.assert_not_called() # No rate limit sleep mock_time_sleep.assert_not_called() # No rate limit sleep
@patch('surfsense_backend.app.connectors.slack_history.logger') @patch("surfsense_backend.app.connectors.slack_history.logger")
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') @patch("surfsense_backend.app.connectors.slack_history.time.sleep")
@patch('slack_sdk.WebClient') @patch("slack_sdk.WebClient")
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger): def test_general_exception_propagates(
mock_client_instance = MockWebClient.return_value self, mock_web_client, mock_time_sleep, mock_logger
):
mock_client_instance = mock_web_client.return_value
original_error = Exception("A very generic problem") original_error = Exception("A very generic problem")
mock_client_instance.users_info.side_effect = original_error mock_client_instance.users_info.side_effect = original_error
@ -413,8 +498,10 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase):
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
slack_history.get_user_info(user_id="U123") slack_history.get_user_info(user_id="U123")
self.assertIs(context.exception, original_error) # Check it's the exact same exception self.assertIs(
context.exception, original_error
) # Check it's the exact same exception
mock_logger.error.assert_called_once_with( mock_logger.error.assert_called_once_with(
"Unexpected error in get_user_info for user U123: A very generic problem" "Unexpected error in get_user_info for user U123: A very generic problem"
) )
mock_time_sleep.assert_not_called() # No rate limit sleep mock_time_sleep.assert_not_called() # No rate limit sleep

View file

@ -1,11 +1,9 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import datetime, timezone from datetime import UTC, datetime
from enum import Enum from enum import Enum
from app.config import config
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
from fastapi import Depends from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from sqlalchemy import ( from sqlalchemy import (
ARRAY, ARRAY,
@ -13,9 +11,7 @@ from sqlalchemy import (
TIMESTAMP, TIMESTAMP,
Boolean, Boolean,
Column, Column,
) Enum as SQLAlchemyEnum,
from sqlalchemy import Enum as SQLAlchemyEnum
from sqlalchemy import (
ForeignKey, ForeignKey,
Integer, Integer,
String, String,
@ -26,17 +22,12 @@ from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
from app.config import config
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
if config.AUTH_TYPE == "GOOGLE": if config.AUTH_TYPE == "GOOGLE":
from fastapi_users.db import ( from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
SQLAlchemyBaseOAuthAccountTableUUID,
SQLAlchemyBaseUserTableUUID,
SQLAlchemyUserDatabase,
)
else:
from fastapi_users.db import (
SQLAlchemyBaseUserTableUUID,
SQLAlchemyUserDatabase,
)
DATABASE_URL = config.DATABASE_URL DATABASE_URL = config.DATABASE_URL
@ -118,11 +109,11 @@ class Base(DeclarativeBase):
class TimestampMixin: class TimestampMixin:
@declared_attr @declared_attr
def created_at(cls): def created_at(cls): # noqa: N805
return Column( return Column(
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
default=lambda: datetime.now(timezone.utc), default=lambda: datetime.now(UTC),
index=True, index=True,
) )

View file

@ -1,9 +1,12 @@
from datetime import UTC, datetime
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from datetime import datetime, timezone
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n' DATE_TODAY = "Today's date is " + datetime.now(UTC).astimezone().isoformat() + "\n"
SUMMARY_PROMPT = DATE_TODAY + """ SUMMARY_PROMPT = (
DATE_TODAY
+ """
<INSTRUCTIONS> <INSTRUCTIONS>
<context> <context>
You are an expert document analyst and summarization specialist tasked with distilling complex information into clear, You are an expert document analyst and summarization specialist tasked with distilling complex information into clear,
@ -96,8 +99,8 @@ SUMMARY_PROMPT = DATE_TODAY + """
</document_to_summarize> </document_to_summarize>
</INSTRUCTIONS> </INSTRUCTIONS>
""" """
)
SUMMARY_PROMPT_TEMPLATE = PromptTemplate( SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
input_variables=["document"], input_variables=["document"], template=SUMMARY_PROMPT
template=SUMMARY_PROMPT
) )

View file

@ -8,7 +8,13 @@ class ChucksHybridSearchRetriever:
""" """
self.db_session = db_session self.db_session = db_session
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list: async def vector_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
) -> list:
""" """
Perform vector similarity search on chunks. Perform vector similarity search on chunks.
@ -21,10 +27,11 @@ class ChucksHybridSearchRetriever:
Returns: Returns:
List of chunks sorted by vector similarity List of chunks sorted by vector similarity
""" """
from sqlalchemy import select, func from sqlalchemy import select
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, SearchSpace
from app.config import config from app.config import config
from app.db import Chunk, Document, SearchSpace
# Get embedding for the query # Get embedding for the query
embedding_model = config.embedding_model_instance embedding_model = config.embedding_model_instance
@ -44,11 +51,7 @@ class ChucksHybridSearchRetriever:
query = query.where(Document.search_space_id == search_space_id) query = query.where(Document.search_space_id == search_space_id)
# Add vector similarity ordering # Add vector similarity ordering
query = ( query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
query
.order_by(Chunk.embedding.op("<=>")(query_embedding))
.limit(top_k)
)
# Execute the query # Execute the query
result = await self.db_session.execute(query) result = await self.db_session.execute(query)
@ -56,7 +59,13 @@ class ChucksHybridSearchRetriever:
return chunks return chunks
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list: async def full_text_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
) -> list:
""" """
Perform full-text keyword search on chunks. Perform full-text keyword search on chunks.
@ -69,13 +78,14 @@ class ChucksHybridSearchRetriever:
Returns: Returns:
List of chunks sorted by text relevance List of chunks sorted by text relevance
""" """
from sqlalchemy import select, func, text from sqlalchemy import func, select
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, SearchSpace from app.db import Chunk, Document, SearchSpace
# Create tsvector and tsquery for PostgreSQL full-text search # Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Chunk.content) tsvector = func.to_tsvector("english", Chunk.content)
tsquery = func.plainto_tsquery('english', query_text) tsquery = func.plainto_tsquery("english", query_text)
# Build the base query with user ownership check # Build the base query with user ownership check
query = ( query = (
@ -84,7 +94,9 @@ class ChucksHybridSearchRetriever:
.join(Document, Chunk.document_id == Document.id) .join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id) .join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(SearchSpace.user_id == user_id) .where(SearchSpace.user_id == user_id)
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query .where(
tsvector.op("@@")(tsquery)
) # Only include results that match the query
) )
# Add search space filter if provided # Add search space filter if provided
@ -92,11 +104,7 @@ class ChucksHybridSearchRetriever:
query = query.where(Document.search_space_id == search_space_id) query = query.where(Document.search_space_id == search_space_id)
# Add text search ranking # Add text search ranking
query = ( query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
query
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(top_k)
)
# Execute the query # Execute the query
result = await self.db_session.execute(query) result = await self.db_session.execute(query)
@ -104,7 +112,14 @@ class ChucksHybridSearchRetriever:
return chunks return chunks
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list: async def hybrid_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
document_type: str | None = None,
) -> list:
""" """
Combine vector similarity and full-text search results using Reciprocal Rank Fusion. Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
@ -118,10 +133,11 @@ class ChucksHybridSearchRetriever:
Returns: Returns:
List of dictionaries containing chunk data and relevance scores List of dictionaries containing chunk data and relevance scores
""" """
from sqlalchemy import select, func, text from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, SearchSpace, DocumentType
from app.config import config from app.config import config
from app.db import Chunk, Document, DocumentType, SearchSpace
# Get embedding for the query # Get embedding for the query
embedding_model = config.embedding_model_instance embedding_model = config.embedding_model_instance
@ -132,8 +148,8 @@ class ChucksHybridSearchRetriever:
n_results = top_k * 2 # Get more results for better fusion n_results = top_k * 2 # Get more results for better fusion
# Create tsvector and tsquery for PostgreSQL full-text search # Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Chunk.content) tsvector = func.to_tsvector("english", Chunk.content)
tsquery = func.plainto_tsquery('english', query_text) tsquery = func.plainto_tsquery("english", query_text)
# Base conditions for document filtering # Base conditions for document filtering
base_conditions = [SearchSpace.user_id == user_id] base_conditions = [SearchSpace.user_id == user_id]
@ -159,7 +175,9 @@ class ChucksHybridSearchRetriever:
semantic_search_cte = ( semantic_search_cte = (
select( select(
Chunk.id, Chunk.id,
func.rank().over(order_by=Chunk.embedding.op("<=>")(query_embedding)).label("rank") func.rank()
.over(order_by=Chunk.embedding.op("<=>")(query_embedding))
.label("rank"),
) )
.join(Document, Chunk.document_id == Document.id) .join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id) .join(SearchSpace, Document.search_space_id == SearchSpace.id)
@ -167,8 +185,7 @@ class ChucksHybridSearchRetriever:
) )
semantic_search_cte = ( semantic_search_cte = (
semantic_search_cte semantic_search_cte.order_by(Chunk.embedding.op("<=>")(query_embedding))
.order_by(Chunk.embedding.op("<=>")(query_embedding))
.limit(n_results) .limit(n_results)
.cte("semantic_search") .cte("semantic_search")
) )
@ -177,7 +194,9 @@ class ChucksHybridSearchRetriever:
keyword_search_cte = ( keyword_search_cte = (
select( select(
Chunk.id, Chunk.id,
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank") func.rank()
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
.label("rank"),
) )
.join(Document, Chunk.document_id == Document.id) .join(Document, Chunk.document_id == Document.id)
.join(SearchSpace, Document.search_space_id == SearchSpace.id) .join(SearchSpace, Document.search_space_id == SearchSpace.id)
@ -186,8 +205,7 @@ class ChucksHybridSearchRetriever:
) )
keyword_search_cte = ( keyword_search_cte = (
keyword_search_cte keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(n_results) .limit(n_results)
.cte("keyword_search") .cte("keyword_search")
) )
@ -197,20 +215,21 @@ class ChucksHybridSearchRetriever:
select( select(
Chunk, Chunk,
( (
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) + func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0) + func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
).label("score") ).label("score"),
) )
.select_from( .select_from(
semantic_search_cte.outerjoin( semantic_search_cte.outerjoin(
keyword_search_cte, keyword_search_cte,
semantic_search_cte.c.id == keyword_search_cte.c.id, semantic_search_cte.c.id == keyword_search_cte.c.id,
full=True full=True,
) )
) )
.join( .join(
Chunk, Chunk,
Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id) Chunk.id
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
) )
.options(joinedload(Chunk.document)) .options(joinedload(Chunk.document))
.order_by(text("score DESC")) .order_by(text("score DESC"))
@ -228,16 +247,20 @@ class ChucksHybridSearchRetriever:
# Convert to serializable dictionaries if no reranker is available or if reranking failed # Convert to serializable dictionaries if no reranker is available or if reranking failed
serialized_results = [] serialized_results = []
for chunk, score in chunks_with_scores: for chunk, score in chunks_with_scores:
serialized_results.append({ serialized_results.append(
"chunk_id": chunk.id, {
"content": chunk.content, "chunk_id": chunk.id,
"score": float(score), # Ensure score is a Python float "content": chunk.content,
"document": { "score": float(score), # Ensure score is a Python float
"id": chunk.document.id, "document": {
"title": chunk.document.title, "id": chunk.document.id,
"document_type": chunk.document.document_type.value if hasattr(chunk.document, 'document_type') else None, "title": chunk.document.title,
"metadata": chunk.document.document_metadata "document_type": chunk.document.document_type.value
if hasattr(chunk.document, "document_type")
else None,
"metadata": chunk.document.document_metadata,
},
} }
}) )
return serialized_results return serialized_results

View file

@ -8,7 +8,13 @@ class DocumentHybridSearchRetriever:
""" """
self.db_session = db_session self.db_session = db_session
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list: async def vector_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
) -> list:
""" """
Perform vector similarity search on documents. Perform vector similarity search on documents.
@ -21,10 +27,11 @@ class DocumentHybridSearchRetriever:
Returns: Returns:
List of documents sorted by vector similarity List of documents sorted by vector similarity
""" """
from sqlalchemy import select, func from sqlalchemy import select
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.db import Document, SearchSpace
from app.config import config from app.config import config
from app.db import Document, SearchSpace
# Get embedding for the query # Get embedding for the query
embedding_model = config.embedding_model_instance embedding_model = config.embedding_model_instance
@ -43,10 +50,8 @@ class DocumentHybridSearchRetriever:
query = query.where(Document.search_space_id == search_space_id) query = query.where(Document.search_space_id == search_space_id)
# Add vector similarity ordering # Add vector similarity ordering
query = ( query = query.order_by(Document.embedding.op("<=>")(query_embedding)).limit(
query top_k
.order_by(Document.embedding.op("<=>")(query_embedding))
.limit(top_k)
) )
# Execute the query # Execute the query
@ -55,7 +60,13 @@ class DocumentHybridSearchRetriever:
return documents return documents
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list: async def full_text_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
) -> list:
""" """
Perform full-text keyword search on documents. Perform full-text keyword search on documents.
@ -68,13 +79,14 @@ class DocumentHybridSearchRetriever:
Returns: Returns:
List of documents sorted by text relevance List of documents sorted by text relevance
""" """
from sqlalchemy import select, func, text from sqlalchemy import func, select
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.db import Document, SearchSpace from app.db import Document, SearchSpace
# Create tsvector and tsquery for PostgreSQL full-text search # Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Document.content) tsvector = func.to_tsvector("english", Document.content)
tsquery = func.plainto_tsquery('english', query_text) tsquery = func.plainto_tsquery("english", query_text)
# Build the base query with user ownership check # Build the base query with user ownership check
query = ( query = (
@ -82,7 +94,9 @@ class DocumentHybridSearchRetriever:
.options(joinedload(Document.search_space)) .options(joinedload(Document.search_space))
.join(SearchSpace, Document.search_space_id == SearchSpace.id) .join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(SearchSpace.user_id == user_id) .where(SearchSpace.user_id == user_id)
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query .where(
tsvector.op("@@")(tsquery)
) # Only include results that match the query
) )
# Add search space filter if provided # Add search space filter if provided
@ -90,11 +104,7 @@ class DocumentHybridSearchRetriever:
query = query.where(Document.search_space_id == search_space_id) query = query.where(Document.search_space_id == search_space_id)
# Add text search ranking # Add text search ranking
query = ( query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
query
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(top_k)
)
# Execute the query # Execute the query
result = await self.db_session.execute(query) result = await self.db_session.execute(query)
@ -102,7 +112,14 @@ class DocumentHybridSearchRetriever:
return documents return documents
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list: async def hybrid_search(
self,
query_text: str,
top_k: int,
user_id: str,
search_space_id: int | None = None,
document_type: str | None = None,
) -> list:
""" """
Combine vector similarity and full-text search results using Reciprocal Rank Fusion. Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
@ -114,10 +131,11 @@ class DocumentHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
""" """
from sqlalchemy import select, func, text from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.db import Document, SearchSpace, DocumentType
from app.config import config from app.config import config
from app.db import Document, DocumentType, SearchSpace
# Get embedding for the query # Get embedding for the query
embedding_model = config.embedding_model_instance embedding_model = config.embedding_model_instance
@ -128,8 +146,8 @@ class DocumentHybridSearchRetriever:
n_results = top_k * 2 # Get more results for better fusion n_results = top_k * 2 # Get more results for better fusion
# Create tsvector and tsquery for PostgreSQL full-text search # Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector('english', Document.content) tsvector = func.to_tsvector("english", Document.content)
tsquery = func.plainto_tsquery('english', query_text) tsquery = func.plainto_tsquery("english", query_text)
# Base conditions for document filtering # Base conditions for document filtering
base_conditions = [SearchSpace.user_id == user_id] base_conditions = [SearchSpace.user_id == user_id]
@ -155,15 +173,16 @@ class DocumentHybridSearchRetriever:
semantic_search_cte = ( semantic_search_cte = (
select( select(
Document.id, Document.id,
func.rank().over(order_by=Document.embedding.op("<=>")(query_embedding)).label("rank") func.rank()
.over(order_by=Document.embedding.op("<=>")(query_embedding))
.label("rank"),
) )
.join(SearchSpace, Document.search_space_id == SearchSpace.id) .join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(*base_conditions) .where(*base_conditions)
) )
semantic_search_cte = ( semantic_search_cte = (
semantic_search_cte semantic_search_cte.order_by(Document.embedding.op("<=>")(query_embedding))
.order_by(Document.embedding.op("<=>")(query_embedding))
.limit(n_results) .limit(n_results)
.cte("semantic_search") .cte("semantic_search")
) )
@ -172,7 +191,9 @@ class DocumentHybridSearchRetriever:
keyword_search_cte = ( keyword_search_cte = (
select( select(
Document.id, Document.id,
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank") func.rank()
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
.label("rank"),
) )
.join(SearchSpace, Document.search_space_id == SearchSpace.id) .join(SearchSpace, Document.search_space_id == SearchSpace.id)
.where(*base_conditions) .where(*base_conditions)
@ -180,8 +201,7 @@ class DocumentHybridSearchRetriever:
) )
keyword_search_cte = ( keyword_search_cte = (
keyword_search_cte keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(n_results) .limit(n_results)
.cte("keyword_search") .cte("keyword_search")
) )
@ -191,20 +211,21 @@ class DocumentHybridSearchRetriever:
select( select(
Document, Document,
( (
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) + func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0) + func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
).label("score") ).label("score"),
) )
.select_from( .select_from(
semantic_search_cte.outerjoin( semantic_search_cte.outerjoin(
keyword_search_cte, keyword_search_cte,
semantic_search_cte.c.id == keyword_search_cte.c.id, semantic_search_cte.c.id == keyword_search_cte.c.id,
full=True full=True,
) )
) )
.join( .join(
Document, Document,
Document.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id) Document.id
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
) )
.options(joinedload(Document.search_space)) .options(joinedload(Document.search_space))
.order_by(text("score DESC")) .order_by(text("score DESC"))
@ -224,24 +245,35 @@ class DocumentHybridSearchRetriever:
for document, score in documents_with_scores: for document, score in documents_with_scores:
# Fetch associated chunks for this document # Fetch associated chunks for this document
from sqlalchemy import select from sqlalchemy import select
from app.db import Chunk from app.db import Chunk
chunks_query = select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id) chunks_query = (
select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
)
chunks_result = await self.db_session.execute(chunks_query) chunks_result = await self.db_session.execute(chunks_query)
chunks = chunks_result.scalars().all() chunks = chunks_result.scalars().all()
# Concatenate chunks content # Concatenate chunks content
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else document.content concatenated_chunks_content = (
" ".join([chunk.content for chunk in chunks])
if chunks
else document.content
)
serialized_results.append({ serialized_results.append(
"document_id": document.id, {
"title": document.title, "document_id": document.id,
"content": document.content, "title": document.title,
"chunks_content": concatenated_chunks_content, "content": document.content,
"document_type": document.document_type.value if hasattr(document, 'document_type') else None, "chunks_content": concatenated_chunks_content,
"metadata": document.document_metadata, "document_type": document.document_type.value
"score": float(score), # Ensure score is a Python float if hasattr(document, "document_type")
"search_space_id": document.search_space_id else None,
}) "metadata": document.document_metadata,
"score": float(score), # Ensure score is a Python float
"search_space_id": document.search_space_id,
}
)
return serialized_results return serialized_results

View file

@ -1,11 +1,12 @@
from fastapi import APIRouter from fastapi import APIRouter
from .search_spaces_routes import router as search_spaces_router
from .documents_routes import router as documents_router
from .podcasts_routes import router as podcasts_router
from .chats_routes import router as chats_router from .chats_routes import router as chats_router
from .search_source_connectors_routes import router as search_source_connectors_router from .documents_routes import router as documents_router
from .llm_config_routes import router as llm_config_router from .llm_config_routes import router as llm_config_router
from .logs_routes import router as logs_router from .logs_routes import router as logs_router
from .podcasts_routes import router as podcasts_router
from .search_source_connectors_routes import router as search_source_connectors_router
from .search_spaces_routes import router as search_spaces_router
router = APIRouter() router = APIRouter()

View file

@ -1,38 +1,40 @@
from typing import List from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from langchain.schema import AIMessage, HumanMessage
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import Chat, SearchSpace, User, get_async_session from app.db import Chat, SearchSpace, User, get_async_session
from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate
from app.tasks.stream_connector_search_results import stream_connector_search_results from app.tasks.stream_connector_search_results import stream_connector_search_results
from app.users import current_active_user from app.users import current_active_user
from app.utils.check_ownership import check_ownership from app.utils.check_ownership import check_ownership
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from langchain.schema import HumanMessage, AIMessage
router = APIRouter() router = APIRouter()
@router.post("/chat") @router.post("/chat")
async def handle_chat_data( async def handle_chat_data(
request: AISDKChatRequest, request: AISDKChatRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
messages = request.messages messages = request.messages
if messages[-1]['role'] != "user": if messages[-1]["role"] != "user":
raise HTTPException( raise HTTPException(
status_code=400, detail="Last message must be a user message") status_code=400, detail="Last message must be a user message"
)
user_query = messages[-1]['content'] user_query = messages[-1]["content"]
search_space_id = request.data.get('search_space_id') search_space_id = request.data.get("search_space_id")
research_mode: str = request.data.get('research_mode') research_mode: str = request.data.get("research_mode")
selected_connectors: List[str] = request.data.get('selected_connectors') selected_connectors: list[str] = request.data.get("selected_connectors")
document_ids_to_add_in_context: List[int] = request.data.get('document_ids_to_add_in_context') document_ids_to_add_in_context: list[int] = request.data.get(
"document_ids_to_add_in_context"
)
search_mode_str = request.data.get('search_mode', "CHUNKS") search_mode_str = request.data.get("search_mode", "CHUNKS")
# Convert search_space_id to integer if it's a string # Convert search_space_id to integer if it's a string
if search_space_id and isinstance(search_space_id, str): if search_space_id and isinstance(search_space_id, str):
@ -40,21 +42,23 @@ async def handle_chat_data(
search_space_id = int(search_space_id) search_space_id = int(search_space_id)
except ValueError: except ValueError:
raise HTTPException( raise HTTPException(
status_code=400, detail="Invalid search_space_id format") status_code=400, detail="Invalid search_space_id format"
) from None
# Check if the search space belongs to the current user # Check if the search space belongs to the current user
try: try:
await check_ownership(session, SearchSpace, search_space_id, user) await check_ownership(session, SearchSpace, search_space_id, user)
except HTTPException: except HTTPException:
raise HTTPException( raise HTTPException(
status_code=403, detail="You don't have access to this search space") status_code=403, detail="You don't have access to this search space"
) from None
langchain_chat_history = [] langchain_chat_history = []
for message in messages[:-1]: for message in messages[:-1]:
if message['role'] == "user": if message["role"] == "user":
langchain_chat_history.append(HumanMessage(content=message['content'])) langchain_chat_history.append(HumanMessage(content=message["content"]))
elif message['role'] == "assistant": elif message["role"] == "assistant":
langchain_chat_history.append(AIMessage(content=message['content'])) langchain_chat_history.append(AIMessage(content=message["content"]))
response = StreamingResponse( response = StreamingResponse(
stream_connector_search_results( stream_connector_search_results(
@ -78,7 +82,7 @@ async def handle_chat_data(
async def create_chat( async def create_chat(
chat: ChatCreate, chat: ChatCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
await check_ownership(session, SearchSpace, chat.search_space_id, user) await check_ownership(session, SearchSpace, chat.search_space_id, user)
@ -89,27 +93,32 @@ async def create_chat(
return db_chat return db_chat
except HTTPException: except HTTPException:
raise raise
except IntegrityError as e: except IntegrityError:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=400, detail="Database constraint violation. Please check your input data.") status_code=400,
except OperationalError as e: detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.") status_code=503, detail="Database operation failed. Please try again later."
except Exception as e: ) from None
except Exception:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, detail="An unexpected error occurred while creating the chat.") status_code=500,
detail="An unexpected error occurred while creating the chat.",
) from None
@router.get("/chats/", response_model=List[ChatRead]) @router.get("/chats/", response_model=list[ChatRead])
async def read_chats( async def read_chats(
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
search_space_id: int = None, search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
query = select(Chat).join(SearchSpace).filter(SearchSpace.user_id == user.id) query = select(Chat).join(SearchSpace).filter(SearchSpace.user_id == user.id)
@ -118,23 +127,23 @@ async def read_chats(
if search_space_id is not None: if search_space_id is not None:
query = query.filter(Chat.search_space_id == search_space_id) query = query.filter(Chat.search_space_id == search_space_id)
result = await session.execute( result = await session.execute(query.offset(skip).limit(limit))
query.offset(skip).limit(limit)
)
return result.scalars().all() return result.scalars().all()
except OperationalError: except OperationalError:
raise HTTPException( raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.") status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception: except Exception:
raise HTTPException( raise HTTPException(
status_code=500, detail="An unexpected error occurred while fetching chats.") status_code=500, detail="An unexpected error occurred while fetching chats."
) from None
@router.get("/chats/{chat_id}", response_model=ChatRead) @router.get("/chats/{chat_id}", response_model=ChatRead)
async def read_chat( async def read_chat(
chat_id: int, chat_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
result = await session.execute( result = await session.execute(
@ -145,14 +154,19 @@ async def read_chat(
chat = result.scalars().first() chat = result.scalars().first()
if not chat: if not chat:
raise HTTPException( raise HTTPException(
status_code=404, detail="Chat not found or you don't have permission to access it") status_code=404,
detail="Chat not found or you don't have permission to access it",
)
return chat return chat
except OperationalError: except OperationalError:
raise HTTPException( raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.") status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception: except Exception:
raise HTTPException( raise HTTPException(
status_code=500, detail="An unexpected error occurred while fetching the chat.") status_code=500,
detail="An unexpected error occurred while fetching the chat.",
) from None
@router.put("/chats/{chat_id}", response_model=ChatRead) @router.put("/chats/{chat_id}", response_model=ChatRead)
@ -160,7 +174,7 @@ async def update_chat(
chat_id: int, chat_id: int,
chat_update: ChatUpdate, chat_update: ChatUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
db_chat = await read_chat(chat_id, session, user) db_chat = await read_chat(chat_id, session, user)
@ -175,22 +189,27 @@ async def update_chat(
except IntegrityError: except IntegrityError:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=400, detail="Database constraint violation. Please check your input data.") status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError: except OperationalError:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.") status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception: except Exception:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, detail="An unexpected error occurred while updating the chat.") status_code=500,
detail="An unexpected error occurred while updating the chat.",
) from None
@router.delete("/chats/{chat_id}", response_model=dict) @router.delete("/chats/{chat_id}", response_model=dict)
async def delete_chat( async def delete_chat(
chat_id: int, chat_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
db_chat = await read_chat(chat_id, session, user) db_chat = await read_chat(chat_id, session, user)
@ -202,81 +221,16 @@ async def delete_chat(
except IntegrityError: except IntegrityError:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=400, detail="Cannot delete chat due to existing dependencies.") status_code=400, detail="Cannot delete chat due to existing dependencies."
) from None
except OperationalError: except OperationalError:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later.") status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception: except Exception:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, detail="An unexpected error occurred while deleting the chat.") status_code=500,
detail="An unexpected error occurred while deleting the chat.",
) from None
# test_data = [
# {
# "type": "TERMINAL_INFO",
# "content": [
# {
# "id": 1,
# "text": "Starting to search for crawled URLs...",
# "type": "info"
# },
# {
# "id": 2,
# "text": "Found 2 relevant crawled URLs",
# "type": "success"
# }
# ]
# },
# {
# "type": "SOURCES",
# "content": [
# {
# "id": 1,
# "name": "Crawled URLs",
# "type": "CRAWLED_URL",
# "sources": [
# {
# "id": 1,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://jsoneditoronline.org/"
# },
# {
# "id": 2,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://www.google.com/"
# }
# ]
# },
# {
# "id": 2,
# "name": "Files",
# "type": "FILE",
# "sources": [
# {
# "id": 3,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://jsoneditoronline.org/"
# },
# {
# "id": 4,
# "title": "Webpage Title",
# "description": "Webpage Dec",
# "url": "https://www.google.com/"
# }
# ]
# }
# ]
# },
# {
# "type": "ANSWER",
# "content": [
# "## SurfSense Introduction",
# "Surfsense is A Personal NotebookLM and Perplexity-like AI Assistant for Everyone. Research and Never forget Anything. [1] [3]"
# ]
# }
# ]

View file

@ -1,23 +1,35 @@
from litellm import atranscription
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from typing import List
from app.db import Log, get_async_session, User, SearchSpace, Document, DocumentType
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from app.tasks.background_tasks import add_received_markdown_file_document, add_extension_received_document, add_received_file_document_using_unstructured, add_crawled_url_document, add_youtube_video_document, add_received_file_document_using_llamacloud, add_received_file_document_using_docling
from app.config import config as app_config
# Force asyncio to use standard event loop before unstructured imports # Force asyncio to use standard event loop before unstructured imports
import asyncio import asyncio
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, UploadFile
from litellm import atranscription
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config as app_config
from app.db import Document, DocumentType, Log, SearchSpace, User, get_async_session
from app.schemas import DocumentRead, DocumentsCreate, DocumentUpdate
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.tasks.background_tasks import (
add_crawled_url_document,
add_extension_received_document,
add_received_file_document_using_docling,
add_received_file_document_using_llamacloud,
add_received_file_document_using_unstructured,
add_received_markdown_file_document,
add_youtube_video_document,
)
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
try: try:
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
except RuntimeError: except RuntimeError as e:
print("Error setting event loop policy", e)
pass pass
import os import os
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1" os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
@ -29,7 +41,7 @@ async def create_documents(
request: DocumentsCreate, request: DocumentsCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
fastapi_background_tasks: BackgroundTasks = BackgroundTasks() fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
): ):
try: try:
# Check if the user owns the search space # Check if the user owns the search space
@ -41,7 +53,7 @@ async def create_documents(
process_extension_document_with_new_session, process_extension_document_with_new_session,
individual_document, individual_document,
request.search_space_id, request.search_space_id,
str(user.id) str(user.id),
) )
elif request.document_type == DocumentType.CRAWLED_URL: elif request.document_type == DocumentType.CRAWLED_URL:
for url in request.content: for url in request.content:
@ -49,7 +61,7 @@ async def create_documents(
process_crawled_url_with_new_session, process_crawled_url_with_new_session,
url, url,
request.search_space_id, request.search_space_id,
str(user.id) str(user.id),
) )
elif request.document_type == DocumentType.YOUTUBE_VIDEO: elif request.document_type == DocumentType.YOUTUBE_VIDEO:
for url in request.content: for url in request.content:
@ -57,13 +69,10 @@ async def create_documents(
process_youtube_video_with_new_session, process_youtube_video_with_new_session,
url, url,
request.search_space_id, request.search_space_id,
str(user.id) str(user.id),
) )
else: else:
raise HTTPException( raise HTTPException(status_code=400, detail="Invalid document type")
status_code=400,
detail="Invalid document type"
)
await session.commit() await session.commit()
return {"message": "Documents processed successfully"} return {"message": "Documents processed successfully"}
@ -72,18 +81,17 @@ async def create_documents(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to process documents: {e!s}"
detail=f"Failed to process documents: {str(e)}" ) from e
)
@router.post("/documents/fileupload") @router.post("/documents/fileupload")
async def create_documents( async def create_documents_file_upload(
files: list[UploadFile], files: list[UploadFile],
search_space_id: int = Form(...), search_space_id: int = Form(...),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
fastapi_background_tasks: BackgroundTasks = BackgroundTasks() fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
): ):
try: try:
await check_ownership(session, SearchSpace, search_space_id, user) await check_ownership(session, SearchSpace, search_space_id, user)
@ -94,12 +102,13 @@ async def create_documents(
for file in files: for file in files:
try: try:
# Save file to a temporary location to avoid stream issues # Save file to a temporary location to avoid stream issues
import tempfile
import aiofiles
import os import os
import tempfile
# Create temp file # Create temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(file.filename)[1]
) as temp_file:
temp_path = temp_file.name temp_path = temp_file.name
# Write uploaded file to temp file # Write uploaded file to temp file
@ -112,13 +121,13 @@ async def create_documents(
temp_path, temp_path,
file.filename, file.filename,
search_space_id, search_space_id,
str(user.id) str(user.id),
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=422, status_code=422,
detail=f"Failed to process file {file.filename}: {str(e)}" detail=f"Failed to process file {file.filename}: {e!s}",
) ) from e
await session.commit() await session.commit()
return {"message": "Files uploaded for processing"} return {"message": "Files uploaded for processing"}
@ -127,9 +136,8 @@ async def create_documents(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to upload files: {e!s}"
detail=f"Failed to upload files: {str(e)}" ) from e
)
async def process_file_in_background( async def process_file_in_background(
@ -139,62 +147,69 @@ async def process_file_in_background(
user_id: str, user_id: str,
session: AsyncSession, session: AsyncSession,
task_logger: TaskLoggingService, task_logger: TaskLoggingService,
log_entry: Log log_entry: Log,
): ):
try: try:
# Check if the file is a markdown or text file # Check if the file is a markdown or text file
if filename.lower().endswith(('.md', '.markdown', '.txt')): if filename.lower().endswith((".md", ".markdown", ".txt")):
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing markdown/text file: {filename}", f"Processing markdown/text file: {filename}",
{"file_type": "markdown", "processing_stage": "reading_file"} {"file_type": "markdown", "processing_stage": "reading_file"},
) )
# For markdown files, read the content directly # For markdown files, read the content directly
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, encoding="utf-8") as f:
markdown_content = f.read() markdown_content = f.read()
# Clean up the temp file # Clean up the temp file
import os import os
try: try:
os.unlink(file_path) os.unlink(file_path)
except: except Exception as e:
print("Error deleting temp file", e)
pass pass
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Creating document from markdown content: {filename}", f"Creating document from markdown content: {filename}",
{"processing_stage": "creating_document", "content_length": len(markdown_content)} {
"processing_stage": "creating_document",
"content_length": len(markdown_content),
},
) )
# Process markdown directly through specialized function # Process markdown directly through specialized function
result = await add_received_markdown_file_document( result = await add_received_markdown_file_document(
session, session, filename, markdown_content, search_space_id, user_id
filename,
markdown_content,
search_space_id,
user_id
) )
if result: if result:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully processed markdown file: {filename}", f"Successfully processed markdown file: {filename}",
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "markdown"} {
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "markdown",
},
) )
else: else:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Markdown file already exists (duplicate): {filename}", f"Markdown file already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "markdown"} {"duplicate_detected": True, "file_type": "markdown"},
) )
# Check if the file is an audio file # Check if the file is an audio file
elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')): elif filename.lower().endswith(
(".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
):
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing audio file for transcription: {filename}", f"Processing audio file for transcription: {filename}",
{"file_type": "audio", "processing_stage": "starting_transcription"} {"file_type": "audio", "processing_stage": "starting_transcription"},
) )
# Open the audio file for transcription # Open the audio file for transcription
@ -205,53 +220,60 @@ async def process_file_in_background(
model=app_config.STT_SERVICE, model=app_config.STT_SERVICE,
file=audio_file, file=audio_file,
api_base=app_config.STT_SERVICE_API_BASE, api_base=app_config.STT_SERVICE_API_BASE,
api_key=app_config.STT_SERVICE_API_KEY api_key=app_config.STT_SERVICE_API_KEY,
) )
else: else:
transcription_response = await atranscription( transcription_response = await atranscription(
model=app_config.STT_SERVICE, model=app_config.STT_SERVICE,
api_key=app_config.STT_SERVICE_API_KEY, api_key=app_config.STT_SERVICE_API_KEY,
file=audio_file file=audio_file,
) )
# Extract the transcribed text # Extract the transcribed text
transcribed_text = transcription_response.get("text", "") transcribed_text = transcription_response.get("text", "")
# Add metadata about the transcription # Add metadata about the transcription
transcribed_text = f"# Transcription of {filename}\n\n{transcribed_text}" transcribed_text = (
f"# Transcription of {filename}\n\n{transcribed_text}"
)
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Transcription completed, creating document: {filename}", f"Transcription completed, creating document: {filename}",
{"processing_stage": "transcription_complete", "transcript_length": len(transcribed_text)} {
"processing_stage": "transcription_complete",
"transcript_length": len(transcribed_text),
},
) )
# Clean up the temp file # Clean up the temp file
try: try:
os.unlink(file_path) os.unlink(file_path)
except: except Exception as e:
print("Error deleting temp file", e)
pass pass
# Process transcription as markdown document # Process transcription as markdown document
result = await add_received_markdown_file_document( result = await add_received_markdown_file_document(
session, session, filename, transcribed_text, search_space_id, user_id
filename,
transcribed_text,
search_space_id,
user_id
) )
if result: if result:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully transcribed and processed audio file: {filename}", f"Successfully transcribed and processed audio file: {filename}",
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "audio", "transcript_length": len(transcribed_text)} {
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "audio",
"transcript_length": len(transcribed_text),
},
) )
else: else:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Audio file transcript already exists (duplicate): {filename}", f"Audio file transcript already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "audio"} {"duplicate_detected": True, "file_type": "audio"},
) )
else: else:
@ -259,7 +281,11 @@ async def process_file_in_background(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing file with Unstructured ETL: {filename}", f"Processing file with Unstructured ETL: {filename}",
{"file_type": "document", "etl_service": "UNSTRUCTURED", "processing_stage": "loading"} {
"file_type": "document",
"etl_service": "UNSTRUCTURED",
"processing_stage": "loading",
},
) )
from langchain_unstructured import UnstructuredLoader from langchain_unstructured import UnstructuredLoader
@ -280,56 +306,66 @@ async def process_file_in_background(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Unstructured ETL completed, creating document: {filename}", f"Unstructured ETL completed, creating document: {filename}",
{"processing_stage": "etl_complete", "elements_count": len(docs)} {"processing_stage": "etl_complete", "elements_count": len(docs)},
) )
# Clean up the temp file # Clean up the temp file
import os import os
try: try:
os.unlink(file_path) os.unlink(file_path)
except: except Exception as e:
print("Error deleting temp file", e)
pass pass
# Pass the documents to the existing background task # Pass the documents to the existing background task
result = await add_received_file_document_using_unstructured( result = await add_received_file_document_using_unstructured(
session, session, filename, docs, search_space_id, user_id
filename,
docs,
search_space_id,
user_id
) )
if result: if result:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully processed file with Unstructured: {filename}", f"Successfully processed file with Unstructured: {filename}",
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "document", "etl_service": "UNSTRUCTURED"} {
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "document",
"etl_service": "UNSTRUCTURED",
},
) )
else: else:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Document already exists (duplicate): {filename}", f"Document already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "document", "etl_service": "UNSTRUCTURED"} {
"duplicate_detected": True,
"file_type": "document",
"etl_service": "UNSTRUCTURED",
},
) )
elif app_config.ETL_SERVICE == "LLAMACLOUD": elif app_config.ETL_SERVICE == "LLAMACLOUD":
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing file with LlamaCloud ETL: {filename}", f"Processing file with LlamaCloud ETL: {filename}",
{"file_type": "document", "etl_service": "LLAMACLOUD", "processing_stage": "parsing"} {
"file_type": "document",
"etl_service": "LLAMACLOUD",
"processing_stage": "parsing",
},
) )
from llama_cloud_services import LlamaParse from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType from llama_cloud_services.parse.utils import ResultType
# Create LlamaParse parser instance # Create LlamaParse parser instance
parser = LlamaParse( parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY, api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1, # Use single worker for file processing num_workers=1, # Use single worker for file processing
verbose=True, verbose=True,
language="en", language="en",
result_type=ResultType.MD result_type=ResultType.MD,
) )
# Parse the file asynchronously # Parse the file asynchronously
@ -337,18 +373,25 @@ async def process_file_in_background(
# Clean up the temp file # Clean up the temp file
import os import os
try: try:
os.unlink(file_path) os.unlink(file_path)
except: except Exception as e:
print("Error deleting temp file", e)
pass pass
# Get markdown documents from the result # Get markdown documents from the result
markdown_documents = await result.aget_markdown_documents(split_by_page=False) markdown_documents = await result.aget_markdown_documents(
split_by_page=False
)
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"LlamaCloud parsing completed, creating documents: {filename}", f"LlamaCloud parsing completed, creating documents: {filename}",
{"processing_stage": "parsing_complete", "documents_count": len(markdown_documents)} {
"processing_stage": "parsing_complete",
"documents_count": len(markdown_documents),
},
) )
for doc in markdown_documents: for doc in markdown_documents:
@ -361,27 +404,40 @@ async def process_file_in_background(
filename, filename,
llamacloud_markdown_document=markdown_content, llamacloud_markdown_document=markdown_content,
search_space_id=search_space_id, search_space_id=search_space_id,
user_id=user_id user_id=user_id,
) )
if doc_result: if doc_result:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully processed file with LlamaCloud: {filename}", f"Successfully processed file with LlamaCloud: {filename}",
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "LLAMACLOUD"} {
"document_id": doc_result.id,
"content_hash": doc_result.content_hash,
"file_type": "document",
"etl_service": "LLAMACLOUD",
},
) )
else: else:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Document already exists (duplicate): {filename}", f"Document already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "document", "etl_service": "LLAMACLOUD"} {
"duplicate_detected": True,
"file_type": "document",
"etl_service": "LLAMACLOUD",
},
) )
elif app_config.ETL_SERVICE == "DOCLING": elif app_config.ETL_SERVICE == "DOCLING":
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing file with Docling ETL: {filename}", f"Processing file with Docling ETL: {filename}",
{"file_type": "document", "etl_service": "DOCLING", "processing_stage": "parsing"} {
"file_type": "document",
"etl_service": "DOCLING",
"processing_stage": "parsing",
},
) )
# Use Docling service for document processing # Use Docling service for document processing
@ -395,97 +451,112 @@ async def process_file_in_background(
# Clean up the temp file # Clean up the temp file
import os import os
try: try:
os.unlink(file_path) os.unlink(file_path)
except: except Exception as e:
print("Error deleting temp file", e)
pass pass
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Docling parsing completed, creating document: {filename}", f"Docling parsing completed, creating document: {filename}",
{"processing_stage": "parsing_complete", "content_length": len(result['content'])} {
"processing_stage": "parsing_complete",
"content_length": len(result["content"]),
},
) )
# Process the document using our Docling background task # Process the document using our Docling background task
doc_result = await add_received_file_document_using_docling( doc_result = await add_received_file_document_using_docling(
session, session,
filename, filename,
docling_markdown_document=result['content'], docling_markdown_document=result["content"],
search_space_id=search_space_id, search_space_id=search_space_id,
user_id=user_id user_id=user_id,
) )
if doc_result: if doc_result:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully processed file with Docling: {filename}", f"Successfully processed file with Docling: {filename}",
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "DOCLING"} {
"document_id": doc_result.id,
"content_hash": doc_result.content_hash,
"file_type": "document",
"etl_service": "DOCLING",
},
) )
else: else:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Document already exists (duplicate): {filename}", f"Document already exists (duplicate): {filename}",
{"duplicate_detected": True, "file_type": "document", "etl_service": "DOCLING"} {
"duplicate_detected": True,
"file_type": "document",
"etl_service": "DOCLING",
},
) )
except Exception as e: except Exception as e:
await task_logger.log_task_failure( await task_logger.log_task_failure(
log_entry, log_entry,
f"Failed to process file: {filename}", f"Failed to process file: {filename}",
str(e), str(e),
{"error_type": type(e).__name__, "filename": filename} {"error_type": type(e).__name__, "filename": filename},
) )
import logging import logging
logging.error(f"Error processing file in background: {str(e)}")
logging.error(f"Error processing file in background: {e!s}")
raise # Re-raise so the wrapper can also handle it raise # Re-raise so the wrapper can also handle it
@router.get("/documents/", response_model=List[DocumentRead]) @router.get("/documents/", response_model=list[DocumentRead])
async def read_documents( async def read_documents(
skip: int = 0, skip: int = 0,
limit: int = 300, limit: int = 300,
search_space_id: int = None, search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
query = select(Document).join(SearchSpace).filter( query = (
SearchSpace.user_id == user.id) select(Document).join(SearchSpace).filter(SearchSpace.user_id == user.id)
)
# Filter by search_space_id if provided # Filter by search_space_id if provided
if search_space_id is not None: if search_space_id is not None:
query = query.filter(Document.search_space_id == search_space_id) query = query.filter(Document.search_space_id == search_space_id)
result = await session.execute( result = await session.execute(query.offset(skip).limit(limit))
query.offset(skip).limit(limit)
)
db_documents = result.scalars().all() db_documents = result.scalars().all()
# Convert database objects to API-friendly format # Convert database objects to API-friendly format
api_documents = [] api_documents = []
for doc in db_documents: for doc in db_documents:
api_documents.append(DocumentRead( api_documents.append(
id=doc.id, DocumentRead(
title=doc.title, id=doc.id,
document_type=doc.document_type, title=doc.title,
document_metadata=doc.document_metadata, document_type=doc.document_type,
content=doc.content, document_metadata=doc.document_metadata,
created_at=doc.created_at, content=doc.content,
search_space_id=doc.search_space_id created_at=doc.created_at,
)) search_space_id=doc.search_space_id,
)
)
return api_documents return api_documents
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to fetch documents: {e!s}"
detail=f"Failed to fetch documents: {str(e)}" ) from e
)
@router.get("/documents/{document_id}", response_model=DocumentRead) @router.get("/documents/{document_id}", response_model=DocumentRead)
async def read_document( async def read_document(
document_id: int, document_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
result = await session.execute( result = await session.execute(
@ -497,8 +568,7 @@ async def read_document(
if not document: if not document:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"Document with id {document_id} not found"
detail=f"Document with id {document_id} not found"
) )
# Convert database object to API-friendly format # Convert database object to API-friendly format
@ -509,13 +579,12 @@ async def read_document(
document_metadata=document.document_metadata, document_metadata=document.document_metadata,
content=document.content, content=document.content,
created_at=document.created_at, created_at=document.created_at,
search_space_id=document.search_space_id search_space_id=document.search_space_id,
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to fetch document: {e!s}"
detail=f"Failed to fetch document: {str(e)}" ) from e
)
@router.put("/documents/{document_id}", response_model=DocumentRead) @router.put("/documents/{document_id}", response_model=DocumentRead)
@ -523,7 +592,7 @@ async def update_document(
document_id: int, document_id: int,
document_update: DocumentUpdate, document_update: DocumentUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
# Query the document directly instead of using read_document function # Query the document directly instead of using read_document function
@ -536,8 +605,7 @@ async def update_document(
if not db_document: if not db_document:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"Document with id {document_id} not found"
detail=f"Document with id {document_id} not found"
) )
update_data = document_update.model_dump(exclude_unset=True) update_data = document_update.model_dump(exclude_unset=True)
@ -554,23 +622,22 @@ async def update_document(
document_metadata=db_document.document_metadata, document_metadata=db_document.document_metadata,
content=db_document.content, content=db_document.content,
created_at=db_document.created_at, created_at=db_document.created_at,
search_space_id=db_document.search_space_id search_space_id=db_document.search_space_id,
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to update document: {e!s}"
detail=f"Failed to update document: {str(e)}" ) from e
)
@router.delete("/documents/{document_id}", response_model=dict) @router.delete("/documents/{document_id}", response_model=dict)
async def delete_document( async def delete_document(
document_id: int, document_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
# Query the document directly instead of using read_document function # Query the document directly instead of using read_document function
@ -583,8 +650,7 @@ async def delete_document(
if not document: if not document:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"Document with id {document_id} not found"
detail=f"Document with id {document_id} not found"
) )
await session.delete(document) await session.delete(document)
@ -595,15 +661,12 @@ async def delete_document(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to delete document: {e!s}"
detail=f"Failed to delete document: {str(e)}" ) from e
)
async def process_extension_document_with_new_session( async def process_extension_document_with_new_session(
individual_document, individual_document, search_space_id: int, user_id: str
search_space_id: int,
user_id: str
): ):
"""Create a new session and process extension document.""" """Create a new session and process extension document."""
from app.db import async_session_maker from app.db import async_session_maker
@ -622,40 +685,41 @@ async def process_extension_document_with_new_session(
"document_type": "EXTENSION", "document_type": "EXTENSION",
"url": individual_document.metadata.VisitedWebPageURL, "url": individual_document.metadata.VisitedWebPageURL,
"title": individual_document.metadata.VisitedWebPageTitle, "title": individual_document.metadata.VisitedWebPageTitle,
"user_id": user_id "user_id": user_id,
} },
) )
try: try:
result = await add_extension_received_document(session, individual_document, search_space_id, user_id) result = await add_extension_received_document(
session, individual_document, search_space_id, user_id
)
if result: if result:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully processed extension document: {individual_document.metadata.VisitedWebPageTitle}", f"Successfully processed extension document: {individual_document.metadata.VisitedWebPageTitle}",
{"document_id": result.id, "content_hash": result.content_hash} {"document_id": result.id, "content_hash": result.content_hash},
) )
else: else:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Extension document already exists (duplicate): {individual_document.metadata.VisitedWebPageTitle}", f"Extension document already exists (duplicate): {individual_document.metadata.VisitedWebPageTitle}",
{"duplicate_detected": True} {"duplicate_detected": True},
) )
except Exception as e: except Exception as e:
await task_logger.log_task_failure( await task_logger.log_task_failure(
log_entry, log_entry,
f"Failed to process extension document: {individual_document.metadata.VisitedWebPageTitle}", f"Failed to process extension document: {individual_document.metadata.VisitedWebPageTitle}",
str(e), str(e),
{"error_type": type(e).__name__} {"error_type": type(e).__name__},
) )
import logging import logging
logging.error(f"Error processing extension document: {str(e)}")
logging.error(f"Error processing extension document: {e!s}")
async def process_crawled_url_with_new_session( async def process_crawled_url_with_new_session(
url: str, url: str, search_space_id: int, user_id: str
search_space_id: int,
user_id: str
): ):
"""Create a new session and process crawled URL.""" """Create a new session and process crawled URL."""
from app.db import async_session_maker from app.db import async_session_maker
@ -670,44 +734,44 @@ async def process_crawled_url_with_new_session(
task_name="process_crawled_url", task_name="process_crawled_url",
source="document_processor", source="document_processor",
message=f"Starting URL crawling and processing for: {url}", message=f"Starting URL crawling and processing for: {url}",
metadata={ metadata={"document_type": "CRAWLED_URL", "url": url, "user_id": user_id},
"document_type": "CRAWLED_URL",
"url": url,
"user_id": user_id
}
) )
try: try:
result = await add_crawled_url_document(session, url, search_space_id, user_id) result = await add_crawled_url_document(
session, url, search_space_id, user_id
)
if result: if result:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully crawled and processed URL: {url}", f"Successfully crawled and processed URL: {url}",
{"document_id": result.id, "title": result.title, "content_hash": result.content_hash} {
"document_id": result.id,
"title": result.title,
"content_hash": result.content_hash,
},
) )
else: else:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"URL document already exists (duplicate): {url}", f"URL document already exists (duplicate): {url}",
{"duplicate_detected": True} {"duplicate_detected": True},
) )
except Exception as e: except Exception as e:
await task_logger.log_task_failure( await task_logger.log_task_failure(
log_entry, log_entry,
f"Failed to crawl URL: {url}", f"Failed to crawl URL: {url}",
str(e), str(e),
{"error_type": type(e).__name__} {"error_type": type(e).__name__},
) )
import logging import logging
logging.error(f"Error processing crawled URL: {str(e)}")
logging.error(f"Error processing crawled URL: {e!s}")
async def process_file_in_background_with_new_session( async def process_file_in_background_with_new_session(
file_path: str, file_path: str, filename: str, search_space_id: int, user_id: str
filename: str,
search_space_id: int,
user_id: str
): ):
"""Create a new session and process file.""" """Create a new session and process file."""
from app.db import async_session_maker from app.db import async_session_maker
@ -726,12 +790,20 @@ async def process_file_in_background_with_new_session(
"document_type": "FILE", "document_type": "FILE",
"filename": filename, "filename": filename,
"file_path": file_path, "file_path": file_path,
"user_id": user_id "user_id": user_id,
} },
) )
try: try:
await process_file_in_background(file_path, filename, search_space_id, user_id, session, task_logger, log_entry) await process_file_in_background(
file_path,
filename,
search_space_id,
user_id,
session,
task_logger,
log_entry,
)
# Note: success/failure logging is handled within process_file_in_background # Note: success/failure logging is handled within process_file_in_background
except Exception as e: except Exception as e:
@ -739,16 +811,15 @@ async def process_file_in_background_with_new_session(
log_entry, log_entry,
f"Failed to process file: {filename}", f"Failed to process file: {filename}",
str(e), str(e),
{"error_type": type(e).__name__} {"error_type": type(e).__name__},
) )
import logging import logging
logging.error(f"Error processing file: {str(e)}")
logging.error(f"Error processing file: {e!s}")
async def process_youtube_video_with_new_session( async def process_youtube_video_with_new_session(
url: str, url: str, search_space_id: int, user_id: str
search_space_id: int,
user_id: str
): ):
"""Create a new session and process YouTube video.""" """Create a new session and process YouTube video."""
from app.db import async_session_maker from app.db import async_session_maker
@ -763,36 +834,37 @@ async def process_youtube_video_with_new_session(
task_name="process_youtube_video", task_name="process_youtube_video",
source="document_processor", source="document_processor",
message=f"Starting YouTube video processing for: {url}", message=f"Starting YouTube video processing for: {url}",
metadata={ metadata={"document_type": "YOUTUBE_VIDEO", "url": url, "user_id": user_id},
"document_type": "YOUTUBE_VIDEO",
"url": url,
"user_id": user_id
}
) )
try: try:
result = await add_youtube_video_document(session, url, search_space_id, user_id) result = await add_youtube_video_document(
session, url, search_space_id, user_id
)
if result: if result:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully processed YouTube video: {result.title}", f"Successfully processed YouTube video: {result.title}",
{"document_id": result.id, "video_id": result.document_metadata.get("video_id"), "content_hash": result.content_hash} {
"document_id": result.id,
"video_id": result.document_metadata.get("video_id"),
"content_hash": result.content_hash,
},
) )
else: else:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"YouTube video document already exists (duplicate): {url}", f"YouTube video document already exists (duplicate): {url}",
{"duplicate_detected": True} {"duplicate_detected": True},
) )
except Exception as e: except Exception as e:
await task_logger.log_task_failure( await task_logger.log_task_failure(
log_entry, log_entry,
f"Failed to process YouTube video: {url}", f"Failed to process YouTube video: {url}",
str(e), str(e),
{"error_type": type(e).__name__} {"error_type": type(e).__name__},
) )
import logging import logging
logging.error(f"Error processing YouTube video: {str(e)}")
logging.error(f"Error processing YouTube video: {e!s}")

View file

@ -1,35 +1,40 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from typing import List, Optional
from pydantic import BaseModel from app.db import LLMConfig, User, get_async_session
from app.db import get_async_session, User, LLMConfig from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from app.schemas import LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
from app.users import current_active_user from app.users import current_active_user
from app.utils.check_ownership import check_ownership from app.utils.check_ownership import check_ownership
router = APIRouter() router = APIRouter()
class LLMPreferencesUpdate(BaseModel): class LLMPreferencesUpdate(BaseModel):
"""Schema for updating user LLM preferences""" """Schema for updating user LLM preferences"""
long_context_llm_id: Optional[int] = None
fast_llm_id: Optional[int] = None long_context_llm_id: int | None = None
strategic_llm_id: Optional[int] = None fast_llm_id: int | None = None
strategic_llm_id: int | None = None
class LLMPreferencesRead(BaseModel): class LLMPreferencesRead(BaseModel):
"""Schema for reading user LLM preferences""" """Schema for reading user LLM preferences"""
long_context_llm_id: Optional[int] = None
fast_llm_id: Optional[int] = None long_context_llm_id: int | None = None
strategic_llm_id: Optional[int] = None fast_llm_id: int | None = None
long_context_llm: Optional[LLMConfigRead] = None strategic_llm_id: int | None = None
fast_llm: Optional[LLMConfigRead] = None long_context_llm: LLMConfigRead | None = None
strategic_llm: Optional[LLMConfigRead] = None fast_llm: LLMConfigRead | None = None
strategic_llm: LLMConfigRead | None = None
@router.post("/llm-configs/", response_model=LLMConfigRead) @router.post("/llm-configs/", response_model=LLMConfigRead)
async def create_llm_config( async def create_llm_config(
llm_config: LLMConfigCreate, llm_config: LLMConfigCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Create a new LLM configuration for the authenticated user""" """Create a new LLM configuration for the authenticated user"""
try: try:
@ -43,16 +48,16 @@ async def create_llm_config(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to create LLM configuration: {e!s}"
detail=f"Failed to create LLM configuration: {str(e)}" ) from e
)
@router.get("/llm-configs/", response_model=List[LLMConfigRead])
@router.get("/llm-configs/", response_model=list[LLMConfigRead])
async def read_llm_configs( async def read_llm_configs(
skip: int = 0, skip: int = 0,
limit: int = 200, limit: int = 200,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Get all LLM configurations for the authenticated user""" """Get all LLM configurations for the authenticated user"""
try: try:
@ -65,15 +70,15 @@ async def read_llm_configs(
return result.scalars().all() return result.scalars().all()
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
detail=f"Failed to fetch LLM configurations: {str(e)}" ) from e
)
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead) @router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def read_llm_config( async def read_llm_config(
llm_config_id: int, llm_config_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Get a specific LLM configuration by ID""" """Get a specific LLM configuration by ID"""
try: try:
@ -83,16 +88,16 @@ async def read_llm_config(
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to fetch LLM configuration: {e!s}"
detail=f"Failed to fetch LLM configuration: {str(e)}" ) from e
)
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead) @router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def update_llm_config( async def update_llm_config(
llm_config_id: int, llm_config_id: int,
llm_config_update: LLMConfigUpdate, llm_config_update: LLMConfigUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Update an existing LLM configuration""" """Update an existing LLM configuration"""
try: try:
@ -110,15 +115,15 @@ async def update_llm_config(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to update LLM configuration: {e!s}"
detail=f"Failed to update LLM configuration: {str(e)}" ) from e
)
@router.delete("/llm-configs/{llm_config_id}", response_model=dict) @router.delete("/llm-configs/{llm_config_id}", response_model=dict)
async def delete_llm_config( async def delete_llm_config(
llm_config_id: int, llm_config_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Delete an LLM configuration""" """Delete an LLM configuration"""
try: try:
@ -131,16 +136,17 @@ async def delete_llm_config(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to delete LLM configuration: {e!s}"
detail=f"Failed to delete LLM configuration: {str(e)}" ) from e
)
# User LLM Preferences endpoints # User LLM Preferences endpoints
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead) @router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
async def get_user_llm_preferences( async def get_user_llm_preferences(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Get the current user's LLM preferences""" """Get the current user's LLM preferences"""
try: try:
@ -161,7 +167,7 @@ async def get_user_llm_preferences(
long_context_llm = await session.execute( long_context_llm = await session.execute(
select(LLMConfig).filter( select(LLMConfig).filter(
LLMConfig.id == user.long_context_llm_id, LLMConfig.id == user.long_context_llm_id,
LLMConfig.user_id == user.id LLMConfig.user_id == user.id,
) )
) )
llm_config = long_context_llm.scalars().first() llm_config = long_context_llm.scalars().first()
@ -171,8 +177,7 @@ async def get_user_llm_preferences(
if user.fast_llm_id: if user.fast_llm_id:
fast_llm = await session.execute( fast_llm = await session.execute(
select(LLMConfig).filter( select(LLMConfig).filter(
LLMConfig.id == user.fast_llm_id, LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id
LLMConfig.user_id == user.id
) )
) )
llm_config = fast_llm.scalars().first() llm_config = fast_llm.scalars().first()
@ -182,8 +187,7 @@ async def get_user_llm_preferences(
if user.strategic_llm_id: if user.strategic_llm_id:
strategic_llm = await session.execute( strategic_llm = await session.execute(
select(LLMConfig).filter( select(LLMConfig).filter(
LLMConfig.id == user.strategic_llm_id, LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id
LLMConfig.user_id == user.id
) )
) )
llm_config = strategic_llm.scalars().first() llm_config = strategic_llm.scalars().first()
@ -193,35 +197,34 @@ async def get_user_llm_preferences(
return result return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
detail=f"Failed to fetch LLM preferences: {str(e)}" ) from e
)
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead) @router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
async def update_user_llm_preferences( async def update_user_llm_preferences(
preferences: LLMPreferencesUpdate, preferences: LLMPreferencesUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Update the current user's LLM preferences""" """Update the current user's LLM preferences"""
try: try:
# Validate that all provided LLM config IDs belong to the user # Validate that all provided LLM config IDs belong to the user
update_data = preferences.model_dump(exclude_unset=True) update_data = preferences.model_dump(exclude_unset=True)
for key, llm_config_id in update_data.items(): for _key, llm_config_id in update_data.items():
if llm_config_id is not None: if llm_config_id is not None:
# Verify ownership of the LLM config # Verify ownership of the LLM config
result = await session.execute( result = await session.execute(
select(LLMConfig).filter( select(LLMConfig).filter(
LLMConfig.id == llm_config_id, LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id
LLMConfig.user_id == user.id
) )
) )
llm_config = result.scalars().first() llm_config = result.scalars().first()
if not llm_config: if not llm_config:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it" detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it",
) )
# Update user preferences # Update user preferences
@ -238,6 +241,5 @@ async def update_user_llm_preferences(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
detail=f"Failed to update LLM preferences: {str(e)}" ) from e
)

View file

@ -1,22 +1,23 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import and_, desc
from typing import List, Optional
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.db import get_async_session, User, SearchSpace, Log, LogLevel, LogStatus from fastapi import APIRouter, Depends, HTTPException
from app.schemas import LogCreate, LogUpdate, LogRead, LogFilter from sqlalchemy import and_, desc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import Log, LogLevel, LogStatus, SearchSpace, User, get_async_session
from app.schemas import LogCreate, LogRead, LogUpdate
from app.users import current_active_user from app.users import current_active_user
from app.utils.check_ownership import check_ownership from app.utils.check_ownership import check_ownership
router = APIRouter() router = APIRouter()
@router.post("/logs/", response_model=LogRead) @router.post("/logs/", response_model=LogRead)
async def create_log( async def create_log(
log: LogCreate, log: LogCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Create a new log entry.""" """Create a new log entry."""
try: try:
@ -33,22 +34,22 @@ async def create_log(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to create log: {e!s}"
detail=f"Failed to create log: {str(e)}" ) from e
)
@router.get("/logs/", response_model=List[LogRead])
@router.get("/logs/", response_model=list[LogRead])
async def read_logs( async def read_logs(
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
search_space_id: Optional[int] = None, search_space_id: int | None = None,
level: Optional[LogLevel] = None, level: LogLevel | None = None,
status: Optional[LogStatus] = None, status: LogStatus | None = None,
source: Optional[str] = None, source: str | None = None,
start_date: Optional[datetime] = None, start_date: datetime | None = None,
end_date: Optional[datetime] = None, end_date: datetime | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Get logs with optional filtering.""" """Get logs with optional filtering."""
try: try:
@ -93,15 +94,15 @@ async def read_logs(
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to fetch logs: {e!s}"
detail=f"Failed to fetch logs: {str(e)}" ) from e
)
@router.get("/logs/{log_id}", response_model=LogRead) @router.get("/logs/{log_id}", response_model=LogRead)
async def read_log( async def read_log(
log_id: int, log_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Get a specific log by ID.""" """Get a specific log by ID."""
try: try:
@ -121,16 +122,16 @@ async def read_log(
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to fetch log: {e!s}"
detail=f"Failed to fetch log: {str(e)}" ) from e
)
@router.put("/logs/{log_id}", response_model=LogRead) @router.put("/logs/{log_id}", response_model=LogRead)
async def update_log( async def update_log(
log_id: int, log_id: int,
log_update: LogUpdate, log_update: LogUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Update a log entry.""" """Update a log entry."""
try: try:
@ -158,15 +159,15 @@ async def update_log(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to update log: {e!s}"
detail=f"Failed to update log: {str(e)}" ) from e
)
@router.delete("/logs/{log_id}") @router.delete("/logs/{log_id}")
async def delete_log( async def delete_log(
log_id: int, log_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Delete a log entry.""" """Delete a log entry."""
try: try:
@ -189,16 +190,16 @@ async def delete_log(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to delete log: {e!s}"
detail=f"Failed to delete log: {str(e)}" ) from e
)
@router.get("/logs/search-space/{search_space_id}/summary") @router.get("/logs/search-space/{search_space_id}/summary")
async def get_logs_summary( async def get_logs_summary(
search_space_id: int, search_space_id: int,
hours: int = 24, hours: int = 24,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Get a summary of logs for a search space in the last X hours.""" """Get a summary of logs for a search space in the last X hours."""
try: try:
@ -212,10 +213,7 @@ async def get_logs_summary(
result = await session.execute( result = await session.execute(
select(Log) select(Log)
.filter( .filter(
and_( and_(Log.search_space_id == search_space_id, Log.created_at >= since)
Log.search_space_id == search_space_id,
Log.created_at >= since
)
) )
.order_by(desc(Log.created_at)) .order_by(desc(Log.created_at))
) )
@ -229,14 +227,16 @@ async def get_logs_summary(
"by_level": {}, "by_level": {},
"by_source": {}, "by_source": {},
"active_tasks": [], "active_tasks": [],
"recent_failures": [] "recent_failures": [],
} }
# Count by status and level # Count by status and level
for log in logs: for log in logs:
# Status counts # Status counts
status_str = log.status.value status_str = log.status.value
summary["by_status"][status_str] = summary["by_status"].get(status_str, 0) + 1 summary["by_status"][status_str] = (
summary["by_status"].get(status_str, 0) + 1
)
# Level counts # Level counts
level_str = log.level.value level_str = log.level.value
@ -244,30 +244,46 @@ async def get_logs_summary(
# Source counts # Source counts
if log.source: if log.source:
summary["by_source"][log.source] = summary["by_source"].get(log.source, 0) + 1 summary["by_source"][log.source] = (
summary["by_source"].get(log.source, 0) + 1
)
# Active tasks (IN_PROGRESS) # Active tasks (IN_PROGRESS)
if log.status == LogStatus.IN_PROGRESS: if log.status == LogStatus.IN_PROGRESS:
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown" task_name = (
summary["active_tasks"].append({ log.log_metadata.get("task_name", "Unknown")
"id": log.id, if log.log_metadata
"task_name": task_name, else "Unknown"
"message": log.message, )
"started_at": log.created_at, summary["active_tasks"].append(
"source": log.source {
}) "id": log.id,
"task_name": task_name,
"message": log.message,
"started_at": log.created_at,
"source": log.source,
}
)
# Recent failures # Recent failures
if log.status == LogStatus.FAILED and len(summary["recent_failures"]) < 10: if log.status == LogStatus.FAILED and len(summary["recent_failures"]) < 10:
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown" task_name = (
summary["recent_failures"].append({ log.log_metadata.get("task_name", "Unknown")
"id": log.id, if log.log_metadata
"task_name": task_name, else "Unknown"
"message": log.message, )
"failed_at": log.created_at, summary["recent_failures"].append(
"source": log.source, {
"error_details": log.log_metadata.get("error_details") if log.log_metadata else None "id": log.id,
}) "task_name": task_name,
"message": log.message,
"failed_at": log.created_at,
"source": log.source,
"error_details": log.log_metadata.get("error_details")
if log.log_metadata
else None,
}
)
return summary return summary
@ -275,6 +291,5 @@ async def get_logs_summary(
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to generate logs summary: {e!s}"
detail=f"Failed to generate logs summary: {str(e)}" ) from e
)

View file

@ -1,24 +1,31 @@
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from typing import List
from app.db import get_async_session, User, SearchSpace, Podcast, Chat
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from app.tasks.podcast_tasks import generate_chat_podcast
from fastapi.responses import StreamingResponse
import os import os
from pathlib import Path from pathlib import Path
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import Chat, Podcast, SearchSpace, User, get_async_session
from app.schemas import (
PodcastCreate,
PodcastGenerateRequest,
PodcastRead,
PodcastUpdate,
)
from app.tasks.podcast_tasks import generate_chat_podcast
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
router = APIRouter() router = APIRouter()
@router.post("/podcasts/", response_model=PodcastRead) @router.post("/podcasts/", response_model=PodcastRead)
async def create_podcast( async def create_podcast(
podcast: PodcastCreate, podcast: PodcastCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
await check_ownership(session, SearchSpace, podcast.search_space_id, user) await check_ownership(session, SearchSpace, podcast.search_space_id, user)
@ -29,22 +36,30 @@ async def create_podcast(
return db_podcast return db_podcast
except HTTPException as he: except HTTPException as he:
raise he raise he
except IntegrityError as e: except IntegrityError:
await session.rollback() await session.rollback()
raise HTTPException(status_code=400, detail="Podcast creation failed due to constraint violation") raise HTTPException(
except SQLAlchemyError as e: status_code=400,
detail="Podcast creation failed due to constraint violation",
) from None
except SQLAlchemyError:
await session.rollback() await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while creating podcast") raise HTTPException(
except Exception as e: status_code=500, detail="Database error occurred while creating podcast"
) from None
except Exception:
await session.rollback() await session.rollback()
raise HTTPException(status_code=500, detail="An unexpected error occurred") raise HTTPException(
status_code=500, detail="An unexpected error occurred"
) from None
@router.get("/podcasts/", response_model=List[PodcastRead])
@router.get("/podcasts/", response_model=list[PodcastRead])
async def read_podcasts( async def read_podcasts(
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
if skip < 0 or limit < 1: if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters") raise HTTPException(status_code=400, detail="Invalid pagination parameters")
@ -58,13 +73,16 @@ async def read_podcasts(
) )
return result.scalars().all() return result.scalars().all()
except SQLAlchemyError: except SQLAlchemyError:
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcasts") raise HTTPException(
status_code=500, detail="Database error occurred while fetching podcasts"
) from None
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead) @router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
async def read_podcast( async def read_podcast(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
result = await session.execute( result = await session.execute(
@ -76,20 +94,23 @@ async def read_podcast(
if not podcast: if not podcast:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Podcast not found or you don't have permission to access it" detail="Podcast not found or you don't have permission to access it",
) )
return podcast return podcast
except HTTPException as he: except HTTPException as he:
raise he raise he
except SQLAlchemyError: except SQLAlchemyError:
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcast") raise HTTPException(
status_code=500, detail="Database error occurred while fetching podcast"
) from None
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead) @router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
async def update_podcast( async def update_podcast(
podcast_id: int, podcast_id: int,
podcast_update: PodcastUpdate, podcast_update: PodcastUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
db_podcast = await read_podcast(podcast_id, session, user) db_podcast = await read_podcast(podcast_id, session, user)
@ -103,16 +124,21 @@ async def update_podcast(
raise he raise he
except IntegrityError: except IntegrityError:
await session.rollback() await session.rollback()
raise HTTPException(status_code=400, detail="Update failed due to constraint violation") raise HTTPException(
status_code=400, detail="Update failed due to constraint violation"
) from None
except SQLAlchemyError: except SQLAlchemyError:
await session.rollback() await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while updating podcast") raise HTTPException(
status_code=500, detail="Database error occurred while updating podcast"
) from None
@router.delete("/podcasts/{podcast_id}", response_model=dict) @router.delete("/podcasts/{podcast_id}", response_model=dict)
async def delete_podcast( async def delete_podcast(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
db_podcast = await read_podcast(podcast_id, session, user) db_podcast = await read_podcast(podcast_id, session, user)
@ -123,30 +149,34 @@ async def delete_podcast(
raise he raise he
except SQLAlchemyError: except SQLAlchemyError:
await session.rollback() await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast") raise HTTPException(
status_code=500, detail="Database error occurred while deleting podcast"
) from None
async def generate_chat_podcast_with_new_session( async def generate_chat_podcast_with_new_session(
chat_id: int, chat_id: int, search_space_id: int, podcast_title: str, user_id: int
search_space_id: int,
podcast_title: str,
user_id: int
): ):
"""Create a new session and process chat podcast generation.""" """Create a new session and process chat podcast generation."""
from app.db import async_session_maker from app.db import async_session_maker
async with async_session_maker() as session: async with async_session_maker() as session:
try: try:
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title, user_id) await generate_chat_podcast(
session, chat_id, search_space_id, podcast_title, user_id
)
except Exception as e: except Exception as e:
import logging import logging
logging.error(f"Error generating podcast from chat: {str(e)}")
logging.error(f"Error generating podcast from chat: {e!s}")
@router.post("/podcasts/generate/") @router.post("/podcasts/generate/")
async def generate_podcast( async def generate_podcast(
request: PodcastGenerateRequest, request: PodcastGenerateRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
fastapi_background_tasks: BackgroundTasks = BackgroundTasks() fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
): ):
try: try:
# Check if the user owns the search space # Check if the user owns the search space
@ -154,10 +184,15 @@ async def generate_podcast(
if request.type == "CHAT": if request.type == "CHAT":
# Verify that all chat IDs belong to this user and search space # Verify that all chat IDs belong to this user and search space
query = select(Chat).filter( query = (
Chat.id.in_(request.ids), select(Chat)
Chat.search_space_id == request.search_space_id .filter(
).join(SearchSpace).filter(SearchSpace.user_id == user.id) Chat.id.in_(request.ids),
Chat.search_space_id == request.search_space_id,
)
.join(SearchSpace)
.filter(SearchSpace.user_id == user.id)
)
result = await session.execute(query) result = await session.execute(query)
valid_chats = result.scalars().all() valid_chats = result.scalars().all()
@ -167,7 +202,7 @@ async def generate_podcast(
if len(valid_chat_ids) != len(request.ids): if len(valid_chat_ids) != len(request.ids):
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail="One or more chat IDs do not belong to this user or search space" detail="One or more chat IDs do not belong to this user or search space",
) )
# Only add a single task with the first chat ID # Only add a single task with the first chat ID
@ -177,7 +212,7 @@ async def generate_podcast(
chat_id, chat_id,
request.search_space_id, request.search_space_id,
request.podcast_title, request.podcast_title,
user.id user.id,
) )
return { return {
@ -185,21 +220,29 @@ async def generate_podcast(
} }
except HTTPException as he: except HTTPException as he:
raise he raise he
except IntegrityError as e: except IntegrityError:
await session.rollback() await session.rollback()
raise HTTPException(status_code=400, detail="Podcast generation failed due to constraint violation") raise HTTPException(
except SQLAlchemyError as e: status_code=400,
detail="Podcast generation failed due to constraint violation",
) from None
except SQLAlchemyError:
await session.rollback() await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while generating podcast") raise HTTPException(
status_code=500, detail="Database error occurred while generating podcast"
) from None
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") raise HTTPException(
status_code=500, detail=f"An unexpected error occurred: {e!s}"
) from e
@router.get("/podcasts/{podcast_id}/stream") @router.get("/podcasts/{podcast_id}/stream")
async def stream_podcast( async def stream_podcast(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
"""Stream a podcast audio file.""" """Stream a podcast audio file."""
try: try:
@ -214,7 +257,7 @@ async def stream_podcast(
if not podcast: if not podcast:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Podcast not found or you don't have permission to access it" detail="Podcast not found or you don't have permission to access it",
) )
# Get the file path # Get the file path
@ -235,11 +278,13 @@ async def stream_podcast(
media_type="audio/mpeg", media_type="audio/mpeg",
headers={ headers={
"Accept-Ranges": "bytes", "Accept-Ranges": "bytes",
"Content-Disposition": f"inline; filename={Path(file_path).name}" "Content-Disposition": f"inline; filename={Path(file_path).name}",
} },
) )
except HTTPException as he: except HTTPException as he:
raise he raise he
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Error streaming podcast: {str(e)}") raise HTTPException(
status_code=500, detail=f"Error streaming podcast: {e!s}"
) from e

View file

@ -12,7 +12,13 @@ Note: Each user can have only one connector of each type (SERPER_API, TAVILY_API
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List from typing import Any
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.connectors.github_connector import GitHubConnector from app.connectors.github_connector import GitHubConnector
from app.db import ( from app.db import (
@ -39,11 +45,6 @@ from app.tasks.connectors_indexing_tasks import (
) )
from app.users import current_active_user from app.users import current_active_user
from app.utils.check_ownership import check_ownership from app.utils.check_ownership import check_ownership
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
# Set up logging # Set up logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,7 +58,7 @@ class GitHubPATRequest(BaseModel):
# --- New Endpoint to list GitHub Repositories --- # --- New Endpoint to list GitHub Repositories ---
@router.post("/github/repositories/", response_model=List[Dict[str, Any]]) @router.post("/github/repositories/", response_model=list[dict[str, Any]])
async def list_github_repositories( async def list_github_repositories(
pat_request: GitHubPATRequest, pat_request: GitHubPATRequest,
user: User = Depends(current_active_user), # Ensure the user is logged in user: User = Depends(current_active_user), # Ensure the user is logged in
@ -74,15 +75,13 @@ async def list_github_repositories(
return repositories return repositories
except ValueError as e: except ValueError as e:
# Handle invalid token error specifically # Handle invalid token error specifically
logger.error(f"GitHub PAT validation failed for user {user.id}: {str(e)}") logger.error(f"GitHub PAT validation failed for user {user.id}: {e!s}")
raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {str(e)}") raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {e!s}") from e
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to fetch GitHub repositories for user {user.id}: {e!s}")
f"Failed to fetch GitHub repositories for user {user.id}: {str(e)}"
)
raise HTTPException( raise HTTPException(
status_code=500, detail="Failed to fetch GitHub repositories." status_code=500, detail="Failed to fetch GitHub repositories."
) ) from e
@router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead) @router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead)
@ -118,32 +117,32 @@ async def create_search_source_connector(
return db_connector return db_connector
except ValidationError as e: except ValidationError as e:
await session.rollback() await session.rollback()
raise HTTPException(status_code=422, detail=f"Validation error: {str(e)}") raise HTTPException(status_code=422, detail=f"Validation error: {e!s}") from e
except IntegrityError as e: except IntegrityError as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
detail=f"Integrity error: A connector with this type already exists. {str(e)}", detail=f"Integrity error: A connector with this type already exists. {e!s}",
) ) from e
except HTTPException: except HTTPException:
await session.rollback() await session.rollback()
raise raise
except Exception as e: except Exception as e:
logger.error(f"Failed to create search source connector: {str(e)}") logger.error(f"Failed to create search source connector: {e!s}")
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"Failed to create search source connector: {str(e)}", detail=f"Failed to create search source connector: {e!s}",
) ) from e
@router.get( @router.get(
"/search-source-connectors/", response_model=List[SearchSourceConnectorRead] "/search-source-connectors/", response_model=list[SearchSourceConnectorRead]
) )
async def read_search_source_connectors( async def read_search_source_connectors(
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
search_space_id: int = None, search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
@ -160,8 +159,8 @@ async def read_search_source_connectors(
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"Failed to fetch search source connectors: {str(e)}", detail=f"Failed to fetch search source connectors: {e!s}",
) ) from e
@router.get( @router.get(
@ -179,8 +178,8 @@ async def read_search_source_connector(
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to fetch search source connector: {str(e)}" status_code=500, detail=f"Failed to fetch search source connector: {e!s}"
) ) from e
@router.put( @router.put(
@ -238,8 +237,8 @@ async def update_search_source_connector(
except ValidationError as e: except ValidationError as e:
# Raise specific validation error for the merged config # Raise specific validation error for the merged config
raise HTTPException( raise HTTPException(
status_code=422, detail=f"Validation error for merged config: {str(e)}" status_code=422, detail=f"Validation error for merged config: {e!s}"
) ) from e
# If validation passes, update the main update_data dict with the merged config # If validation passes, update the main update_data dict with the merged config
update_data["config"] = merged_config update_data["config"] = merged_config
@ -272,8 +271,8 @@ async def update_search_source_connector(
await session.rollback() await session.rollback()
# This might occur if connector_type constraint is violated somehow after the check # This might occur if connector_type constraint is violated somehow after the check
raise HTTPException( raise HTTPException(
status_code=409, detail=f"Database integrity error during update: {str(e)}" status_code=409, detail=f"Database integrity error during update: {e!s}"
) ) from e
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
logger.error( logger.error(
@ -282,8 +281,8 @@ async def update_search_source_connector(
) )
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"Failed to update search source connector: {str(e)}", detail=f"Failed to update search source connector: {e!s}",
) ) from e
@router.delete("/search-source-connectors/{connector_id}", response_model=dict) @router.delete("/search-source-connectors/{connector_id}", response_model=dict)
@ -306,12 +305,12 @@ async def delete_search_source_connector(
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"Failed to delete search source connector: {str(e)}", detail=f"Failed to delete search source connector: {e!s}",
) ) from e
@router.post( @router.post(
"/search-source-connectors/{connector_id}/index", response_model=Dict[str, Any] "/search-source-connectors/{connector_id}/index", response_model=dict[str, Any]
) )
async def index_connector_content( async def index_connector_content(
connector_id: int, connector_id: int,
@ -356,7 +355,7 @@ async def index_connector_content(
) )
# Check if the search space belongs to the user # Check if the search space belongs to the user
search_space = await check_ownership( _search_space = await check_ownership(
session, SearchSpace, search_space_id, user session, SearchSpace, search_space_id, user
) )
@ -381,10 +380,7 @@ async def index_connector_content(
else: else:
indexing_from = start_date indexing_from = start_date
if end_date is None: indexing_to = end_date if end_date else today_str
indexing_to = today_str
else:
indexing_to = end_date
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR: if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
# Run indexing in background # Run indexing in background
@ -497,8 +493,8 @@ async def index_connector_content(
exc_info=True, exc_info=True,
) )
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to initiate indexing: {str(e)}" status_code=500, detail=f"Failed to initiate indexing: {e!s}"
) ) from e
async def update_connector_last_indexed(session: AsyncSession, connector_id: int): async def update_connector_last_indexed(session: AsyncSession, connector_id: int):
@ -523,7 +519,7 @@ async def update_connector_last_indexed(session: AsyncSession, connector_id: int
logger.info(f"Updated last_indexed_at for connector {connector_id}") logger.info(f"Updated last_indexed_at for connector {connector_id}")
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to update last_indexed_at for connector {connector_id}: {str(e)}" f"Failed to update last_indexed_at for connector {connector_id}: {e!s}"
) )
await session.rollback() await session.rollback()
@ -587,7 +583,7 @@ async def run_slack_indexing(
f"Slack indexing failed or no documents processed: {error_or_warning}" f"Slack indexing failed or no documents processed: {error_or_warning}"
) )
except Exception as e: except Exception as e:
logger.error(f"Error in background Slack indexing task: {str(e)}") logger.error(f"Error in background Slack indexing task: {e!s}")
async def run_notion_indexing_with_new_session( async def run_notion_indexing_with_new_session(
@ -649,7 +645,7 @@ async def run_notion_indexing(
f"Notion indexing failed or no documents processed: {error_or_warning}" f"Notion indexing failed or no documents processed: {error_or_warning}"
) )
except Exception as e: except Exception as e:
logger.error(f"Error in background Notion indexing task: {str(e)}") logger.error(f"Error in background Notion indexing task: {e!s}")
# Add new helper functions for GitHub indexing # Add new helper functions for GitHub indexing
@ -829,7 +825,7 @@ async def run_discord_indexing(
f"Discord indexing failed or no documents processed: {error_or_warning}" f"Discord indexing failed or no documents processed: {error_or_warning}"
) )
except Exception as e: except Exception as e:
logger.error(f"Error in background Discord indexing task: {str(e)}") logger.error(f"Error in background Discord indexing task: {e!s}")
# Add new helper functions for Jira indexing # Add new helper functions for Jira indexing

View file

@ -1,20 +1,20 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from typing import List
from app.db import get_async_session, User, SearchSpace from app.db import SearchSpace, User, get_async_session
from app.schemas import SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead from app.schemas import SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate
from app.users import current_active_user from app.users import current_active_user
from app.utils.check_ownership import check_ownership from app.utils.check_ownership import check_ownership
from fastapi import HTTPException
router = APIRouter() router = APIRouter()
@router.post("/searchspaces/", response_model=SearchSpaceRead) @router.post("/searchspaces/", response_model=SearchSpaceRead)
async def create_search_space( async def create_search_space(
search_space: SearchSpaceCreate, search_space: SearchSpaceCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id) db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id)
@ -27,16 +27,16 @@ async def create_search_space(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to create search space: {e!s}"
detail=f"Failed to create search space: {str(e)}" ) from e
)
@router.get("/searchspaces/", response_model=List[SearchSpaceRead])
@router.get("/searchspaces/", response_model=list[SearchSpaceRead])
async def read_search_spaces( async def read_search_spaces(
skip: int = 0, skip: int = 0,
limit: int = 200, limit: int = 200,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
result = await session.execute( result = await session.execute(
@ -48,37 +48,41 @@ async def read_search_spaces(
return result.scalars().all() return result.scalars().all()
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to fetch search spaces: {e!s}"
detail=f"Failed to fetch search spaces: {str(e)}" ) from e
)
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead) @router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
async def read_search_space( async def read_search_space(
search_space_id: int, search_space_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
search_space = await check_ownership(session, SearchSpace, search_space_id, user) search_space = await check_ownership(
session, SearchSpace, search_space_id, user
)
return search_space return search_space
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to fetch search space: {e!s}"
detail=f"Failed to fetch search space: {str(e)}" ) from e
)
@router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead) @router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
async def update_search_space( async def update_search_space(
search_space_id: int, search_space_id: int,
search_space_update: SearchSpaceUpdate, search_space_update: SearchSpaceUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user) db_search_space = await check_ownership(
session, SearchSpace, search_space_id, user
)
update_data = search_space_update.model_dump(exclude_unset=True) update_data = search_space_update.model_dump(exclude_unset=True)
for key, value in update_data.items(): for key, value in update_data.items():
setattr(db_search_space, key, value) setattr(db_search_space, key, value)
@ -90,18 +94,20 @@ async def update_search_space(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to update search space: {e!s}"
detail=f"Failed to update search space: {str(e)}" ) from e
)
@router.delete("/searchspaces/{search_space_id}", response_model=dict) @router.delete("/searchspaces/{search_space_id}", response_model=dict)
async def delete_search_space( async def delete_search_space(
search_space_id: int, search_space_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user) user: User = Depends(current_active_user),
): ):
try: try:
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user) db_search_space = await check_ownership(
session, SearchSpace, search_space_id, user
)
await session.delete(db_search_space) await session.delete(db_search_space)
await session.commit() await session.commit()
return {"message": "Search space deleted successfully"} return {"message": "Search space deleted successfully"}
@ -110,6 +116,5 @@ async def delete_search_space(
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Failed to delete search space: {e!s}"
detail=f"Failed to delete search space: {str(e)}" ) from e
)

View file

@ -1,62 +1,78 @@
from .base import TimestampModel, IDModel from .base import IDModel, TimestampModel
from .users import UserRead, UserCreate, UserUpdate from .chats import AISDKChatRequest, ChatBase, ChatCreate, ChatRead, ChatUpdate
from .search_space import SearchSpaceBase, SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
from .documents import ( from .documents import (
ExtensionDocumentMetadata,
ExtensionDocumentContent,
DocumentBase, DocumentBase,
DocumentRead,
DocumentsCreate, DocumentsCreate,
DocumentUpdate, DocumentUpdate,
DocumentRead, ExtensionDocumentContent,
ExtensionDocumentMetadata,
) )
from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest from .podcasts import (
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead PodcastBase,
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigUpdate, LLMConfigRead PodcastCreate,
from .logs import LogBase, LogCreate, LogUpdate, LogRead, LogFilter PodcastGenerateRequest,
PodcastRead,
PodcastUpdate,
)
from .search_source_connector import (
SearchSourceConnectorBase,
SearchSourceConnectorCreate,
SearchSourceConnectorRead,
SearchSourceConnectorUpdate,
)
from .search_space import (
SearchSpaceBase,
SearchSpaceCreate,
SearchSpaceRead,
SearchSpaceUpdate,
)
from .users import UserCreate, UserRead, UserUpdate
__all__ = [ __all__ = [
"AISDKChatRequest", "AISDKChatRequest",
"TimestampModel",
"IDModel",
"UserRead",
"UserCreate",
"UserUpdate",
"SearchSpaceBase",
"SearchSpaceCreate",
"SearchSpaceUpdate",
"SearchSpaceRead",
"ExtensionDocumentMetadata",
"ExtensionDocumentContent",
"DocumentBase",
"DocumentsCreate",
"DocumentUpdate",
"DocumentRead",
"ChunkBase",
"ChunkCreate",
"ChunkUpdate",
"ChunkRead",
"PodcastBase",
"PodcastCreate",
"PodcastUpdate",
"PodcastRead",
"PodcastGenerateRequest",
"ChatBase", "ChatBase",
"ChatCreate", "ChatCreate",
"ChatUpdate",
"ChatRead", "ChatRead",
"SearchSourceConnectorBase", "ChatUpdate",
"SearchSourceConnectorCreate", "ChunkBase",
"SearchSourceConnectorUpdate", "ChunkCreate",
"SearchSourceConnectorRead", "ChunkRead",
"ChunkUpdate",
"DocumentBase",
"DocumentRead",
"DocumentUpdate",
"DocumentsCreate",
"ExtensionDocumentContent",
"ExtensionDocumentMetadata",
"IDModel",
"LLMConfigBase", "LLMConfigBase",
"LLMConfigCreate", "LLMConfigCreate",
"LLMConfigUpdate",
"LLMConfigRead", "LLMConfigRead",
"LLMConfigUpdate",
"LogBase", "LogBase",
"LogCreate", "LogCreate",
"LogUpdate",
"LogRead",
"LogFilter", "LogFilter",
"LogRead",
"LogUpdate",
"PodcastBase",
"PodcastCreate",
"PodcastGenerateRequest",
"PodcastRead",
"PodcastUpdate",
"SearchSourceConnectorBase",
"SearchSourceConnectorCreate",
"SearchSourceConnectorRead",
"SearchSourceConnectorUpdate",
"SearchSpaceBase",
"SearchSpaceCreate",
"SearchSpaceRead",
"SearchSpaceUpdate",
"TimestampModel",
"UserCreate",
"UserRead",
"UserUpdate",
] ]

View file

@ -1,10 +1,13 @@
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
class TimestampModel(BaseModel): class TimestampModel(BaseModel):
created_at: datetime created_at: datetime
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class IDModel(BaseModel): class IDModel(BaseModel):
id: int id: int
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -1,7 +1,8 @@
from typing import Any, Dict, List, Optional from typing import Any
from pydantic import BaseModel, ConfigDict
from app.db import ChatType from app.db import ChatType
from pydantic import BaseModel, ConfigDict
from .base import IDModel, TimestampModel from .base import IDModel, TimestampModel
@ -9,20 +10,20 @@ from .base import IDModel, TimestampModel
class ChatBase(BaseModel): class ChatBase(BaseModel):
type: ChatType type: ChatType
title: str title: str
initial_connectors: Optional[List[str]] = None initial_connectors: list[str] | None = None
messages: List[Any] messages: list[Any]
search_space_id: int search_space_id: int
class ClientAttachment(BaseModel): class ClientAttachment(BaseModel):
name: str name: str
contentType: str content_type: str
url: str url: str
class ToolInvocation(BaseModel): class ToolInvocation(BaseModel):
toolCallId: str tool_call_id: str
toolName: str tool_name: str
args: dict args: dict
result: dict result: dict
@ -33,15 +34,19 @@ class ToolInvocation(BaseModel):
# experimental_attachments: Optional[List[ClientAttachment]] = None # experimental_attachments: Optional[List[ClientAttachment]] = None
# toolInvocations: Optional[List[ToolInvocation]] = None # toolInvocations: Optional[List[ToolInvocation]] = None
class AISDKChatRequest(BaseModel): class AISDKChatRequest(BaseModel):
messages: List[Any] messages: list[Any]
data: Optional[Dict[str, Any]] = None data: dict[str, Any] | None = None
class ChatCreate(ChatBase): class ChatCreate(ChatBase):
pass pass
class ChatUpdate(ChatBase): class ChatUpdate(ChatBase):
pass pass
class ChatRead(ChatBase, IDModel, TimestampModel): class ChatRead(ChatBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -1,15 +1,20 @@
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .base import IDModel, TimestampModel from .base import IDModel, TimestampModel
class ChunkBase(BaseModel): class ChunkBase(BaseModel):
content: str content: str
document_id: int document_id: int
class ChunkCreate(ChunkBase): class ChunkCreate(ChunkBase):
pass pass
class ChunkUpdate(ChunkBase): class ChunkUpdate(ChunkBase):
pass pass
class ChunkRead(ChunkBase, IDModel, TimestampModel): class ChunkRead(ChunkBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -1,8 +1,10 @@
from typing import List
from pydantic import BaseModel, ConfigDict
from app.db import DocumentType
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, ConfigDict
from app.db import DocumentType
class ExtensionDocumentMetadata(BaseModel): class ExtensionDocumentMetadata(BaseModel):
BrowsingSessionId: str BrowsingSessionId: str
VisitedWebPageURL: str VisitedWebPageURL: str
@ -11,21 +13,28 @@ class ExtensionDocumentMetadata(BaseModel):
VisitedWebPageReffererURL: str VisitedWebPageReffererURL: str
VisitedWebPageVisitDurationInMilliseconds: str VisitedWebPageVisitDurationInMilliseconds: str
class ExtensionDocumentContent(BaseModel): class ExtensionDocumentContent(BaseModel):
metadata: ExtensionDocumentMetadata metadata: ExtensionDocumentMetadata
pageContent: str pageContent: str # noqa: N815
class DocumentBase(BaseModel): class DocumentBase(BaseModel):
document_type: DocumentType document_type: DocumentType
content: List[ExtensionDocumentContent] | List[str] | str # Updated to allow string content content: (
list[ExtensionDocumentContent] | list[str] | str
) # Updated to allow string content
search_space_id: int search_space_id: int
class DocumentsCreate(DocumentBase): class DocumentsCreate(DocumentBase):
pass pass
class DocumentUpdate(DocumentBase): class DocumentUpdate(DocumentBase):
pass pass
class DocumentRead(BaseModel): class DocumentRead(BaseModel):
id: int id: int
title: str title: str
@ -36,4 +45,3 @@ class DocumentRead(BaseModel):
search_space_id: int search_space_id: int
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -1,30 +1,57 @@
from datetime import datetime
import uuid import uuid
from typing import Optional, Dict, Any from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from .base import IDModel, TimestampModel
from app.db import LiteLLMProvider from app.db import LiteLLMProvider
from .base import IDModel, TimestampModel
class LLMConfigBase(BaseModel): class LLMConfigBase(BaseModel):
name: str = Field(..., max_length=100, description="User-friendly name for the LLM configuration") name: str = Field(
..., max_length=100, description="User-friendly name for the LLM configuration"
)
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type") provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM") custom_provider: str | None = Field(
model_name: str = Field(..., max_length=100, description="Model name without provider prefix") None, max_length=100, description="Custom provider name when provider is CUSTOM"
)
model_name: str = Field(
..., max_length=100, description="Model name without provider prefix"
)
api_key: str = Field(..., description="API key for the provider") api_key: str = Field(..., description="API key for the provider")
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL") api_base: str | None = Field(
litellm_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional LiteLLM parameters") None, max_length=500, description="Optional API base URL"
)
litellm_params: dict[str, Any] | None = Field(
default=None, description="Additional LiteLLM parameters"
)
class LLMConfigCreate(LLMConfigBase): class LLMConfigCreate(LLMConfigBase):
pass pass
class LLMConfigUpdate(BaseModel): class LLMConfigUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100, description="User-friendly name for the LLM configuration") name: str | None = Field(
provider: Optional[LiteLLMProvider] = Field(None, description="LiteLLM provider type") None, max_length=100, description="User-friendly name for the LLM configuration"
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM") )
model_name: Optional[str] = Field(None, max_length=100, description="Model name without provider prefix") provider: LiteLLMProvider | None = Field(None, description="LiteLLM provider type")
api_key: Optional[str] = Field(None, description="API key for the provider") custom_provider: str | None = Field(
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL") None, max_length=100, description="Custom provider name when provider is CUSTOM"
litellm_params: Optional[Dict[str, Any]] = Field(None, description="Additional LiteLLM parameters") )
model_name: str | None = Field(
None, max_length=100, description="Model name without provider prefix"
)
api_key: str | None = Field(None, description="API key for the provider")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
litellm_params: dict[str, Any] | None = Field(
None, description="Additional LiteLLM parameters"
)
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel): class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
id: int id: int

View file

@ -1,30 +1,37 @@
from datetime import datetime from datetime import datetime
from typing import Optional, Dict, Any from typing import Any
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .base import IDModel, TimestampModel
from app.db import LogLevel, LogStatus from app.db import LogLevel, LogStatus
from .base import IDModel, TimestampModel
class LogBase(BaseModel): class LogBase(BaseModel):
level: LogLevel level: LogLevel
status: LogStatus status: LogStatus
message: str message: str
source: Optional[str] = None source: str | None = None
log_metadata: Optional[Dict[str, Any]] = None log_metadata: dict[str, Any] | None = None
class LogCreate(BaseModel): class LogCreate(BaseModel):
level: LogLevel level: LogLevel
status: LogStatus status: LogStatus
message: str message: str
source: Optional[str] = None source: str | None = None
log_metadata: Optional[Dict[str, Any]] = None log_metadata: dict[str, Any] | None = None
search_space_id: int search_space_id: int
class LogUpdate(BaseModel): class LogUpdate(BaseModel):
level: Optional[LogLevel] = None level: LogLevel | None = None
status: Optional[LogStatus] = None status: LogStatus | None = None
message: Optional[str] = None message: str | None = None
source: Optional[str] = None source: str | None = None
log_metadata: Optional[Dict[str, Any]] = None log_metadata: dict[str, Any] | None = None
class LogRead(LogBase, IDModel, TimestampModel): class LogRead(LogBase, IDModel, TimestampModel):
id: int id: int
@ -33,12 +40,13 @@ class LogRead(LogBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class LogFilter(BaseModel): class LogFilter(BaseModel):
level: Optional[LogLevel] = None level: LogLevel | None = None
status: Optional[LogStatus] = None status: LogStatus | None = None
source: Optional[str] = None source: str | None = None
search_space_id: Optional[int] = None search_space_id: int | None = None
start_date: Optional[datetime] = None start_date: datetime | None = None
end_date: Optional[datetime] = None end_date: datetime | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -1,24 +1,31 @@
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing import Any, List, Literal
from .base import IDModel, TimestampModel from .base import IDModel, TimestampModel
class PodcastBase(BaseModel): class PodcastBase(BaseModel):
title: str title: str
podcast_transcript: List[Any] podcast_transcript: list[Any]
file_location: str = "" file_location: str = ""
search_space_id: int search_space_id: int
class PodcastCreate(PodcastBase): class PodcastCreate(PodcastBase):
pass pass
class PodcastUpdate(PodcastBase): class PodcastUpdate(PodcastBase):
pass pass
class PodcastRead(PodcastBase, IDModel, TimestampModel): class PodcastRead(PodcastBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class PodcastGenerateRequest(BaseModel): class PodcastGenerateRequest(BaseModel):
type: Literal["DOCUMENT", "CHAT"] type: Literal["DOCUMENT", "CHAT"]
ids: List[int] ids: list[int]
search_space_id: int search_space_id: int
podcast_title: str = "SurfSense Podcast" podcast_title: str = "SurfSense Podcast"

View file

@ -1,9 +1,10 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from app.db import SearchSourceConnectorType from app.db import SearchSourceConnectorType
from pydantic import BaseModel, ConfigDict, field_validator
from .base import IDModel, TimestampModel from .base import IDModel, TimestampModel
@ -12,14 +13,14 @@ class SearchSourceConnectorBase(BaseModel):
name: str name: str
connector_type: SearchSourceConnectorType connector_type: SearchSourceConnectorType
is_indexable: bool is_indexable: bool
last_indexed_at: Optional[datetime] = None last_indexed_at: datetime | None = None
config: Dict[str, Any] config: dict[str, Any]
@field_validator("config") @field_validator("config")
@classmethod @classmethod
def validate_config_for_connector_type( def validate_config_for_connector_type(
cls, config: Dict[str, Any], values: Dict[str, Any] cls, config: dict[str, Any], values: dict[str, Any]
) -> Dict[str, Any]: ) -> dict[str, Any]:
connector_type = values.data.get("connector_type") connector_type = values.data.get("connector_type")
if connector_type == SearchSourceConnectorType.SERPER_API: if connector_type == SearchSourceConnectorType.SERPER_API:
@ -150,11 +151,11 @@ class SearchSourceConnectorCreate(SearchSourceConnectorBase):
class SearchSourceConnectorUpdate(BaseModel): class SearchSourceConnectorUpdate(BaseModel):
name: Optional[str] = None name: str | None = None
connector_type: Optional[SearchSourceConnectorType] = None connector_type: SearchSourceConnectorType | None = None
is_indexable: Optional[bool] = None is_indexable: bool | None = None
last_indexed_at: Optional[datetime] = None last_indexed_at: datetime | None = None
config: Optional[Dict[str, Any]] = None config: dict[str, Any] | None = None
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel): class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):

View file

@ -1,19 +1,24 @@
from datetime import datetime
import uuid import uuid
from typing import Optional from datetime import datetime
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .base import IDModel, TimestampModel from .base import IDModel, TimestampModel
class SearchSpaceBase(BaseModel): class SearchSpaceBase(BaseModel):
name: str name: str
description: Optional[str] = None description: str | None = None
class SearchSpaceCreate(SearchSpaceBase): class SearchSpaceCreate(SearchSpaceBase):
pass pass
class SearchSpaceUpdate(SearchSpaceBase): class SearchSpaceUpdate(SearchSpaceBase):
pass pass
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel): class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
id: int id: int
created_at: datetime created_at: datetime

View file

@ -1,11 +1,15 @@
import uuid import uuid
from fastapi_users import schemas from fastapi_users import schemas
class UserRead(schemas.BaseUser[uuid.UUID]): class UserRead(schemas.BaseUser[uuid.UUID]):
pass pass
class UserCreate(schemas.BaseUserCreate): class UserCreate(schemas.BaseUserCreate):
pass pass
class UserUpdate(schemas.BaseUserUpdate): class UserUpdate(schemas.BaseUserUpdate):
pass pass

View file

@ -1,5 +1,11 @@
import asyncio import asyncio
from typing import Dict, List, Optional from typing import Any
from linkup import LinkupClient
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from tavily import TavilyClient
from app.agents.researcher.configuration import SearchMode from app.agents.researcher.configuration import SearchMode
from app.db import ( from app.db import (
@ -11,15 +17,10 @@ from app.db import (
) )
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
from linkup import LinkupClient
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from tavily import TavilyClient
class ConnectorService: class ConnectorService:
def __init__(self, session: AsyncSession, user_id: str = None): def __init__(self, session: AsyncSession, user_id: str | None = None):
self.session = session self.session = session
self.chunk_retriever = ChucksHybridSearchRetriever(session) self.chunk_retriever = ChucksHybridSearchRetriever(session)
self.document_retriever = DocumentHybridSearchRetriever(session) self.document_retriever = DocumentHybridSearchRetriever(session)
@ -52,7 +53,7 @@ class ConnectorService:
f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}" f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}"
) )
except Exception as e: except Exception as e:
print(f"Error initializing source_id_counter: {str(e)}") print(f"Error initializing source_id_counter: {e!s}")
# Fallback to default value # Fallback to default value
self.source_id_counter = 1 self.source_id_counter = 1
@ -204,7 +205,9 @@ class ConnectorService:
return result_object, files_chunks return result_object, files_chunks
def _transform_document_results(self, document_results: List[Dict]) -> List[Dict]: def _transform_document_results(
self, document_results: list[dict[str, Any]]
) -> list[dict[str, Any]]:
""" """
Transform results from document_retriever.hybrid_search() to match the format Transform results from document_retriever.hybrid_search() to match the format
expected by the processing code. expected by the processing code.
@ -233,7 +236,7 @@ class ConnectorService:
async def get_connector_by_type( async def get_connector_by_type(
self, user_id: str, connector_type: SearchSourceConnectorType self, user_id: str, connector_type: SearchSourceConnectorType
) -> Optional[SearchSourceConnector]: ) -> SearchSourceConnector | None:
""" """
Get a connector by type for a specific user Get a connector by type for a specific user
@ -350,7 +353,7 @@ class ConnectorService:
except Exception as e: except Exception as e:
# Log the error and return empty results # Log the error and return empty results
print(f"Error searching with Tavily: {str(e)}") print(f"Error searching with Tavily: {e!s}")
return { return {
"id": 3, "id": 3,
"name": "Tavily Search", "name": "Tavily Search",
@ -596,7 +599,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication # Process each chunk and create sources directly without deduplication
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(extension_chunks): for _, chunk in enumerate(extension_chunks):
# Extract document metadata # Extract document metadata
document = chunk.get("document", {}) document = chunk.get("document", {})
metadata = document.get("metadata", {}) metadata = document.get("metadata", {})
@ -608,7 +611,7 @@ class ConnectorService:
visit_duration = metadata.get( visit_duration = metadata.get(
"VisitedWebPageVisitDurationInMilliseconds", "" "VisitedWebPageVisitDurationInMilliseconds", ""
) )
browsing_session_id = metadata.get("BrowsingSessionId", "") _browsing_session_id = metadata.get("BrowsingSessionId", "")
# Create a more descriptive title for extension data # Create a more descriptive title for extension data
title = webpage_title title = webpage_title
@ -622,7 +625,7 @@ class ConnectorService:
else visit_date else visit_date
) )
title += f" (visited: {formatted_date})" title += f" (visited: {formatted_date})"
except: except Exception:
# Fallback if date parsing fails # Fallback if date parsing fails
title += f" (visited: {visit_date})" title += f" (visited: {visit_date})"
@ -642,7 +645,7 @@ class ConnectorService:
if description: if description:
description += f" | Duration: {duration_text}" description += f" | Duration: {duration_text}"
except: except Exception:
# Fallback if duration parsing fails # Fallback if duration parsing fails
pass pass
@ -1180,7 +1183,7 @@ class ConnectorService:
except Exception as e: except Exception as e:
# Log the error and return empty results # Log the error and return empty results
print(f"Error searching with Linkup: {str(e)}") print(f"Error searching with Linkup: {e!s}")
return { return {
"id": 10, "id": 10,
"name": "Linkup Search", "name": "Linkup Search",
@ -1239,7 +1242,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication # Process each chunk and create sources directly without deduplication
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(discord_chunks): for _, chunk in enumerate(discord_chunks):
# Extract document metadata # Extract document metadata
document = chunk.get("document", {}) document = chunk.get("document", {})
metadata = document.get("metadata", {}) metadata = document.get("metadata", {})

View file

@ -5,12 +5,13 @@ SSL-safe implementation with pre-downloaded models
""" """
import logging import logging
import ssl
import os import os
from typing import Dict, Any import ssl
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DoclingService: class DoclingService:
"""Docling service for enhanced document processing with SSL fixes.""" """Docling service for enhanced document processing with SSL fixes."""
@ -29,11 +30,12 @@ class DoclingService:
ssl._create_default_https_context = ssl._create_unverified_context ssl._create_default_https_context = ssl._create_unverified_context
# Set SSL environment variables if not already set # Set SSL environment variables if not already set
if not os.environ.get('SSL_CERT_FILE'): if not os.environ.get("SSL_CERT_FILE"):
try: try:
import certifi import certifi
os.environ['SSL_CERT_FILE'] = certifi.where()
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where() os.environ["SSL_CERT_FILE"] = certifi.where()
os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
except ImportError: except ImportError:
pass pass
@ -45,6 +47,7 @@ class DoclingService:
"""Check and configure GPU support for WSL2 environment.""" """Check and configure GPU support for WSL2 environment."""
try: try:
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
gpu_count = torch.cuda.device_count() gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown" gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown"
@ -64,10 +67,10 @@ class DoclingService:
def _initialize_docling(self): def _initialize_docling(self):
"""Initialize Docling with version-safe configuration.""" """Initialize Docling with version-safe configuration."""
try: try:
from docling.document_converter import DocumentConverter, PdfFormatOption from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.base_models import InputFormat from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend from docling.document_converter import DocumentConverter, PdfFormatOption
logger.info("🔧 Initializing Docling with version-safe configuration...") logger.info("🔧 Initializing Docling with version-safe configuration...")
@ -75,19 +78,19 @@ class DoclingService:
pipeline_options = PdfPipelineOptions() pipeline_options = PdfPipelineOptions()
# Disable OCR (user request) # Disable OCR (user request)
if hasattr(pipeline_options, 'do_ocr'): if hasattr(pipeline_options, "do_ocr"):
pipeline_options.do_ocr = False pipeline_options.do_ocr = False
logger.info("⚠️ OCR disabled by user request") logger.info("⚠️ OCR disabled by user request")
else: else:
logger.warning("⚠️ OCR attribute not available in this Docling version") logger.warning("⚠️ OCR attribute not available in this Docling version")
# Enable table structure if available # Enable table structure if available
if hasattr(pipeline_options, 'do_table_structure'): if hasattr(pipeline_options, "do_table_structure"):
pipeline_options.do_table_structure = True pipeline_options.do_table_structure = True
logger.info("✅ Table structure detection enabled") logger.info("✅ Table structure detection enabled")
# Configure GPU acceleration for WSL2 if available # Configure GPU acceleration for WSL2 if available
if hasattr(pipeline_options, 'accelerator_device'): if hasattr(pipeline_options, "accelerator_device"):
if self.use_gpu: if self.use_gpu:
try: try:
pipeline_options.accelerator_device = "cuda" pipeline_options.accelerator_device = "cuda"
@ -99,98 +102,112 @@ class DoclingService:
pipeline_options.accelerator_device = "cpu" pipeline_options.accelerator_device = "cpu"
logger.info("🖥️ Using CPU acceleration") logger.info("🖥️ Using CPU acceleration")
else: else:
logger.info(" Accelerator device attribute not available in this Docling version") logger.info(
"⚠️ Accelerator device attribute not available in this Docling version"
)
# Create PDF format option with backend # Create PDF format option with backend
pdf_format_option = PdfFormatOption( pdf_format_option = PdfFormatOption(
pipeline_options=pipeline_options, pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend
backend=PyPdfiumDocumentBackend
) )
# Initialize DocumentConverter # Initialize DocumentConverter
self.converter = DocumentConverter( self.converter = DocumentConverter(
format_options={ format_options={InputFormat.PDF: pdf_format_option}
InputFormat.PDF: pdf_format_option
}
) )
acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU" acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU"
logger.info(f"✅ Docling initialized successfully with {acceleration_type} acceleration") logger.info(
f"✅ Docling initialized successfully with {acceleration_type} acceleration"
)
except ImportError as e: except ImportError as e:
logger.error(f"❌ Docling not installed: {e}") logger.error(f"❌ Docling not installed: {e}")
raise RuntimeError(f"Docling not available: {e}") raise RuntimeError(f"Docling not available: {e}") from e
except Exception as e: except Exception as e:
logger.error(f"❌ Docling initialization failed: {e}") logger.error(f"❌ Docling initialization failed: {e}")
raise RuntimeError(f"Docling initialization failed: {e}") raise RuntimeError(f"Docling initialization failed: {e}") from e
def _configure_easyocr_local_models(self): def _configure_easyocr_local_models(self):
"""Configure EasyOCR to use pre-downloaded local models.""" """Configure EasyOCR to use pre-downloaded local models."""
try: try:
import easyocr
import os import os
import easyocr
# Set SSL environment for EasyOCR downloads # Set SSL environment for EasyOCR downloads
os.environ['CURL_CA_BUNDLE'] = '' os.environ["CURL_CA_BUNDLE"] = ""
os.environ['REQUESTS_CA_BUNDLE'] = '' os.environ["REQUESTS_CA_BUNDLE"] = ""
# Try to use local models first, fallback to download if needed # Try to use local models first, fallback to download if needed
try: try:
reader = easyocr.Reader(['en'], reader = easyocr.Reader(
download_enabled=False, ["en"],
model_storage_directory="/root/.EasyOCR/model") download_enabled=False,
model_storage_directory="/root/.EasyOCR/model",
)
logger.info("✅ EasyOCR configured for local models") logger.info("✅ EasyOCR configured for local models")
return reader return reader
except: except Exception:
# If local models fail, allow download with SSL bypass # If local models fail, allow download with SSL bypass
logger.info("🔄 Local models failed, attempting download with SSL bypass...") logger.info(
reader = easyocr.Reader(['en'], "🔄 Local models failed, attempting download with SSL bypass..."
download_enabled=True, )
model_storage_directory="/root/.EasyOCR/model") reader = easyocr.Reader(
["en"],
download_enabled=True,
model_storage_directory="/root/.EasyOCR/model",
)
logger.info("✅ EasyOCR configured with downloaded models") logger.info("✅ EasyOCR configured with downloaded models")
return reader return reader
except Exception as e: except Exception as e:
logger.warning(f"⚠️ EasyOCR configuration failed: {e}") logger.warning(f"⚠️ EasyOCR configuration failed: {e}")
return None return None
async def process_document(self, file_path: str, filename: str = None) -> Dict[str, Any]: async def process_document(
self, file_path: str, filename: str | None = None
) -> dict[str, Any]:
"""Process document with Docling using pre-downloaded models.""" """Process document with Docling using pre-downloaded models."""
if self.converter is None: if self.converter is None:
raise RuntimeError("Docling converter not initialized") raise RuntimeError("Docling converter not initialized")
try: try:
logger.info(f"🔄 Processing {filename} with Docling (using local models)...") logger.info(
f"🔄 Processing {filename} with Docling (using local models)..."
)
# Process document with local models # Process document with local models
result = self.converter.convert(file_path) result = self.converter.convert(file_path)
# Extract content using version-safe methods # Extract content using version-safe methods
content = None content = None
if hasattr(result, 'document') and result.document: if hasattr(result, "document") and result.document:
# Try different export methods (version compatibility) # Try different export methods (version compatibility)
if hasattr(result.document, 'export_to_markdown'): if hasattr(result.document, "export_to_markdown"):
content = result.document.export_to_markdown() content = result.document.export_to_markdown()
logger.info("📄 Used export_to_markdown method") logger.info("📄 Used export_to_markdown method")
elif hasattr(result.document, 'to_markdown'): elif hasattr(result.document, "to_markdown"):
content = result.document.to_markdown() content = result.document.to_markdown()
logger.info("📄 Used to_markdown method") logger.info("📄 Used to_markdown method")
elif hasattr(result.document, 'text'): elif hasattr(result.document, "text"):
content = result.document.text content = result.document.text
logger.info("📄 Used text property") logger.info("📄 Used text property")
elif hasattr(result.document, '__str__'): elif hasattr(result.document, "__str__"):
content = str(result.document) content = str(result.document)
logger.info("📄 Used string conversion") logger.info("📄 Used string conversion")
if content: if content:
logger.info(f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)") logger.info(
f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)"
)
return { return {
'content': content, "content": content,
'full_text': content, "full_text": content,
'service_used': 'docling', "service_used": "docling",
'status': 'success', "status": "success",
'processing_notes': 'Processed with Docling using pre-downloaded models' "processing_notes": "Processed with Docling using pre-downloaded models",
} }
else: else:
raise ValueError("No content could be extracted from document") raise ValueError("No content could be extracted from document")
@ -201,14 +218,12 @@ class DoclingService:
logger.error(f"❌ Docling processing failed for {filename}: {e}") logger.error(f"❌ Docling processing failed for {filename}: {e}")
# Log the full error for debugging # Log the full error for debugging
import traceback import traceback
logger.error(f"Full traceback: {traceback.format_exc()}") logger.error(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(f"Docling processing failed: {e}") raise RuntimeError(f"Docling processing failed: {e}") from e
async def process_large_document_summary( async def process_large_document_summary(
self, self, content: str, llm, document_title: str = "Document"
content: str,
llm,
document_title: str = "Document"
) -> str: ) -> str:
""" """
Process large documents using chunked LLM summarization. Process large documents using chunked LLM summarization.
@ -222,24 +237,28 @@ class DoclingService:
Final summary of the document Final summary of the document
""" """
# Large document threshold (100K characters ≈ 25K tokens) # Large document threshold (100K characters ≈ 25K tokens)
LARGE_DOCUMENT_THRESHOLD = 100_000 large_document_threshold = 100_000
if len(content) <= LARGE_DOCUMENT_THRESHOLD: if len(content) <= large_document_threshold:
# For smaller documents, use direct processing # For smaller documents, use direct processing
logger.info(f"📄 Document size: {len(content)} chars - using direct processing") logger.info(
f"📄 Document size: {len(content)} chars - using direct processing"
)
from app.prompts import SUMMARY_PROMPT_TEMPLATE from app.prompts import SUMMARY_PROMPT_TEMPLATE
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
result = await summary_chain.ainvoke({"document": content}) result = await summary_chain.ainvoke({"document": content})
return result.content return result.content
logger.info(f"📚 Large document detected: {len(content)} chars - using chunked processing") logger.info(
f"📚 Large document detected: {len(content)} chars - using chunked processing"
)
# Import chunker from config # Import chunker from config
from app.config import config # Create LLM-optimized chunks (8K tokens max for safety)
from chonkie import OverlapRefinery, RecursiveChunker
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
# Create LLM-optimized chunks (8K tokens max for safety)
from chonkie import RecursiveChunker, OverlapRefinery
llm_chunker = RecursiveChunker( llm_chunker = RecursiveChunker(
chunk_size=8000 # Conservative for most LLMs chunk_size=8000 # Conservative for most LLMs
) )
@ -247,7 +266,7 @@ class DoclingService:
# Apply overlap refinery for context preservation (10% overlap = 800 tokens) # Apply overlap refinery for context preservation (10% overlap = 800 tokens)
overlap_refinery = OverlapRefinery( overlap_refinery = OverlapRefinery(
context_size=0.1, # 10% overlap for context preservation context_size=0.1, # 10% overlap for context preservation
method="suffix" # Add next chunk context to current chunk method="suffix", # Add next chunk context to current chunk
) )
# First chunk the content, then apply overlap refinery # First chunk the content, then apply overlap refinery
@ -274,21 +293,25 @@ Chunk {chunk_number}/{total_chunks}:
<document_chunk> <document_chunk>
{chunk} {chunk}
</document_chunk> </document_chunk>
</INSTRUCTIONS>""" </INSTRUCTIONS>""",
) )
# Process each chunk individually # Process each chunk individually
chunk_summaries = [] chunk_summaries = []
for i, chunk in enumerate(chunks, 1): for i, chunk in enumerate(chunks, 1):
try: try:
logger.info(f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)") logger.info(
f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)"
)
chunk_chain = chunk_template | llm chunk_chain = chunk_template | llm
chunk_result = await chunk_chain.ainvoke({ chunk_result = await chunk_chain.ainvoke(
"chunk": chunk.text, {
"chunk_number": i, "chunk": chunk.text,
"total_chunks": total_chunks "chunk_number": i,
}) "total_chunks": total_chunks,
}
)
chunk_summary = chunk_result.content chunk_summary = chunk_result.content
chunk_summaries.append(f"=== Section {i} ===\n{chunk_summary}") chunk_summaries.append(f"=== Section {i} ===\n{chunk_summary}")
@ -318,19 +341,20 @@ Ensure:
<section_summaries> <section_summaries>
{summaries} {summaries}
</section_summaries> </section_summaries>
</INSTRUCTIONS>""" </INSTRUCTIONS>""",
) )
combined_summaries = "\n\n".join(chunk_summaries) combined_summaries = "\n\n".join(chunk_summaries)
combine_chain = combine_template | llm combine_chain = combine_template | llm
final_result = await combine_chain.ainvoke({ final_result = await combine_chain.ainvoke(
"summaries": combined_summaries, {"summaries": combined_summaries, "document_title": document_title}
"document_title": document_title )
})
final_summary = final_result.content final_summary = final_result.content
logger.info(f"✅ Large document processing complete: {len(final_summary)} chars summary") logger.info(
f"✅ Large document processing complete: {len(final_summary)} chars summary"
)
return final_summary return final_summary
@ -341,6 +365,7 @@ Ensure:
logger.warning("⚠️ Using fallback combined summary") logger.warning("⚠️ Using fallback combined summary")
return fallback_summary return fallback_summary
def create_docling_service() -> DoclingService: def create_docling_service() -> DoclingService:
"""Create a Docling service instance.""" """Create a Docling service instance."""
return DoclingService() return DoclingService()

View file

@ -1,23 +1,23 @@
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from langchain_community.chat_models import ChatLiteLLM
import logging import logging
from app.db import User, LLMConfig from langchain_community.chat_models import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import LLMConfig, User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LLMRole: class LLMRole:
LONG_CONTEXT = "long_context" LONG_CONTEXT = "long_context"
FAST = "fast" FAST = "fast"
STRATEGIC = "strategic" STRATEGIC = "strategic"
async def get_user_llm_instance( async def get_user_llm_instance(
session: AsyncSession, session: AsyncSession, user_id: str, role: str
user_id: str, ) -> ChatLiteLLM | None:
role: str
) -> Optional[ChatLiteLLM]:
""" """
Get a ChatLiteLLM instance for a specific user and role. Get a ChatLiteLLM instance for a specific user and role.
@ -31,9 +31,7 @@ async def get_user_llm_instance(
""" """
try: try:
# Get user with their LLM preferences # Get user with their LLM preferences
result = await session.execute( result = await session.execute(select(User).where(User.id == user_id))
select(User).where(User.id == user_id)
)
user = result.scalars().first() user = result.scalars().first()
if not user: if not user:
@ -59,8 +57,7 @@ async def get_user_llm_instance(
# Get the LLM configuration # Get the LLM configuration
result = await session.execute( result = await session.execute(
select(LLMConfig).where( select(LLMConfig).where(
LLMConfig.id == llm_config_id, LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id
LLMConfig.user_id == user_id
) )
) )
llm_config = result.scalars().first() llm_config = result.scalars().first()
@ -84,7 +81,9 @@ async def get_user_llm_instance(
"MISTRAL": "mistral", "MISTRAL": "mistral",
# Add more mappings as needed # Add more mappings as needed
} }
provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower()) provider_prefix = provider_map.get(
llm_config.provider.value, llm_config.provider.value.lower()
)
model_string = f"{provider_prefix}/{llm_config.model_name}" model_string = f"{provider_prefix}/{llm_config.model_name}"
# Create ChatLiteLLM instance # Create ChatLiteLLM instance
@ -104,17 +103,26 @@ async def get_user_llm_instance(
return ChatLiteLLM(**litellm_kwargs) return ChatLiteLLM(**litellm_kwargs)
except Exception as e: except Exception as e:
logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}") logger.error(
f"Error getting LLM instance for user {user_id}, role {role}: {e!s}"
)
return None return None
async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
async def get_user_long_context_llm(
session: AsyncSession, user_id: str
) -> ChatLiteLLM | None:
"""Get user's long context LLM instance.""" """Get user's long context LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT) return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None:
"""Get user's fast LLM instance.""" """Get user's fast LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.FAST) return await get_user_llm_instance(session, user_id, LLMRole.FAST)
async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
async def get_user_strategic_llm(
session: AsyncSession, user_id: str
) -> ChatLiteLLM | None:
"""Get user's strategic LLM instance.""" """Get user's strategic LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC) return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)

View file

@ -1,9 +1,10 @@
import datetime import datetime
from langchain.schema import HumanMessage, SystemMessage, AIMessage from typing import Any
from app.config import config
from app.services.llm_service import get_user_strategic_llm from langchain.schema import AIMessage, HumanMessage, SystemMessage
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from typing import Any, List, Optional
from app.services.llm_service import get_user_strategic_llm
class QueryService: class QueryService:
@ -16,7 +17,7 @@ class QueryService:
user_query: str, user_query: str,
session: AsyncSession, session: AsyncSession,
user_id: str, user_id: str,
chat_history_str: Optional[str] = None chat_history_str: str | None = None,
) -> str: ) -> str:
""" """
Reformulate the user query using the user's strategic LLM to make it more Reformulate the user query using the user's strategic LLM to make it more
@ -38,7 +39,9 @@ class QueryService:
# Get the user's strategic LLM instance # Get the user's strategic LLM instance
llm = await get_user_strategic_llm(session, user_id) llm = await get_user_strategic_llm(session, user_id)
if not llm: if not llm:
print(f"Warning: No strategic LLM configured for user {user_id}. Using original query.") print(
f"Warning: No strategic LLM configured for user {user_id}. Using original query."
)
return user_query return user_query
# Create system message with instructions # Create system message with instructions
@ -92,9 +95,8 @@ class QueryService:
print(f"Error reformulating query: {e}") print(f"Error reformulating query: {e}")
return user_query return user_query
@staticmethod @staticmethod
async def langchain_chat_history_to_str(chat_history: List[Any]) -> str: async def langchain_chat_history_to_str(chat_history: list[Any]) -> str:
""" """
Convert a list of chat history messages to a string. Convert a list of chat history messages to a string.
""" """

View file

@ -1,7 +1,9 @@
import logging import logging
from typing import List, Dict, Any, Optional from typing import Any, Optional
from rerankers import Document as RerankerDocument from rerankers import Document as RerankerDocument
class RerankerService: class RerankerService:
""" """
Service for reranking documents using a configured reranker Service for reranking documents using a configured reranker
@ -16,7 +18,9 @@ class RerankerService:
""" """
self.reranker_instance = reranker_instance self.reranker_instance = reranker_instance
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def rerank_documents(
self, query_text: str, documents: list[dict[str, Any]]
) -> list[dict[str, Any]]:
""" """
Rerank documents using the configured reranker Rerank documents using the configured reranker
@ -44,18 +48,17 @@ class RerankerService:
text=content, text=content,
doc_id=chunk_id, doc_id=chunk_id,
metadata={ metadata={
'document_id': document_info.get("id", ""), "document_id": document_info.get("id", ""),
'document_title': document_info.get("title", ""), "document_title": document_info.get("title", ""),
'document_type': document_info.get("document_type", ""), "document_type": document_info.get("document_type", ""),
'rrf_score': score "rrf_score": score,
} },
) )
) )
# Rerank using the configured reranker # Rerank using the configured reranker
reranking_results = self.reranker_instance.rank( reranking_results = self.reranker_instance.rank(
query=query_text, query=query_text, docs=reranker_docs
docs=reranker_docs
) )
# Process the results from the reranker # Process the results from the reranker
@ -63,7 +66,14 @@ class RerankerService:
serialized_results = [] serialized_results = []
for result in reranking_results.results: for result in reranking_results.results:
# Find the original document by id # Find the original document by id
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None) original_doc = next(
(
doc
for doc in documents
if doc.get("chunk_id") == result.document.doc_id
),
None,
)
if original_doc: if original_doc:
# Create a new document with the reranked score # Create a new document with the reranked score
reranked_doc = original_doc.copy() reranked_doc = original_doc.copy()
@ -75,12 +85,12 @@ class RerankerService:
except Exception as e: except Exception as e:
# Log the error # Log the error
logging.error(f"Error during reranking: {str(e)}") logging.error(f"Error during reranking: {e!s}")
# Fall back to original documents without reranking # Fall back to original documents without reranking
return documents return documents
@staticmethod @staticmethod
def get_reranker_instance() -> Optional['RerankerService']: def get_reranker_instance() -> Optional["RerankerService"]:
""" """
Get a reranker service instance from the global configuration. Get a reranker service instance from the global configuration.
@ -89,7 +99,6 @@ class RerankerService:
""" """
from app.config import config from app.config import config
if hasattr(config, 'reranker_instance') and config.reranker_instance: if hasattr(config, "reranker_instance") and config.reranker_instance:
return RerankerService(config.reranker_instance) return RerankerService(config.reranker_instance)
return None return None

View file

@ -1,27 +1,15 @@
import json import json
from typing import Any, Dict, List from typing import Any
class StreamingService: class StreamingService:
def __init__(self): def __init__(self):
self.terminal_idx = 1 self.terminal_idx = 1
self.message_annotations = [ self.message_annotations = [
{ {"type": "TERMINAL_INFO", "content": []},
"type": "TERMINAL_INFO", {"type": "SOURCES", "content": []},
"content": [] {"type": "ANSWER", "content": []},
}, {"type": "FURTHER_QUESTIONS", "content": []},
{
"type": "SOURCES",
"content": []
},
{
"type": "ANSWER",
"content": []
},
{
"type": "FURTHER_QUESTIONS",
"content": []
}
] ]
# DEPRECATED: This sends the full annotation array every time (inefficient) # DEPRECATED: This sends the full annotation array every time (inefficient)
@ -35,7 +23,7 @@ class StreamingService:
Returns: Returns:
str: The formatted annotations string str: The formatted annotations string
""" """
return f'8:{json.dumps(self.message_annotations)}\n' return f"8:{json.dumps(self.message_annotations)}\n"
def format_terminal_info_delta(self, text: str, message_type: str = "info") -> str: def format_terminal_info_delta(self, text: str, message_type: str = "info") -> str:
""" """
@ -58,7 +46,7 @@ class StreamingService:
annotation = {"type": "TERMINAL_INFO", "content": [message]} annotation = {"type": "TERMINAL_INFO", "content": [message]}
return f"8:[{json.dumps(annotation)}]\n" return f"8:[{json.dumps(annotation)}]\n"
def format_sources_delta(self, sources: List[Dict[str, Any]]) -> str: def format_sources_delta(self, sources: list[dict[str, Any]]) -> str:
""" """
Format sources as a delta annotation Format sources as a delta annotation
@ -95,7 +83,7 @@ class StreamingService:
annotation = {"type": "ANSWER", "content": [answer_chunk]} annotation = {"type": "ANSWER", "content": [answer_chunk]}
return f"8:[{json.dumps(annotation)}]\n" return f"8:[{json.dumps(annotation)}]\n"
def format_answer_annotation(self, answer_lines: List[str]) -> str: def format_answer_annotation(self, answer_lines: list[str]) -> str:
""" """
Format the complete answer as a replacement annotation Format the complete answer as a replacement annotation
@ -113,7 +101,7 @@ class StreamingService:
return f"8:[{json.dumps(annotation)}]\n" return f"8:[{json.dumps(annotation)}]\n"
def format_further_questions_delta( def format_further_questions_delta(
self, further_questions: List[Dict[str, Any]] self, further_questions: list[dict[str, Any]]
) -> str: ) -> str:
""" """
Format further questions as a delta annotation Format further questions as a delta annotation
@ -155,7 +143,9 @@ class StreamingService:
""" """
return f"3:{json.dumps(error_message)}\n" return f"3:{json.dumps(error_message)}\n"
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str: def format_completion(
self, prompt_tokens: int = 156, completion_tokens: int = 204
) -> str:
""" """
Format a completion message Format a completion message
@ -172,7 +162,7 @@ class StreamingService:
"usage": { "usage": {
"promptTokens": prompt_tokens, "promptTokens": prompt_tokens,
"completionTokens": completion_tokens, "completionTokens": completion_tokens,
"totalTokens": total_tokens "totalTokens": total_tokens,
} },
} }
return f'd:{json.dumps(completion_data)}\n' return f"d:{json.dumps(completion_data)}\n"

View file

@ -1,12 +1,14 @@
from typing import Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Log, LogLevel, LogStatus
import logging import logging
import json
from datetime import datetime from datetime import datetime
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Log, LogLevel, LogStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TaskLoggingService: class TaskLoggingService:
"""Service for logging background tasks using the database Log model""" """Service for logging background tasks using the database Log model"""
@ -19,7 +21,7 @@ class TaskLoggingService:
task_name: str, task_name: str,
source: str, source: str,
message: str, message: str,
metadata: Optional[Dict[str, Any]] = None metadata: dict[str, Any] | None = None,
) -> Log: ) -> Log:
""" """
Log the start of a task with IN_PROGRESS status Log the start of a task with IN_PROGRESS status
@ -34,10 +36,9 @@ class TaskLoggingService:
Log: The created log entry Log: The created log entry
""" """
log_metadata = metadata or {} log_metadata = metadata or {}
log_metadata.update({ log_metadata.update(
"task_name": task_name, {"task_name": task_name, "started_at": datetime.utcnow().isoformat()}
"started_at": datetime.utcnow().isoformat() )
})
log_entry = Log( log_entry = Log(
level=LogLevel.INFO, level=LogLevel.INFO,
@ -45,7 +46,7 @@ class TaskLoggingService:
message=message, message=message,
source=source, source=source,
log_metadata=log_metadata, log_metadata=log_metadata,
search_space_id=self.search_space_id search_space_id=self.search_space_id,
) )
self.session.add(log_entry) self.session.add(log_entry)
@ -59,7 +60,7 @@ class TaskLoggingService:
self, self,
log_entry: Log, log_entry: Log,
message: str, message: str,
additional_metadata: Optional[Dict[str, Any]] = None additional_metadata: dict[str, Any] | None = None,
) -> Log: ) -> Log:
""" """
Update a log entry to SUCCESS status Update a log entry to SUCCESS status
@ -86,7 +87,11 @@ class TaskLoggingService:
await self.session.commit() await self.session.commit()
await self.session.refresh(log_entry) await self.session.refresh(log_entry)
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown" task_name = (
log_entry.log_metadata.get("task_name", "unknown")
if log_entry.log_metadata
else "unknown"
)
logger.info(f"Completed task {task_name}: {message}") logger.info(f"Completed task {task_name}: {message}")
return log_entry return log_entry
@ -94,8 +99,8 @@ class TaskLoggingService:
self, self,
log_entry: Log, log_entry: Log,
error_message: str, error_message: str,
error_details: Optional[str] = None, error_details: str | None = None,
additional_metadata: Optional[Dict[str, Any]] = None additional_metadata: dict[str, Any] | None = None,
) -> Log: ) -> Log:
""" """
Update a log entry to FAILED status Update a log entry to FAILED status
@ -118,10 +123,9 @@ class TaskLoggingService:
if log_entry.log_metadata is None: if log_entry.log_metadata is None:
log_entry.log_metadata = {} log_entry.log_metadata = {}
log_entry.log_metadata.update({ log_entry.log_metadata.update(
"failed_at": datetime.utcnow().isoformat(), {"failed_at": datetime.utcnow().isoformat(), "error_details": error_details}
"error_details": error_details )
})
if additional_metadata: if additional_metadata:
log_entry.log_metadata.update(additional_metadata) log_entry.log_metadata.update(additional_metadata)
@ -129,7 +133,11 @@ class TaskLoggingService:
await self.session.commit() await self.session.commit()
await self.session.refresh(log_entry) await self.session.refresh(log_entry)
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown" task_name = (
log_entry.log_metadata.get("task_name", "unknown")
if log_entry.log_metadata
else "unknown"
)
logger.error(f"Failed task {task_name}: {error_message}") logger.error(f"Failed task {task_name}: {error_message}")
if error_details: if error_details:
logger.error(f"Error details: {error_details}") logger.error(f"Error details: {error_details}")
@ -140,7 +148,7 @@ class TaskLoggingService:
self, self,
log_entry: Log, log_entry: Log,
progress_message: str, progress_message: str,
progress_metadata: Optional[Dict[str, Any]] = None progress_metadata: dict[str, Any] | None = None,
) -> Log: ) -> Log:
""" """
Update a log entry with progress information while keeping IN_PROGRESS status Update a log entry with progress information while keeping IN_PROGRESS status
@ -159,12 +167,18 @@ class TaskLoggingService:
if log_entry.log_metadata is None: if log_entry.log_metadata is None:
log_entry.log_metadata = {} log_entry.log_metadata = {}
log_entry.log_metadata.update(progress_metadata) log_entry.log_metadata.update(progress_metadata)
log_entry.log_metadata["last_progress_update"] = datetime.utcnow().isoformat() log_entry.log_metadata["last_progress_update"] = (
datetime.utcnow().isoformat()
)
await self.session.commit() await self.session.commit()
await self.session.refresh(log_entry) await self.session.refresh(log_entry)
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown" task_name = (
log_entry.log_metadata.get("task_name", "unknown")
if log_entry.log_metadata
else "unknown"
)
logger.info(f"Progress update for task {task_name}: {progress_message}") logger.info(f"Progress update for task {task_name}: {progress_message}")
return log_entry return log_entry
@ -173,7 +187,7 @@ class TaskLoggingService:
level: LogLevel, level: LogLevel,
source: str, source: str,
message: str, message: str,
metadata: Optional[Dict[str, Any]] = None metadata: dict[str, Any] | None = None,
) -> Log: ) -> Log:
""" """
Log a simple event (not a long-running task) Log a simple event (not a long-running task)
@ -193,7 +207,7 @@ class TaskLoggingService:
message=message, message=message,
source=source, source=source,
log_metadata=metadata or {}, log_metadata=metadata or {},
search_space_id=self.search_space_id search_space_id=self.search_space_id,
) )
self.session.add(log_entry) self.session.add(log_entry)

View file

@ -1,28 +1,33 @@
from typing import Optional, List import logging
from sqlalchemy.ext.asyncio import AsyncSession from urllib.parse import parse_qs, urlparse
import aiohttp
import validators
from langchain_community.document_loaders import AsyncChromiumLoader, FireCrawlLoader
from langchain_community.document_transformers import MarkdownifyTransformer
from langchain_core.documents import Document as LangChainDocument
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import Document, DocumentType, Chunk from youtube_transcript_api import YouTubeTranscriptApi
from app.schemas import ExtensionDocumentContent
from app.config import config from app.config import config
from app.db import Chunk, Document, DocumentType
from app.prompts import SUMMARY_PROMPT_TEMPLATE from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.utils.document_converters import convert_document_to_markdown, generate_content_hash from app.schemas import ExtensionDocumentContent
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from langchain_core.documents import Document as LangChainDocument from app.utils.document_converters import (
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader convert_document_to_markdown,
from langchain_community.document_transformers import MarkdownifyTransformer generate_content_hash,
import validators )
from youtube_transcript_api import YouTubeTranscriptApi
from urllib.parse import urlparse, parse_qs
import aiohttp
import logging
md = MarkdownifyTransformer() md = MarkdownifyTransformer()
async def add_crawled_url_document( async def add_crawled_url_document(
session: AsyncSession, url: str, search_space_id: int, user_id: str session: AsyncSession, url: str, search_space_id: int, user_id: str
) -> Optional[Document]: ) -> Document | None:
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
# Log task start # Log task start
@ -30,15 +35,13 @@ async def add_crawled_url_document(
task_name="crawl_url_document", task_name="crawl_url_document",
source="background_task", source="background_task",
message=f"Starting URL crawling process for: {url}", message=f"Starting URL crawling process for: {url}",
metadata={"url": url, "user_id": str(user_id)} metadata={"url": url, "user_id": str(user_id)},
) )
try: try:
# URL validation step # URL validation step
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry, f"Validating URL: {url}", {"stage": "validation"}
f"Validating URL: {url}",
{"stage": "validation"}
) )
if not validators.url(url): if not validators.url(url):
@ -48,7 +51,10 @@ async def add_crawled_url_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Setting up crawler for URL: {url}", f"Setting up crawler for URL: {url}",
{"stage": "crawler_setup", "firecrawl_available": bool(config.FIRECRAWL_API_KEY)} {
"stage": "crawler_setup",
"firecrawl_available": bool(config.FIRECRAWL_API_KEY),
},
) )
if config.FIRECRAWL_API_KEY: if config.FIRECRAWL_API_KEY:
@ -68,21 +74,21 @@ async def add_crawled_url_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Crawling URL content: {url}", f"Crawling URL content: {url}",
{"stage": "crawling", "crawler_type": type(crawl_loader).__name__} {"stage": "crawling", "crawler_type": type(crawl_loader).__name__},
) )
url_crawled = await crawl_loader.aload() url_crawled = await crawl_loader.aload()
if type(crawl_loader) == FireCrawlLoader: if isinstance(crawl_loader, FireCrawlLoader):
content_in_markdown = url_crawled[0].page_content content_in_markdown = url_crawled[0].page_content
elif type(crawl_loader) == AsyncChromiumLoader: elif isinstance(crawl_loader, AsyncChromiumLoader):
content_in_markdown = md.transform_documents(url_crawled)[0].page_content content_in_markdown = md.transform_documents(url_crawled)[0].page_content
# Format document # Format document
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing crawled content from: {url}", f"Processing crawled content from: {url}",
{"stage": "content_processing", "content_length": len(content_in_markdown)} {"stage": "content_processing", "content_length": len(content_in_markdown)},
) )
# Format document metadata in a more maintainable way # Format document metadata in a more maintainable way
@ -117,7 +123,7 @@ async def add_crawled_url_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Checking for duplicate content: {url}", f"Checking for duplicate content: {url}",
{"stage": "duplicate_check", "content_hash": content_hash} {"stage": "duplicate_check", "content_hash": content_hash},
) )
# Check if document with this content hash already exists # Check if document with this content hash already exists
@ -130,16 +136,21 @@ async def add_crawled_url_document(
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Document already exists for URL: {url}", f"Document already exists for URL: {url}",
{"duplicate_detected": True, "existing_document_id": existing_document.id} {
"duplicate_detected": True,
"existing_document_id": existing_document.id,
},
)
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
) )
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document return existing_document
# Get LLM for summary generation # Get LLM for summary generation
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Preparing for summary generation: {url}", f"Preparing for summary generation: {url}",
{"stage": "llm_setup"} {"stage": "llm_setup"},
) )
# Get user's long context LLM # Get user's long context LLM
@ -151,7 +162,7 @@ async def add_crawled_url_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Generating summary for URL content: {url}", f"Generating summary for URL content: {url}",
{"stage": "summary_generation"} {"stage": "summary_generation"},
) )
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
@ -165,7 +176,7 @@ async def add_crawled_url_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing content chunks for URL: {url}", f"Processing content chunks for URL: {url}",
{"stage": "chunk_processing"} {"stage": "chunk_processing"},
) )
chunks = [ chunks = [
@ -180,13 +191,13 @@ async def add_crawled_url_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Creating document in database for URL: {url}", f"Creating document in database for URL: {url}",
{"stage": "document_creation", "chunks_count": len(chunks)} {"stage": "document_creation", "chunks_count": len(chunks)},
) )
document = Document( document = Document(
search_space_id=search_space_id, search_space_id=search_space_id,
title=url_crawled[0].metadata["title"] title=url_crawled[0].metadata["title"]
if type(crawl_loader) == FireCrawlLoader if isinstance(crawl_loader, FireCrawlLoader)
else url_crawled[0].metadata["source"], else url_crawled[0].metadata["source"],
document_type=DocumentType.CRAWLED_URL, document_type=DocumentType.CRAWLED_URL,
document_metadata=url_crawled[0].metadata, document_metadata=url_crawled[0].metadata,
@ -209,8 +220,8 @@ async def add_crawled_url_document(
"title": document.title, "title": document.title,
"content_hash": content_hash, "content_hash": content_hash,
"chunks_count": len(chunks), "chunks_count": len(chunks),
"summary_length": len(summary_content) "summary_length": len(summary_content),
} },
) )
return document return document
@ -221,7 +232,7 @@ async def add_crawled_url_document(
log_entry, log_entry,
f"Database error while processing URL: {url}", f"Database error while processing URL: {url}",
str(db_error), str(db_error),
{"error_type": "SQLAlchemyError"} {"error_type": "SQLAlchemyError"},
) )
raise db_error raise db_error
except Exception as e: except Exception as e:
@ -230,14 +241,17 @@ async def add_crawled_url_document(
log_entry, log_entry,
f"Failed to crawl URL: {url}", f"Failed to crawl URL: {url}",
str(e), str(e),
{"error_type": type(e).__name__} {"error_type": type(e).__name__},
) )
raise RuntimeError(f"Failed to crawl URL: {str(e)}") raise RuntimeError(f"Failed to crawl URL: {e!s}") from e
async def add_extension_received_document( async def add_extension_received_document(
session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int, user_id: str session: AsyncSession,
) -> Optional[Document]: content: ExtensionDocumentContent,
search_space_id: int,
user_id: str,
) -> Document | None:
""" """
Process and store document content received from the SurfSense Extension. Process and store document content received from the SurfSense Extension.
@ -259,8 +273,8 @@ async def add_extension_received_document(
metadata={ metadata={
"url": content.metadata.VisitedWebPageURL, "url": content.metadata.VisitedWebPageURL,
"title": content.metadata.VisitedWebPageTitle, "title": content.metadata.VisitedWebPageTitle,
"user_id": str(user_id) "user_id": str(user_id),
} },
) )
try: try:
@ -306,9 +320,14 @@ async def add_extension_received_document(
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Extension document already exists: {content.metadata.VisitedWebPageTitle}", f"Extension document already exists: {content.metadata.VisitedWebPageTitle}",
{"duplicate_detected": True, "existing_document_id": existing_document.id} {
"duplicate_detected": True,
"existing_document_id": existing_document.id,
},
)
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
) )
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document return existing_document
# Get user's long context LLM # Get user's long context LLM
@ -356,8 +375,8 @@ async def add_extension_received_document(
{ {
"document_id": document.id, "document_id": document.id,
"content_hash": content_hash, "content_hash": content_hash,
"url": content.metadata.VisitedWebPageURL "url": content.metadata.VisitedWebPageURL,
} },
) )
return document return document
@ -368,7 +387,7 @@ async def add_extension_received_document(
log_entry, log_entry,
f"Database error processing extension document: {content.metadata.VisitedWebPageTitle}", f"Database error processing extension document: {content.metadata.VisitedWebPageTitle}",
str(db_error), str(db_error),
{"error_type": "SQLAlchemyError"} {"error_type": "SQLAlchemyError"},
) )
raise db_error raise db_error
except Exception as e: except Exception as e:
@ -377,14 +396,18 @@ async def add_extension_received_document(
log_entry, log_entry,
f"Failed to process extension document: {content.metadata.VisitedWebPageTitle}", f"Failed to process extension document: {content.metadata.VisitedWebPageTitle}",
str(e), str(e),
{"error_type": type(e).__name__} {"error_type": type(e).__name__},
) )
raise RuntimeError(f"Failed to process extension document: {str(e)}") raise RuntimeError(f"Failed to process extension document: {e!s}") from e
async def add_received_markdown_file_document( async def add_received_markdown_file_document(
session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int, user_id: str session: AsyncSession,
) -> Optional[Document]: file_name: str,
file_in_markdown: str,
search_space_id: int,
user_id: str,
) -> Document | None:
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
# Log task start # Log task start
@ -392,7 +415,11 @@ async def add_received_markdown_file_document(
task_name="markdown_file_document", task_name="markdown_file_document",
source="background_task", source="background_task",
message=f"Processing markdown file: {file_name}", message=f"Processing markdown file: {file_name}",
metadata={"filename": file_name, "user_id": str(user_id), "content_length": len(file_in_markdown)} metadata={
"filename": file_name,
"user_id": str(user_id),
"content_length": len(file_in_markdown),
},
) )
try: try:
@ -408,9 +435,14 @@ async def add_received_markdown_file_document(
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Markdown file document already exists: {file_name}", f"Markdown file document already exists: {file_name}",
{"duplicate_detected": True, "existing_document_id": existing_document.id} {
"duplicate_detected": True,
"existing_document_id": existing_document.id,
},
)
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
) )
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document return existing_document
# Get user's long context LLM # Get user's long context LLM
@ -459,8 +491,8 @@ async def add_received_markdown_file_document(
"document_id": document.id, "document_id": document.id,
"content_hash": content_hash, "content_hash": content_hash,
"chunks_count": len(chunks), "chunks_count": len(chunks),
"summary_length": len(summary_content) "summary_length": len(summary_content),
} },
) )
return document return document
@ -470,7 +502,7 @@ async def add_received_markdown_file_document(
log_entry, log_entry,
f"Database error processing markdown file: {file_name}", f"Database error processing markdown file: {file_name}",
str(db_error), str(db_error),
{"error_type": "SQLAlchemyError"} {"error_type": "SQLAlchemyError"},
) )
raise db_error raise db_error
except Exception as e: except Exception as e:
@ -479,18 +511,18 @@ async def add_received_markdown_file_document(
log_entry, log_entry,
f"Failed to process markdown file: {file_name}", f"Failed to process markdown file: {file_name}",
str(e), str(e),
{"error_type": type(e).__name__} {"error_type": type(e).__name__},
) )
raise RuntimeError(f"Failed to process file document: {str(e)}") raise RuntimeError(f"Failed to process file document: {e!s}") from e
async def add_received_file_document_using_unstructured( async def add_received_file_document_using_unstructured(
session: AsyncSession, session: AsyncSession,
file_name: str, file_name: str,
unstructured_processed_elements: List[LangChainDocument], unstructured_processed_elements: list[LangChainDocument],
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
) -> Optional[Document]: ) -> Document | None:
try: try:
file_in_markdown = await convert_document_to_markdown( file_in_markdown = await convert_document_to_markdown(
unstructured_processed_elements unstructured_processed_elements
@ -505,7 +537,9 @@ async def add_received_file_document_using_unstructured(
existing_document = existing_doc_result.scalars().first() existing_document = existing_doc_result.scalars().first()
if existing_document: if existing_document:
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
return existing_document return existing_document
# TODO: Check if file_markdown exceeds token limit of embedding model # TODO: Check if file_markdown exceeds token limit of embedding model
@ -555,7 +589,7 @@ async def add_received_file_document_using_unstructured(
raise db_error raise db_error
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise RuntimeError(f"Failed to process file document: {str(e)}") raise RuntimeError(f"Failed to process file document: {e!s}") from e
async def add_received_file_document_using_llamacloud( async def add_received_file_document_using_llamacloud(
@ -564,7 +598,7 @@ async def add_received_file_document_using_llamacloud(
llamacloud_markdown_document: str, llamacloud_markdown_document: str,
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
) -> Optional[Document]: ) -> Document | None:
""" """
Process and store document content parsed by LlamaCloud. Process and store document content parsed by LlamaCloud.
@ -590,7 +624,9 @@ async def add_received_file_document_using_llamacloud(
existing_document = existing_doc_result.scalars().first() existing_document = existing_doc_result.scalars().first()
if existing_document: if existing_document:
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
return existing_document return existing_document
# Get user's long context LLM # Get user's long context LLM
@ -638,7 +674,9 @@ async def add_received_file_document_using_llamacloud(
raise db_error raise db_error
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise RuntimeError(f"Failed to process file document using LlamaCloud: {str(e)}") raise RuntimeError(
f"Failed to process file document using LlamaCloud: {e!s}"
) from e
async def add_received_file_document_using_docling( async def add_received_file_document_using_docling(
@ -647,7 +685,7 @@ async def add_received_file_document_using_docling(
docling_markdown_document: str, docling_markdown_document: str,
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
) -> Optional[Document]: ) -> Document | None:
""" """
Process and store document content parsed by Docling. Process and store document content parsed by Docling.
@ -673,7 +711,9 @@ async def add_received_file_document_using_docling(
existing_document = existing_doc_result.scalars().first() existing_document = existing_doc_result.scalars().first()
if existing_document: if existing_document:
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
)
return existing_document return existing_document
# Get user's long context LLM # Get user's long context LLM
@ -683,12 +723,11 @@ async def add_received_file_document_using_docling(
# Generate summary using chunked processing for large documents # Generate summary using chunked processing for large documents
from app.services.docling_service import create_docling_service from app.services.docling_service import create_docling_service
docling_service = create_docling_service() docling_service = create_docling_service()
summary_content = await docling_service.process_large_document_summary( summary_content = await docling_service.process_large_document_summary(
content=file_in_markdown, content=file_in_markdown, llm=user_llm, document_title=file_name
llm=user_llm,
document_title=file_name
) )
summary_embedding = config.embedding_model_instance.embed(summary_content) summary_embedding = config.embedding_model_instance.embed(summary_content)
@ -726,7 +765,9 @@ async def add_received_file_document_using_docling(
raise db_error raise db_error
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise RuntimeError(f"Failed to process file document using Docling: {str(e)}") raise RuntimeError(
f"Failed to process file document using Docling: {e!s}"
) from e
async def add_youtube_video_document( async def add_youtube_video_document(
@ -755,7 +796,7 @@ async def add_youtube_video_document(
task_name="youtube_video_document", task_name="youtube_video_document",
source="background_task", source="background_task",
message=f"Starting YouTube video processing for: {url}", message=f"Starting YouTube video processing for: {url}",
metadata={"url": url, "user_id": str(user_id)} metadata={"url": url, "user_id": str(user_id)},
) )
try: try:
@ -763,7 +804,7 @@ async def add_youtube_video_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Extracting video ID from URL: {url}", f"Extracting video ID from URL: {url}",
{"stage": "video_id_extraction"} {"stage": "video_id_extraction"},
) )
def get_youtube_video_id(url: str): def get_youtube_video_id(url: str):
@ -790,14 +831,14 @@ async def add_youtube_video_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Video ID extracted: {video_id}", f"Video ID extracted: {video_id}",
{"stage": "video_id_extracted", "video_id": video_id} {"stage": "video_id_extracted", "video_id": video_id},
) )
# Get video metadata # Get video metadata
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Fetching video metadata for: {video_id}", f"Fetching video metadata for: {video_id}",
{"stage": "metadata_fetch"} {"stage": "metadata_fetch"},
) )
params = { params = {
@ -806,21 +847,27 @@ async def add_youtube_video_document(
} }
oembed_url = "https://www.youtube.com/oembed" oembed_url = "https://www.youtube.com/oembed"
async with aiohttp.ClientSession() as http_session: async with (
async with http_session.get(oembed_url, params=params) as response: aiohttp.ClientSession() as http_session,
video_data = await response.json() http_session.get(oembed_url, params=params) as response,
):
video_data = await response.json()
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Video metadata fetched: {video_data.get('title', 'Unknown')}", f"Video metadata fetched: {video_data.get('title', 'Unknown')}",
{"stage": "metadata_fetched", "title": video_data.get('title'), "author": video_data.get('author_name')} {
"stage": "metadata_fetched",
"title": video_data.get("title"),
"author": video_data.get("author_name"),
},
) )
# Get video transcript # Get video transcript
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Fetching transcript for video: {video_id}", f"Fetching transcript for video: {video_id}",
{"stage": "transcript_fetch"} {"stage": "transcript_fetch"},
) )
try: try:
@ -838,21 +885,25 @@ async def add_youtube_video_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Transcript fetched successfully: {len(captions)} segments", f"Transcript fetched successfully: {len(captions)} segments",
{"stage": "transcript_fetched", "segments_count": len(captions), "transcript_length": len(transcript_text)} {
"stage": "transcript_fetched",
"segments_count": len(captions),
"transcript_length": len(transcript_text),
},
) )
except Exception as e: except Exception as e:
transcript_text = f"No captions available for this video. Error: {str(e)}" transcript_text = f"No captions available for this video. Error: {e!s}"
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"No transcript available for video: {video_id}", f"No transcript available for video: {video_id}",
{"stage": "transcript_unavailable", "error": str(e)} {"stage": "transcript_unavailable", "error": str(e)},
) )
# Format document # Format document
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing video content: {video_data.get('title', 'YouTube Video')}", f"Processing video content: {video_data.get('title', 'YouTube Video')}",
{"stage": "content_processing"} {"stage": "content_processing"},
) )
# Format document metadata in a more maintainable way # Format document metadata in a more maintainable way
@ -890,7 +941,7 @@ async def add_youtube_video_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Checking for duplicate video content: {video_id}", f"Checking for duplicate video content: {video_id}",
{"stage": "duplicate_check", "content_hash": content_hash} {"stage": "duplicate_check", "content_hash": content_hash},
) )
# Check if document with this content hash already exists # Check if document with this content hash already exists
@ -903,16 +954,22 @@ async def add_youtube_video_document(
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"YouTube video document already exists: {video_data.get('title', 'YouTube Video')}", f"YouTube video document already exists: {video_data.get('title', 'YouTube Video')}",
{"duplicate_detected": True, "existing_document_id": existing_document.id, "video_id": video_id} {
"duplicate_detected": True,
"existing_document_id": existing_document.id,
"video_id": video_id,
},
)
logging.info(
f"Document with content hash {content_hash} already exists. Skipping processing."
) )
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document return existing_document
# Get LLM for summary generation # Get LLM for summary generation
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Preparing for summary generation: {video_data.get('title', 'YouTube Video')}", f"Preparing for summary generation: {video_data.get('title', 'YouTube Video')}",
{"stage": "llm_setup"} {"stage": "llm_setup"},
) )
# Get user's long context LLM # Get user's long context LLM
@ -924,7 +981,7 @@ async def add_youtube_video_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Generating summary for video: {video_data.get('title', 'YouTube Video')}", f"Generating summary for video: {video_data.get('title', 'YouTube Video')}",
{"stage": "summary_generation"} {"stage": "summary_generation"},
) )
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
@ -938,7 +995,7 @@ async def add_youtube_video_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing content chunks for video: {video_data.get('title', 'YouTube Video')}", f"Processing content chunks for video: {video_data.get('title', 'YouTube Video')}",
{"stage": "chunk_processing"} {"stage": "chunk_processing"},
) )
chunks = [ chunks = [
@ -953,7 +1010,7 @@ async def add_youtube_video_document(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Creating YouTube video document in database: {video_data.get('title', 'YouTube Video')}", f"Creating YouTube video document in database: {video_data.get('title', 'YouTube Video')}",
{"stage": "document_creation", "chunks_count": len(chunks)} {"stage": "document_creation", "chunks_count": len(chunks)},
) )
document = Document( document = Document(
@ -988,8 +1045,8 @@ async def add_youtube_video_document(
"content_hash": content_hash, "content_hash": content_hash,
"chunks_count": len(chunks), "chunks_count": len(chunks),
"summary_length": len(summary_content), "summary_length": len(summary_content),
"has_transcript": "No captions available" not in transcript_text "has_transcript": "No captions available" not in transcript_text,
} },
) )
return document return document
@ -999,7 +1056,10 @@ async def add_youtube_video_document(
log_entry, log_entry,
f"Database error while processing YouTube video: {url}", f"Database error while processing YouTube video: {url}",
str(db_error), str(db_error),
{"error_type": "SQLAlchemyError", "video_id": video_id if 'video_id' in locals() else None} {
"error_type": "SQLAlchemyError",
"video_id": video_id if "video_id" in locals() else None,
},
) )
raise db_error raise db_error
except Exception as e: except Exception as e:
@ -1008,7 +1068,10 @@ async def add_youtube_video_document(
log_entry, log_entry,
f"Failed to process YouTube video: {url}", f"Failed to process YouTube video: {url}",
str(e), str(e),
{"error_type": type(e).__name__, "video_id": video_id if 'video_id' in locals() else None} {
"error_type": type(e).__name__,
"video_id": video_id if "video_id" in locals() else None,
},
) )
logging.error(f"Failed to process YouTube video: {str(e)}") logging.error(f"Failed to process YouTube video: {e!s}")
raise raise

View file

@ -1,7 +1,11 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
from typing import Optional, Tuple
from slack_sdk.errors import SlackApiError
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config from app.config import config
from app.connectors.discord_connector import DiscordConnector from app.connectors.discord_connector import DiscordConnector
@ -21,10 +25,6 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import generate_content_hash from app.utils.document_converters import generate_content_hash
from slack_sdk.errors import SlackApiError
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
# Set up logging # Set up logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,10 +35,10 @@ async def index_slack_messages(
connector_id: int, connector_id: int,
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
start_date: str = None, start_date: str | None = None,
end_date: str = None, end_date: str | None = None,
update_last_indexed: bool = True, update_last_indexed: bool = True,
) -> Tuple[int, Optional[str]]: ) -> tuple[int, str | None]:
""" """
Index Slack messages from all accessible channels. Index Slack messages from all accessible channels.
@ -192,7 +192,7 @@ async def index_slack_messages(
str(e), str(e),
{"error_type": "ChannelFetchError"}, {"error_type": "ChannelFetchError"},
) )
return 0, f"Failed to get Slack channels: {str(e)}" return 0, f"Failed to get Slack channels: {e!s}"
if not channels: if not channels:
await task_logger.log_task_success( await task_logger.log_task_success(
@ -400,13 +400,13 @@ async def index_slack_messages(
except SlackApiError as slack_error: except SlackApiError as slack_error:
logger.error( logger.error(
f"Slack API error for channel {channel_name}: {str(slack_error)}" f"Slack API error for channel {channel_name}: {slack_error!s}"
) )
skipped_channels.append(f"{channel_name} (Slack API error)") skipped_channels.append(f"{channel_name} (Slack API error)")
documents_skipped += 1 documents_skipped += 1
continue # Skip this channel and continue with others continue # Skip this channel and continue with others
except Exception as e: except Exception as e:
logger.error(f"Error processing channel {channel_name}: {str(e)}") logger.error(f"Error processing channel {channel_name}: {e!s}")
skipped_channels.append(f"{channel_name} (processing error)") skipped_channels.append(f"{channel_name} (processing error)")
documents_skipped += 1 documents_skipped += 1
continue # Skip this channel and continue with others continue # Skip this channel and continue with others
@ -453,8 +453,8 @@ async def index_slack_messages(
str(db_error), str(db_error),
{"error_type": "SQLAlchemyError"}, {"error_type": "SQLAlchemyError"},
) )
logger.error(f"Database error: {str(db_error)}") logger.error(f"Database error: {db_error!s}")
return 0, f"Database error: {str(db_error)}" return 0, f"Database error: {db_error!s}"
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -463,8 +463,8 @@ async def index_slack_messages(
str(e), str(e),
{"error_type": type(e).__name__}, {"error_type": type(e).__name__},
) )
logger.error(f"Failed to index Slack messages: {str(e)}") logger.error(f"Failed to index Slack messages: {e!s}")
return 0, f"Failed to index Slack messages: {str(e)}" return 0, f"Failed to index Slack messages: {e!s}"
async def index_notion_pages( async def index_notion_pages(
@ -472,10 +472,10 @@ async def index_notion_pages(
connector_id: int, connector_id: int,
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
start_date: str = None, start_date: str | None = None,
end_date: str = None, end_date: str | None = None,
update_last_indexed: bool = True, update_last_indexed: bool = True,
) -> Tuple[int, Optional[str]]: ) -> tuple[int, str | None]:
""" """
Index Notion pages from all accessible pages. Index Notion pages from all accessible pages.
@ -611,8 +611,8 @@ async def index_notion_pages(
str(e), str(e),
{"error_type": "PageFetchError"}, {"error_type": "PageFetchError"},
) )
logger.error(f"Error fetching Notion pages: {str(e)}", exc_info=True) logger.error(f"Error fetching Notion pages: {e!s}", exc_info=True)
return 0, f"Failed to get Notion pages: {str(e)}" return 0, f"Failed to get Notion pages: {e!s}"
if not pages: if not pages:
await task_logger.log_task_success( await task_logger.log_task_success(
@ -799,7 +799,7 @@ async def index_notion_pages(
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error processing Notion page {page.get('title', 'Unknown')}: {str(e)}", f"Error processing Notion page {page.get('title', 'Unknown')}: {e!s}",
exc_info=True, exc_info=True,
) )
skipped_pages.append( skipped_pages.append(
@ -852,9 +852,9 @@ async def index_notion_pages(
{"error_type": "SQLAlchemyError"}, {"error_type": "SQLAlchemyError"},
) )
logger.error( logger.error(
f"Database error during Notion indexing: {str(db_error)}", exc_info=True f"Database error during Notion indexing: {db_error!s}", exc_info=True
) )
return 0, f"Database error: {str(db_error)}" return 0, f"Database error: {db_error!s}"
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -863,8 +863,8 @@ async def index_notion_pages(
str(e), str(e),
{"error_type": type(e).__name__}, {"error_type": type(e).__name__},
) )
logger.error(f"Failed to index Notion pages: {str(e)}", exc_info=True) logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True)
return 0, f"Failed to index Notion pages: {str(e)}" return 0, f"Failed to index Notion pages: {e!s}"
async def index_github_repos( async def index_github_repos(
@ -872,10 +872,10 @@ async def index_github_repos(
connector_id: int, connector_id: int,
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
start_date: str = None, start_date: str | None = None,
end_date: str = None, end_date: str | None = None,
update_last_indexed: bool = True, update_last_indexed: bool = True,
) -> Tuple[int, Optional[str]]: ) -> tuple[int, str | None]:
""" """
Index code and documentation files from accessible GitHub repositories. Index code and documentation files from accessible GitHub repositories.
@ -978,7 +978,7 @@ async def index_github_repos(
str(e), str(e),
{"error_type": "ClientInitializationError"}, {"error_type": "ClientInitializationError"},
) )
return 0, f"Failed to initialize GitHub client: {str(e)}" return 0, f"Failed to initialize GitHub client: {e!s}"
# 4. Validate selected repositories # 4. Validate selected repositories
# For simplicity, we'll proceed with the list provided. # For simplicity, we'll proceed with the list provided.
@ -1097,7 +1097,7 @@ async def index_github_repos(
"url": file_url, "url": file_url,
"sha": file_sha, "sha": file_sha,
"type": file_type, "type": file_type,
"indexed_at": datetime.now(timezone.utc).isoformat(), "indexed_at": datetime.now(UTC).isoformat(),
} }
# Create new document # Create new document
@ -1175,10 +1175,10 @@ async def index_linear_issues(
connector_id: int, connector_id: int,
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
start_date: str = None, start_date: str | None = None,
end_date: str = None, end_date: str | None = None,
update_last_indexed: bool = True, update_last_indexed: bool = True,
) -> Tuple[int, Optional[str]]: ) -> tuple[int, str | None]:
""" """
Index Linear issues and comments. Index Linear issues and comments.
@ -1339,8 +1339,8 @@ async def index_linear_issues(
logger.info(f"Retrieved {len(issues)} issues from Linear API") logger.info(f"Retrieved {len(issues)} issues from Linear API")
except Exception as e: except Exception as e:
logger.error(f"Exception when calling Linear API: {str(e)}", exc_info=True) logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True)
return 0, f"Failed to get Linear issues: {str(e)}" return 0, f"Failed to get Linear issues: {e!s}"
if not issues: if not issues:
logger.info("No Linear issues found for the specified date range") logger.info("No Linear issues found for the specified date range")
@ -1481,7 +1481,7 @@ async def index_linear_issues(
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error processing issue {issue.get('identifier', 'Unknown')}: {str(e)}", f"Error processing issue {issue.get('identifier', 'Unknown')}: {e!s}",
exc_info=True, exc_info=True,
) )
skipped_issues.append( skipped_issues.append(
@ -1528,8 +1528,8 @@ async def index_linear_issues(
str(db_error), str(db_error),
{"error_type": "SQLAlchemyError"}, {"error_type": "SQLAlchemyError"},
) )
logger.error(f"Database error: {str(db_error)}", exc_info=True) logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, f"Database error: {str(db_error)}" return 0, f"Database error: {db_error!s}"
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -1538,8 +1538,8 @@ async def index_linear_issues(
str(e), str(e),
{"error_type": type(e).__name__}, {"error_type": type(e).__name__},
) )
logger.error(f"Failed to index Linear issues: {str(e)}", exc_info=True) logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True)
return 0, f"Failed to index Linear issues: {str(e)}" return 0, f"Failed to index Linear issues: {e!s}"
async def index_discord_messages( async def index_discord_messages(
@ -1547,10 +1547,10 @@ async def index_discord_messages(
connector_id: int, connector_id: int,
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
start_date: str = None, start_date: str | None = None,
end_date: str = None, end_date: str | None = None,
update_last_indexed: bool = True, update_last_indexed: bool = True,
) -> Tuple[int, Optional[str]]: ) -> tuple[int, str | None]:
""" """
Index Discord messages from all accessible channels. Index Discord messages from all accessible channels.
@ -1632,13 +1632,11 @@ async def index_discord_messages(
# Calculate date range # Calculate date range
if start_date is None or end_date is None: if start_date is None or end_date is None:
# Fall back to calculating dates based on last_indexed_at # Fall back to calculating dates based on last_indexed_at
calculated_end_date = datetime.now(timezone.utc) calculated_end_date = datetime.now(UTC)
# Use last_indexed_at as start date if available, otherwise use 365 days ago # Use last_indexed_at as start date if available, otherwise use 365 days ago
if connector.last_indexed_at: if connector.last_indexed_at:
calculated_start_date = connector.last_indexed_at.replace( calculated_start_date = connector.last_indexed_at.replace(tzinfo=UTC)
tzinfo=timezone.utc
)
logger.info( logger.info(
f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date" f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date"
) )
@ -1655,7 +1653,7 @@ async def index_discord_messages(
# Convert YYYY-MM-DD to ISO format # Convert YYYY-MM-DD to ISO format
start_date_iso = ( start_date_iso = (
datetime.strptime(start_date, "%Y-%m-%d") datetime.strptime(start_date, "%Y-%m-%d")
.replace(tzinfo=timezone.utc) .replace(tzinfo=UTC)
.isoformat() .isoformat()
) )
@ -1665,20 +1663,18 @@ async def index_discord_messages(
# Convert YYYY-MM-DD to ISO format # Convert YYYY-MM-DD to ISO format
end_date_iso = ( end_date_iso = (
datetime.strptime(end_date, "%Y-%m-%d") datetime.strptime(end_date, "%Y-%m-%d")
.replace(tzinfo=timezone.utc) .replace(tzinfo=UTC)
.isoformat() .isoformat()
) )
else: else:
# Convert provided dates to ISO format for Discord API # Convert provided dates to ISO format for Discord API
start_date_iso = ( start_date_iso = (
datetime.strptime(start_date, "%Y-%m-%d") datetime.strptime(start_date, "%Y-%m-%d")
.replace(tzinfo=timezone.utc) .replace(tzinfo=UTC)
.isoformat() .isoformat()
) )
end_date_iso = ( end_date_iso = (
datetime.strptime(end_date, "%Y-%m-%d") datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=UTC).isoformat()
.replace(tzinfo=timezone.utc)
.isoformat()
) )
logger.info( logger.info(
@ -1710,9 +1706,9 @@ async def index_discord_messages(
str(e), str(e),
{"error_type": "GuildFetchError"}, {"error_type": "GuildFetchError"},
) )
logger.error(f"Failed to get Discord guilds: {str(e)}", exc_info=True) logger.error(f"Failed to get Discord guilds: {e!s}", exc_info=True)
await discord_client.close_bot() await discord_client.close_bot()
return 0, f"Failed to get Discord guilds: {str(e)}" return 0, f"Failed to get Discord guilds: {e!s}"
if not guilds: if not guilds:
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
@ -1754,7 +1750,7 @@ async def index_discord_messages(
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to get messages for channel {channel_name}: {str(e)}" f"Failed to get messages for channel {channel_name}: {e!s}"
) )
skipped_channels.append( skipped_channels.append(
f"{guild_name}#{channel_name} (fetch error)" f"{guild_name}#{channel_name} (fetch error)"
@ -1886,7 +1882,9 @@ async def index_discord_messages(
chunks = [ chunks = [
Chunk(content=raw_chunk.text, embedding=embedding) Chunk(content=raw_chunk.text, embedding=embedding)
for raw_chunk, embedding in zip(raw_chunks, chunk_embeddings) for raw_chunk, embedding in zip(
raw_chunks, chunk_embeddings, strict=False
)
] ]
# Create and store new document # Create and store new document
@ -1902,7 +1900,7 @@ async def index_discord_messages(
"message_count": len(formatted_messages), "message_count": len(formatted_messages),
"start_date": start_date_iso, "start_date": start_date_iso,
"end_date": end_date_iso, "end_date": end_date_iso,
"indexed_at": datetime.now(timezone.utc).strftime( "indexed_at": datetime.now(UTC).strftime(
"%Y-%m-%d %H:%M:%S" "%Y-%m-%d %H:%M:%S"
), ),
}, },
@ -1920,14 +1918,14 @@ async def index_discord_messages(
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error processing guild {guild_name}: {str(e)}", exc_info=True f"Error processing guild {guild_name}: {e!s}", exc_info=True
) )
skipped_channels.append(f"{guild_name} (processing error)") skipped_channels.append(f"{guild_name} (processing error)")
documents_skipped += 1 documents_skipped += 1
continue continue
if update_last_indexed and documents_indexed > 0: if update_last_indexed and documents_indexed > 0:
connector.last_indexed_at = datetime.now(timezone.utc) connector.last_indexed_at = datetime.now(UTC)
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}") logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
await session.commit() await session.commit()
@ -1968,9 +1966,9 @@ async def index_discord_messages(
{"error_type": "SQLAlchemyError"}, {"error_type": "SQLAlchemyError"},
) )
logger.error( logger.error(
f"Database error during Discord indexing: {str(db_error)}", exc_info=True f"Database error during Discord indexing: {db_error!s}", exc_info=True
) )
return 0, f"Database error: {str(db_error)}" return 0, f"Database error: {db_error!s}"
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -1979,8 +1977,8 @@ async def index_discord_messages(
str(e), str(e),
{"error_type": type(e).__name__}, {"error_type": type(e).__name__},
) )
logger.error(f"Failed to index Discord messages: {str(e)}", exc_info=True) logger.error(f"Failed to index Discord messages: {e!s}", exc_info=True)
return 0, f"Failed to index Discord messages: {str(e)}" return 0, f"Failed to index Discord messages: {e!s}"
async def index_jira_issues( async def index_jira_issues(
@ -1988,10 +1986,10 @@ async def index_jira_issues(
connector_id: int, connector_id: int,
search_space_id: int, search_space_id: int,
user_id: str, user_id: str,
start_date: str = None, start_date: str | None = None,
end_date: str = None, end_date: str | None = None,
update_last_indexed: bool = True, update_last_indexed: bool = True,
) -> Tuple[int, Optional[str]]: ) -> tuple[int, str | None]:
""" """
Index Jira issues and comments. Index Jira issues and comments.
@ -2161,8 +2159,8 @@ async def index_jira_issues(
logger.info(f"Retrieved {len(issues)} issues from Jira API") logger.info(f"Retrieved {len(issues)} issues from Jira API")
except Exception as e: except Exception as e:
logger.error(f"Error fetching Jira issues: {str(e)}", exc_info=True) logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True)
return 0, f"Error fetching Jira issues: {str(e)}" return 0, f"Error fetching Jira issues: {e!s}"
# Process and index each issue # Process and index each issue
documents_indexed = 0 documents_indexed = 0
@ -2272,7 +2270,7 @@ async def index_jira_issues(
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error processing issue {issue.get('identifier', 'Unknown')}: {str(e)}", f"Error processing issue {issue.get('identifier', 'Unknown')}: {e!s}",
exc_info=True, exc_info=True,
) )
skipped_issues.append( skipped_issues.append(
@ -2319,8 +2317,8 @@ async def index_jira_issues(
str(db_error), str(db_error),
{"error_type": "SQLAlchemyError"}, {"error_type": "SQLAlchemyError"},
) )
logger.error(f"Database error: {str(db_error)}", exc_info=True) logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, f"Database error: {str(db_error)}" return 0, f"Database error: {db_error!s}"
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -2329,5 +2327,5 @@ async def index_jira_issues(
str(e), str(e),
{"error_type": type(e).__name__}, {"error_type": type(e).__name__},
) )
logger.error(f"Failed to index JIRA issues: {str(e)}", exc_info=True) logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True)
return 0, f"Failed to index JIRA issues: {str(e)}" return 0, f"Failed to index JIRA issues: {e!s}"

View file

@ -1,30 +1,26 @@
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State from app.agents.podcaster.state import State
from app.db import Chat, Podcast from app.db import Chat, Podcast
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
async def generate_document_podcast( async def generate_document_podcast(
session: AsyncSession, session: AsyncSession, document_id: int, search_space_id: int, user_id: int
document_id: int,
search_space_id: int,
user_id: int
): ):
# TODO: Need to fetch the document chunks, then concatenate them and pass them to the podcast generation model # TODO: Need to fetch the document chunks, then concatenate them and pass them to the podcast generation model
pass pass
async def generate_chat_podcast( async def generate_chat_podcast(
session: AsyncSession, session: AsyncSession,
chat_id: int, chat_id: int,
search_space_id: int, search_space_id: int,
podcast_title: str, podcast_title: str,
user_id: int user_id: int,
): ):
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
@ -37,21 +33,18 @@ async def generate_chat_podcast(
"chat_id": chat_id, "chat_id": chat_id,
"search_space_id": search_space_id, "search_space_id": search_space_id,
"podcast_title": podcast_title, "podcast_title": podcast_title,
"user_id": str(user_id) "user_id": str(user_id),
} },
) )
try: try:
# Fetch the chat with the specified ID # Fetch the chat with the specified ID
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry, f"Fetching chat {chat_id} from database", {"stage": "fetch_chat"}
f"Fetching chat {chat_id} from database",
{"stage": "fetch_chat"}
) )
query = select(Chat).filter( query = select(Chat).filter(
Chat.id == chat_id, Chat.id == chat_id, Chat.search_space_id == search_space_id
Chat.search_space_id == search_space_id
) )
result = await session.execute(query) result = await session.execute(query)
@ -62,15 +55,17 @@ async def generate_chat_podcast(
log_entry, log_entry,
f"Chat with id {chat_id} not found in search space {search_space_id}", f"Chat with id {chat_id} not found in search space {search_space_id}",
"Chat not found", "Chat not found",
{"error_type": "ChatNotFound"} {"error_type": "ChatNotFound"},
)
raise ValueError(
f"Chat with id {chat_id} not found in search space {search_space_id}"
) )
raise ValueError(f"Chat with id {chat_id} not found in search space {search_space_id}")
# Create chat history structure # Create chat history structure
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing chat history for chat {chat_id}", f"Processing chat history for chat {chat_id}",
{"stage": "process_chat_history", "message_count": len(chat.messages)} {"stage": "process_chat_history", "message_count": len(chat.messages)},
) )
chat_history_str = "<chat_history>" chat_history_str = "<chat_history>"
@ -89,7 +84,9 @@ async def generate_chat_podcast(
# If content is a list, join it into a single string # If content is a list, join it into a single string
if isinstance(answer_text, list): if isinstance(answer_text, list):
answer_text = "\n".join(answer_text) answer_text = "\n".join(answer_text)
chat_history_str += f"<assistant_message>{answer_text}</assistant_message>" chat_history_str += (
f"<assistant_message>{answer_text}</assistant_message>"
)
processed_messages += 1 processed_messages += 1
chat_history_str += "</chat_history>" chat_history_str += "</chat_history>"
@ -98,7 +95,11 @@ async def generate_chat_podcast(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Initializing podcast generation for chat {chat_id}", f"Initializing podcast generation for chat {chat_id}",
{"stage": "initialize_podcast_generation", "processed_messages": processed_messages, "content_length": len(chat_history_str)} {
"stage": "initialize_podcast_generation",
"processed_messages": processed_messages,
"content_length": len(chat_history_str),
},
) )
config = { config = {
@ -108,16 +109,13 @@ async def generate_chat_podcast(
} }
} }
# Initialize state with database session and streaming service # Initialize state with database session and streaming service
initial_state = State( initial_state = State(source_content=chat_history_str, db_session=session)
source_content=chat_history_str,
db_session=session
)
# Run the graph directly # Run the graph directly
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Running podcast generation graph for chat {chat_id}", f"Running podcast generation graph for chat {chat_id}",
{"stage": "run_podcast_graph"} {"stage": "run_podcast_graph"},
) )
result = await podcaster_graph.ainvoke(initial_state, config=config) result = await podcaster_graph.ainvoke(initial_state, config=config)
@ -126,28 +124,33 @@ async def generate_chat_podcast(
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Processing podcast transcript for chat {chat_id}", f"Processing podcast transcript for chat {chat_id}",
{"stage": "process_transcript", "transcript_entries": len(result["podcast_transcript"])} {
"stage": "process_transcript",
"transcript_entries": len(result["podcast_transcript"]),
},
) )
serializable_transcript = [] serializable_transcript = []
for entry in result["podcast_transcript"]: for entry in result["podcast_transcript"]:
serializable_transcript.append({ serializable_transcript.append(
"speaker_id": entry.speaker_id, {"speaker_id": entry.speaker_id, "dialog": entry.dialog}
"dialog": entry.dialog )
})
# Create a new podcast entry # Create a new podcast entry
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Creating podcast database entry for chat {chat_id}", f"Creating podcast database entry for chat {chat_id}",
{"stage": "create_podcast_entry", "file_location": result.get("final_podcast_file_path")} {
"stage": "create_podcast_entry",
"file_location": result.get("final_podcast_file_path"),
},
) )
podcast = Podcast( podcast = Podcast(
title=f"{podcast_title}", title=f"{podcast_title}",
podcast_transcript=serializable_transcript, podcast_transcript=serializable_transcript,
file_location=result["final_podcast_file_path"], file_location=result["final_podcast_file_path"],
search_space_id=search_space_id search_space_id=search_space_id,
) )
# Add to session and commit # Add to session and commit
@ -165,8 +168,8 @@ async def generate_chat_podcast(
"transcript_entries": len(serializable_transcript), "transcript_entries": len(serializable_transcript),
"file_location": result.get("final_podcast_file_path"), "file_location": result.get("final_podcast_file_path"),
"processed_messages": processed_messages, "processed_messages": processed_messages,
"content_length": len(chat_history_str) "content_length": len(chat_history_str),
} },
) )
return podcast return podcast
@ -178,7 +181,7 @@ async def generate_chat_podcast(
log_entry, log_entry,
f"Value error during podcast generation for chat {chat_id}", f"Value error during podcast generation for chat {chat_id}",
str(ve), str(ve),
{"error_type": "ValueError"} {"error_type": "ValueError"},
) )
raise ve raise ve
except SQLAlchemyError as db_error: except SQLAlchemyError as db_error:
@ -187,7 +190,7 @@ async def generate_chat_podcast(
log_entry, log_entry,
f"Database error during podcast generation for chat {chat_id}", f"Database error during podcast generation for chat {chat_id}",
str(db_error), str(db_error),
{"error_type": "SQLAlchemyError"} {"error_type": "SQLAlchemyError"},
) )
raise db_error raise db_error
except Exception as e: except Exception as e:
@ -196,7 +199,8 @@ async def generate_chat_podcast(
log_entry, log_entry,
f"Unexpected error during podcast generation for chat {chat_id}", f"Unexpected error during podcast generation for chat {chat_id}",
str(e), str(e),
{"error_type": type(e).__name__} {"error_type": type(e).__name__},
) )
raise RuntimeError(f"Failed to generate podcast for chat {chat_id}: {str(e)}") raise RuntimeError(
f"Failed to generate podcast for chat {chat_id}: {e!s}"
) from e

View file

@ -1,24 +1,25 @@
from typing import Any, AsyncGenerator, List, Union from collections.abc import AsyncGenerator
from typing import Any
from uuid import UUID from uuid import UUID
from app.agents.researcher.graph import graph as researcher_graph
from app.agents.researcher.state import State
from app.services.streaming_service import StreamingService
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.researcher.configuration import SearchMode from app.agents.researcher.configuration import SearchMode
from app.agents.researcher.graph import graph as researcher_graph
from app.agents.researcher.state import State
from app.services.streaming_service import StreamingService
async def stream_connector_search_results( async def stream_connector_search_results(
user_query: str, user_query: str,
user_id: Union[str, UUID], user_id: str | UUID,
search_space_id: int, search_space_id: int,
session: AsyncSession, session: AsyncSession,
research_mode: str, research_mode: str,
selected_connectors: List[str], selected_connectors: list[str],
langchain_chat_history: List[Any], langchain_chat_history: list[Any],
search_mode_str: str, search_mode_str: str,
document_ids_to_add_in_context: List[int] document_ids_to_add_in_context: list[int],
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
Stream connector search results to the client Stream connector search results to the client
@ -37,14 +38,14 @@ async def stream_connector_search_results(
streaming_service = StreamingService() streaming_service = StreamingService()
if research_mode == "REPORT_GENERAL": if research_mode == "REPORT_GENERAL":
NUM_SECTIONS = 1 num_sections = 1
elif research_mode == "REPORT_DEEP": elif research_mode == "REPORT_DEEP":
NUM_SECTIONS = 3 num_sections = 3
elif research_mode == "REPORT_DEEPER": elif research_mode == "REPORT_DEEPER":
NUM_SECTIONS = 6 num_sections = 6
else: else:
# Default fallback # Default fallback
NUM_SECTIONS = 1 num_sections = 1
# Convert UUID to string if needed # Convert UUID to string if needed
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
@ -58,20 +59,20 @@ async def stream_connector_search_results(
config = { config = {
"configurable": { "configurable": {
"user_query": user_query, "user_query": user_query,
"num_sections": NUM_SECTIONS, "num_sections": num_sections,
"connectors_to_search": selected_connectors, "connectors_to_search": selected_connectors,
"user_id": user_id_str, "user_id": user_id_str,
"search_space_id": search_space_id, "search_space_id": search_space_id,
"search_mode": search_mode, "search_mode": search_mode,
"research_mode": research_mode, "research_mode": research_mode,
"document_ids_to_add_in_context": document_ids_to_add_in_context "document_ids_to_add_in_context": document_ids_to_add_in_context,
} }
} }
# Initialize state with database session and streaming service # Initialize state with database session and streaming service
initial_state = State( initial_state = State(
db_session=session, db_session=session,
streaming_service=streaming_service, streaming_service=streaming_service,
chat_history=langchain_chat_history chat_history=langchain_chat_history,
) )
# Run the graph directly # Run the graph directly
@ -83,8 +84,7 @@ async def stream_connector_search_results(
config=config, config=config,
stream_mode="custom", stream_mode="custom",
): ):
if isinstance(chunk, dict): if isinstance(chunk, dict) and "yield_value" in chunk:
if "yield_value" in chunk: yield chunk["yield_value"]
yield chunk["yield_value"]
yield streaming_service.format_completion() yield streaming_service.format_completion()

View file

@ -1,8 +1,7 @@
from typing import Optional
import uuid import uuid
from fastapi import Depends, Request, Response from fastapi import Depends, Request, Response
from fastapi.responses import RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
from fastapi_users.authentication import ( from fastapi_users.authentication import (
AuthenticationBackend, AuthenticationBackend,
@ -10,16 +9,18 @@ from fastapi_users.authentication import (
JWTStrategy, JWTStrategy,
) )
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
from fastapi.responses import JSONResponse
from fastapi_users.schemas import model_dump from fastapi_users.schemas import model_dump
from pydantic import BaseModel
from app.config import config from app.config import config
from app.db import User, get_user_db from app.db import User, get_user_db
from pydantic import BaseModel
class BearerResponse(BaseModel): class BearerResponse(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
SECRET = config.SECRET_KEY SECRET = config.SECRET_KEY
if config.AUTH_TYPE == "GOOGLE": if config.AUTH_TYPE == "GOOGLE":
@ -35,19 +36,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = SECRET reset_password_token_secret = SECRET
verification_token_secret = SECRET verification_token_secret = SECRET
async def on_after_register(self, user: User, request: Optional[Request] = None): async def on_after_register(self, user: User, request: Request | None = None):
print(f"User {user.id} has registered.") print(f"User {user.id} has registered.")
async def on_after_forgot_password( async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None self, user: User, token: str, request: Request | None = None
): ):
print(f"User {user.id} has forgot their password. Reset token: {token}") print(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify( async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None self, user: User, token: str, request: Request | None = None
): ):
print( print(f"Verification requested for user {user.id}. Verification token: {token}")
f"Verification requested for user {user.id}. Verification token: {token}")
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
@ -55,7 +55,7 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
return JWTStrategy(secret=SECRET, lifetime_seconds=3600*24) return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24)
# # COOKIE AUTH | Uncomment if you want to use cookie auth. # # COOKIE AUTH | Uncomment if you want to use cookie auth.
@ -77,6 +77,7 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
# get_strategy=get_jwt_strategy, # get_strategy=get_jwt_strategy,
# ) # )
# BEARER AUTH CODE. # BEARER AUTH CODE.
class CustomBearerTransport(BearerTransport): class CustomBearerTransport(BearerTransport):
async def get_login_response(self, token: str) -> Response: async def get_login_response(self, token: str) -> Response:
@ -87,6 +88,7 @@ class CustomBearerTransport(BearerTransport):
else: else:
return JSONResponse(model_dump(bearer_response)) return JSONResponse(model_dump(bearer_response))
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login") bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")

View file

@ -1,12 +1,19 @@
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import User from app.db import User
# Helper function to check user ownership # Helper function to check user ownership
async def check_ownership(session: AsyncSession, model, item_id: int, user: User): async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id)) item = await session.execute(
select(model).filter(model.id == item_id, model.user_id == user.id)
)
item = item.scalars().first() item = item.scalars().first()
if not item: if not item:
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it") raise HTTPException(
status_code=404,
detail="Item not found or you don't have permission to access it",
)
return item return item

View file

@ -32,7 +32,7 @@ async def convert_element_to_markdown(element) -> str:
"Footer": lambda x: f"*{x}*\n\n", "Footer": lambda x: f"*{x}*\n\n",
"CodeSnippet": lambda x: f"```\n{x}\n```", "CodeSnippet": lambda x: f"```\n{x}\n```",
"PageNumber": lambda x: f"*Page {x}*\n\n", "PageNumber": lambda x: f"*Page {x}*\n\n",
"UncategorizedText": lambda x: f"{x}\n\n" "UncategorizedText": lambda x: f"{x}\n\n",
} }
converter = markdown_mapping.get(element_category, lambda x: x) converter = markdown_mapping.get(element_category, lambda x: x)
@ -74,7 +74,7 @@ def convert_chunks_to_langchain_documents(chunks):
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"LangChain is not installed. Please install it with `pip install langchain langchain-core`" "LangChain is not installed. Please install it with `pip install langchain langchain-core`"
) ) from None
langchain_docs = [] langchain_docs = []
@ -92,17 +92,20 @@ def convert_chunks_to_langchain_documents(chunks):
# Add document information to metadata # Add document information to metadata
if "document" in chunk: if "document" in chunk:
doc = chunk["document"] doc = chunk["document"]
metadata.update({ metadata.update(
"document_id": doc.get("id"), {
"document_title": doc.get("title"), "document_id": doc.get("id"),
"document_type": doc.get("document_type"), "document_title": doc.get("title"),
}) "document_type": doc.get("document_type"),
}
)
# Add document metadata if available # Add document metadata if available
if "metadata" in doc: if "metadata" in doc:
# Prefix document metadata keys to avoid conflicts # Prefix document metadata keys to avoid conflicts
doc_metadata = {f"doc_meta_{k}": v for k, doc_metadata = {
v in doc.get("metadata", {}).items()} f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()
}
metadata.update(doc_metadata) metadata.update(doc_metadata)
# Add source URL if available in metadata # Add source URL if available in metadata
@ -131,10 +134,7 @@ def convert_chunks_to_langchain_documents(chunks):
""" """
# Create LangChain Document # Create LangChain Document
langchain_doc = LangChainDocument( langchain_doc = LangChainDocument(page_content=new_content, metadata=metadata)
page_content=new_content,
metadata=metadata
)
langchain_docs.append(langchain_doc) langchain_docs.append(langchain_doc)
@ -144,4 +144,4 @@ def convert_chunks_to_langchain_documents(chunks):
def generate_content_hash(content: str, search_space_id: int) -> str: def generate_content_hash(content: str, search_space_id: int) -> str:
"""Generate SHA-256 hash for the given content combined with search space ID.""" """Generate SHA-256 hash for the given content combined with search space ID."""
combined_data = f"{search_space_id}:{content}" combined_data = f"{search_space_id}:{content}"
return hashlib.sha256(combined_data.encode('utf-8')).hexdigest() return hashlib.sha256(combined_data.encode("utf-8")).hexdigest()

View file

@ -1,7 +1,9 @@
import uvicorn
import argparse import argparse
import logging import logging
import uvicorn
from dotenv import load_dotenv from dotenv import load_dotenv
from app.config.uvicorn import load_uvicorn_config from app.config.uvicorn import load_uvicorn_config
logging.basicConfig( logging.basicConfig(

View file

@ -36,3 +36,97 @@ dependencies = [
"validators>=0.34.0", "validators>=0.34.0",
"youtube-transcript-api>=1.0.3", "youtube-transcript-api>=1.0.3",
] ]
[dependency-groups]
dev = [
"ruff>=0.12.5",
]
[tool.ruff]
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]
line-length = 88
indent-width = 4
# Python 3.12
target-version = "py312"
[tool.ruff.lint]
select = [
"E4", # pycodestyle errors
"E7", # pycodestyle errors
"E9", # pycodestyle errors
"F", # Pyflakes
"I", # isort
"N", # pep8-naming
"UP", # pyupgrade
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"T20", # flake8-print
"SIM", # flake8-simplify
"RUF", # Ruff-specific rules
]
ignore = [
"E501", # Line too long (handled by formatter)
"B008", # Do not perform function calls in argument defaults
"T201", # Print found (allow print statements)
"RUF012", # Mutable class attributes should be annotated with `typing.ClassVar`
]
extend-select = ["I"]
# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = []
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[tool.ruff.format]
# Use double quotes for strings.
quote-style = "double"
# Indent with spaces, rather than tabs.
indent-style = "space"
# Respect magic trailing commas.
skip-magic-trailing-comma = false
# Automatically detect the appropriate line ending.
line-ending = "auto"
[tool.ruff.lint.isort]
# Group imports by type
known-first-party = ["app"]
force-single-line = false
combine-as-imports = true

4507
surfsense_backend/uv.lock generated

File diff suppressed because it is too large Load diff