mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 07:42:39 +02:00
commit
617a7a34b5
92 changed files with 6274 additions and 5163 deletions
224
.github/workflows/code-quality.yml
vendored
Normal file
224
.github/workflows/code-quality.yml
vendored
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
name: Code Quality Checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main, dev]
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
file-quality:
|
||||
name: File Quality Checks
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Fetch base branch
|
||||
run: |
|
||||
# Ensure we have the base branch reference for comparison
|
||||
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit
|
||||
|
||||
- name: Cache pre-commit hooks
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
pre-commit-
|
||||
|
||||
- name: Install hook environments (cache)
|
||||
run: pre-commit install-hooks
|
||||
|
||||
- name: Run file quality checks on changed files
|
||||
run: |
|
||||
# Get list of changed files and run specific hooks on them
|
||||
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
|
||||
BASE_REF="${{ github.base_ref }}"
|
||||
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
|
||||
BASE_REF="origin/${{ github.base_ref }}"
|
||||
else
|
||||
echo "Base branch reference not found, running file quality hooks on all files"
|
||||
pre-commit run --all-files check-yaml check-json check-toml check-merge-conflict check-added-large-files debug-statements check-case-conflict
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Running file quality hooks on changed files against $BASE_REF"
|
||||
|
||||
# Run each hook individually on changed files
|
||||
SKIP=detect-secrets,bandit,ruff,ruff-format,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
|
||||
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
|
||||
|
||||
# Exit with the same code as pre-commit
|
||||
exit ${exit_code:-0}
|
||||
|
||||
security-scan:
|
||||
name: Security Scan
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Fetch base branch
|
||||
run: |
|
||||
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit
|
||||
|
||||
- name: Cache pre-commit hooks
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-security-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
pre-commit-security-
|
||||
|
||||
- name: Install hook environments (cache)
|
||||
run: pre-commit install-hooks
|
||||
|
||||
- name: Run security scans on changed files
|
||||
run: |
|
||||
# Get base ref for comparison
|
||||
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
|
||||
BASE_REF="${{ github.base_ref }}"
|
||||
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
|
||||
BASE_REF="origin/${{ github.base_ref }}"
|
||||
else
|
||||
echo "Base branch reference not found, running security scans on all files"
|
||||
echo "⚠️ This may take longer than normal"
|
||||
pre-commit run --all-files detect-secrets bandit
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Running security scans on changed files against $BASE_REF"
|
||||
|
||||
# Run only security hooks on changed files
|
||||
SKIP=check-yaml,check-json,check-toml,check-merge-conflict,check-added-large-files,debug-statements,check-case-conflict,ruff,ruff-format,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
|
||||
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
|
||||
|
||||
# Exit with the same code as pre-commit
|
||||
exit ${exit_code:-0}
|
||||
|
||||
python-backend:
|
||||
name: Python Backend Quality
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install UV
|
||||
uses: astral-sh/setup-uv@v3
|
||||
|
||||
- name: Check if backend files changed
|
||||
id: backend-changes
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
backend:
|
||||
- 'surfsense_backend/**'
|
||||
|
||||
- name: Cache dependencies
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
surfsense_backend/.venv
|
||||
key: python-deps-${{ hashFiles('surfsense_backend/uv.lock') }}
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
working-directory: surfsense_backend
|
||||
run: uv sync
|
||||
|
||||
- name: Install pre-commit for backend checks
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
run: pip install pre-commit
|
||||
|
||||
- name: Cache pre-commit hooks
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-backend-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
pre-commit-backend-
|
||||
|
||||
- name: Install hook environments (cache)
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
run: pre-commit install-hooks
|
||||
|
||||
- name: Run Python backend quality checks
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
run: |
|
||||
# Get base ref for comparison
|
||||
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
|
||||
BASE_REF="${{ github.base_ref }}"
|
||||
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
|
||||
BASE_REF="origin/${{ github.base_ref }}"
|
||||
else
|
||||
echo "Base branch reference not found, running Python backend checks on all files"
|
||||
pre-commit run --all-files ruff ruff-format
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Running Python backend checks on changed files against $BASE_REF"
|
||||
|
||||
# Run only ruff hooks on changed Python files
|
||||
SKIP=detect-secrets,bandit,check-yaml,check-json,check-toml,check-merge-conflict,check-added-large-files,debug-statements,check-case-conflict,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
|
||||
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
|
||||
|
||||
# Exit with the same code as pre-commit
|
||||
exit ${exit_code:-0}
|
||||
|
||||
quality-gate:
|
||||
name: Quality Gate
|
||||
runs-on: ubuntu-latest
|
||||
needs: [file-quality, security-scan, python-backend]
|
||||
if: always()
|
||||
|
||||
steps:
|
||||
- name: Check all jobs status
|
||||
run: |
|
||||
if [[ "${{ needs.file-quality.result }}" == "failure" ||
|
||||
"${{ needs.security-scan.result }}" == "failure" ||
|
||||
"${{ needs.python-backend.result }}" == "failure" ]]; then
|
||||
echo "❌ Code quality checks failed"
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All code quality checks passed"
|
||||
fi
|
||||
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
|
|
@ -3,7 +3,7 @@ name: pre-commit
|
|||
on:
|
||||
push:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
branches: [main, dev]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
|
|
|
|||
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,3 +1,5 @@
|
|||
.flashrank_cache*
|
||||
podcasts/
|
||||
.env
|
||||
|
||||
.ruff_cache/
|
||||
|
|
@ -6,8 +6,6 @@ repos:
|
|||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: '\.md$'
|
||||
- id: check-yaml
|
||||
args: [--multi, --unsafe]
|
||||
- id: check-json
|
||||
|
|
@ -31,52 +29,36 @@ repos:
|
|||
.*\.env\.template|
|
||||
.*/tests/.*|
|
||||
.*test.*\.py|
|
||||
test_.*\.py|
|
||||
.github/workflows/.*\.yml|
|
||||
.github/workflows/.*\.yaml|
|
||||
.*pnpm-lock\.yaml|
|
||||
.*alembic\.ini|
|
||||
.*alembic/versions/.*\.py|
|
||||
.*\.mdx$
|
||||
)$
|
||||
|
||||
# Python Backend Hooks (surfsense_backend)
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
files: ^surfsense_backend/
|
||||
language_version: python3
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
files: ^surfsense_backend/
|
||||
args: ["--profile", "black", "--line-length", "88"]
|
||||
|
||||
# Python Backend Hooks (surfsense_backend) - Using Ruff for linting and formatting
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.4
|
||||
rev: v0.12.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: ruff-check
|
||||
files: ^surfsense_backend/
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
exclude: ^surfsense_backend/(test_.*\.py|.*test.*\.py)
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
name: ruff-format
|
||||
files: ^surfsense_backend/
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: ^surfsense_backend/
|
||||
additional_dependencies: ['types-requests']
|
||||
args: [--ignore-missing-imports, --disallow-untyped-defs]
|
||||
exclude: ^surfsense_backend/(test_.*\.py|.*test.*\.py)
|
||||
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.6
|
||||
hooks:
|
||||
- id: bandit
|
||||
files: ^surfsense_backend/
|
||||
args: ['-r', '-f', 'json']
|
||||
exclude: ^surfsense_backend/(tests/|alembic/)
|
||||
args: ['-f', 'json', '--severity-level', 'high', '--confidence-level', 'high']
|
||||
exclude: ^surfsense_backend/(tests/|test_.*\.py|.*test.*\.py|alembic/)
|
||||
|
||||
# Frontend/Extension Hooks (TypeScript/JavaScript)
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
|
@ -11,10 +11,10 @@ from alembic import context
|
|||
|
||||
# Ensure the app directory is in the Python path
|
||||
# This allows Alembic to find your models
|
||||
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
# Import your models base
|
||||
from app.db import Base # Assuming your Base is defined in app.db
|
||||
from app.db import Base # Assuming your Base is defined in app.db
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
|
|
|||
|
|
@ -4,17 +4,15 @@ Revision ID: 10
|
|||
Revises: 9
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "10"
|
||||
down_revision: Union[str, None] = "9"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "9"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Define the ENUM type name
|
||||
CHAT_TYPE_ENUM = "chattype"
|
||||
|
|
@ -22,87 +20,101 @@ CHAT_TYPE_ENUM = "chattype"
|
|||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema - replace ChatType enum values with new QNA/REPORT structure."""
|
||||
|
||||
|
||||
# Old enum name for temporary storage
|
||||
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
||||
|
||||
|
||||
# New enum values
|
||||
new_values = (
|
||||
"QNA",
|
||||
"REPORT_GENERAL",
|
||||
"REPORT_DEEP",
|
||||
"REPORT_DEEPER"
|
||||
)
|
||||
new_values = ("QNA", "REPORT_GENERAL", "REPORT_DEEP", "REPORT_DEEPER")
|
||||
new_values_sql = ", ".join([f"'{v}'" for v in new_values])
|
||||
|
||||
|
||||
# Table and column info
|
||||
table_name = "chats"
|
||||
column_name = "type"
|
||||
|
||||
|
||||
# Step 1: Rename the current enum type
|
||||
op.execute(f"ALTER TYPE {CHAT_TYPE_ENUM} RENAME TO {old_enum_name}")
|
||||
|
||||
|
||||
# Step 2: Create the new enum type with new values
|
||||
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({new_values_sql})")
|
||||
|
||||
|
||||
# Step 3: Add a temporary column with the new type
|
||||
op.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}")
|
||||
|
||||
op.execute(
|
||||
f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}"
|
||||
)
|
||||
|
||||
# Step 4: Update the temporary column with mapped values
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'QNA' WHERE {column_name}::text = 'GENERAL'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEP' WHERE {column_name}::text = 'DEEP'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPER'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPEST'")
|
||||
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'QNA' WHERE {column_name}::text = 'GENERAL'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEP' WHERE {column_name}::text = 'DEEP'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPER'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPEST'"
|
||||
)
|
||||
|
||||
# Step 5: Drop the old column
|
||||
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
|
||||
|
||||
|
||||
# Step 6: Rename the new column to the original name
|
||||
op.execute(f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}")
|
||||
|
||||
op.execute(
|
||||
f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}"
|
||||
)
|
||||
|
||||
# Step 7: Drop the old enum type
|
||||
op.execute(f"DROP TYPE {old_enum_name}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema - revert ChatType enum to old GENERAL/DEEP/DEEPER/DEEPEST structure."""
|
||||
|
||||
|
||||
# Old enum name for temporary storage
|
||||
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
||||
|
||||
|
||||
# Original enum values
|
||||
original_values = (
|
||||
"GENERAL",
|
||||
"DEEP",
|
||||
"DEEPER",
|
||||
"DEEPEST"
|
||||
)
|
||||
original_values = ("GENERAL", "DEEP", "DEEPER", "DEEPEST")
|
||||
original_values_sql = ", ".join([f"'{v}'" for v in original_values])
|
||||
|
||||
|
||||
# Table and column info
|
||||
table_name = "chats"
|
||||
column_name = "type"
|
||||
|
||||
|
||||
# Step 1: Rename the current enum type
|
||||
op.execute(f"ALTER TYPE {CHAT_TYPE_ENUM} RENAME TO {old_enum_name}")
|
||||
|
||||
|
||||
# Step 2: Create the new enum type with original values
|
||||
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({original_values_sql})")
|
||||
|
||||
|
||||
# Step 3: Add a temporary column with the original type
|
||||
op.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}")
|
||||
|
||||
op.execute(
|
||||
f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}"
|
||||
)
|
||||
|
||||
# Step 4: Update the temporary column with mapped values back to old values
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'QNA'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'REPORT_GENERAL'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'DEEP' WHERE {column_name}::text = 'REPORT_DEEP'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'DEEPER' WHERE {column_name}::text = 'REPORT_DEEPER'")
|
||||
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'QNA'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'REPORT_GENERAL'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'DEEP' WHERE {column_name}::text = 'REPORT_DEEP'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'DEEPER' WHERE {column_name}::text = 'REPORT_DEEPER'"
|
||||
)
|
||||
|
||||
# Step 5: Drop the old column
|
||||
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
|
||||
|
||||
|
||||
# Step 6: Rename the new column to the original name
|
||||
op.execute(f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}")
|
||||
|
||||
op.execute(
|
||||
f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}"
|
||||
)
|
||||
|
||||
# Step 7: Drop the old enum type
|
||||
op.execute(f"DROP TYPE {old_enum_name}")
|
||||
op.execute(f"DROP TYPE {old_enum_name}")
|
||||
|
|
|
|||
|
|
@ -4,16 +4,17 @@ Revision ID: 11
|
|||
Revises: 10
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "11"
|
||||
down_revision: Union[str, None] = "10"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "10"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
|
|||
|
|
@ -4,16 +4,17 @@ Revision ID: 12
|
|||
Revises: 11
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import inspect
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "12"
|
||||
down_revision: Union[str, None] = "11"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "11"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
|
|||
|
|
@ -4,15 +4,15 @@ Revision ID: 13
|
|||
Revises: 12
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "13"
|
||||
down_revision: Union[str, None] = "12"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "12"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ Revises:
|
|||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
|
@ -15,9 +15,9 @@ from alembic import op
|
|||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "1"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
@ -63,10 +63,8 @@ def downgrade() -> None:
|
|||
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR')"
|
||||
)
|
||||
op.execute(
|
||||
(
|
||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||
"connector_type::text::searchsourceconnectortype"
|
||||
)
|
||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||
"connector_type::text::searchsourceconnectortype"
|
||||
)
|
||||
op.execute("DROP TYPE searchsourceconnectortype_old")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,15 +5,15 @@ Revises: e55302644c51
|
|||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2"
|
||||
down_revision: Union[str, None] = "e55302644c51"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "e55302644c51"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
@ -49,10 +49,8 @@ def downgrade() -> None:
|
|||
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR')"
|
||||
)
|
||||
op.execute(
|
||||
(
|
||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||
"connector_type::text::searchsourceconnectortype"
|
||||
)
|
||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||
"connector_type::text::searchsourceconnectortype"
|
||||
)
|
||||
op.execute("DROP TYPE searchsourceconnectortype_old")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,15 +5,15 @@ Revises: 2
|
|||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "3"
|
||||
down_revision: Union[str, None] = "2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "2"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Define the ENUM type name and the new value
|
||||
ENUM_NAME = "documenttype" # Make sure this matches the name in your DB (usually lowercase class name)
|
||||
|
|
|
|||
|
|
@ -5,37 +5,26 @@ Revises: 3
|
|||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "4"
|
||||
down_revision: Union[str, None] = "3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "3"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
ENUM_NAME = "searchsourceconnectortype"
|
||||
NEW_VALUE = "LINKUP_API"
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
op.execute(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = '{NEW_VALUE}'
|
||||
AND enumtypid = (
|
||||
SELECT oid FROM pg_type WHERE typname = '{ENUM_NAME}'
|
||||
)
|
||||
) THEN
|
||||
ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
# Manually add the command to add the enum value
|
||||
op.execute("ALTER TYPE searchsourceconnectortype ADD VALUE 'LINKUP_API'")
|
||||
|
||||
# Pass for the rest, as autogenerate didn't run to add other schema details
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
|
@ -49,10 +38,8 @@ def downgrade() -> None:
|
|||
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR', 'LINEAR_CONNECTOR')"
|
||||
)
|
||||
op.execute(
|
||||
(
|
||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||
"connector_type::text::searchsourceconnectortype"
|
||||
)
|
||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||
"connector_type::text::searchsourceconnectortype"
|
||||
)
|
||||
op.execute("DROP TYPE searchsourceconnectortype_old")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,54 +4,73 @@ Revision ID: 5
|
|||
Revises: 4
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '5'
|
||||
down_revision: Union[str, None] = '4'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "5"
|
||||
down_revision: str | None = "4"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Alter Chat table
|
||||
op.alter_column('chats', 'title',
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False)
|
||||
|
||||
op.alter_column(
|
||||
"chats",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Alter Document table
|
||||
op.alter_column('documents', 'title',
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False)
|
||||
|
||||
op.alter_column(
|
||||
"documents",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Alter Podcast table
|
||||
op.alter_column('podcasts', 'title',
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False)
|
||||
op.alter_column(
|
||||
"podcasts",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert Chat table
|
||||
op.alter_column('chats', 'title',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False)
|
||||
|
||||
op.alter_column(
|
||||
"chats",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Revert Document table
|
||||
op.alter_column('documents', 'title',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False)
|
||||
|
||||
op.alter_column(
|
||||
"documents",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Revert Podcast table
|
||||
op.alter_column('podcasts', 'title',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False)
|
||||
op.alter_column(
|
||||
"podcasts",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,18 +5,19 @@ Revises: 5
|
|||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6"
|
||||
down_revision: Union[str, None] = "5"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "5"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
|
|||
|
|
@ -5,17 +5,18 @@ Revises: 6
|
|||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "7"
|
||||
down_revision: Union[str, None] = "6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "6"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
|
|||
|
|
@ -4,17 +4,18 @@ Revision ID: 8
|
|||
Revises: 7
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "8"
|
||||
down_revision: Union[str, None] = "7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "7"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
|
|||
|
|
@ -4,15 +4,15 @@ Revision ID: 9
|
|||
Revises: 8
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9"
|
||||
down_revision: Union[str, None] = "8"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "8"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Define the ENUM type name and the new value
|
||||
CONNECTOR_ENUM = "searchsourceconnectortype"
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e55302644c51"
|
||||
down_revision: Union[str, None] = "1"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "1"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Define the ENUM type name and the new value
|
||||
ENUM_NAME = "documenttype"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
|
@ -17,11 +16,11 @@ class Configuration:
|
|||
# create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/)
|
||||
# and when you invoke the graph
|
||||
podcast_title: str
|
||||
user_id: str
|
||||
user_id: str
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: Optional[RunnableConfig] = None
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
from langgraph.graph import StateGraph
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import create_merged_podcast_audio, create_podcast_transcript
|
||||
from .state import State
|
||||
|
||||
|
||||
from .nodes import create_merged_podcast_audio, create_podcast_transcript
|
||||
|
||||
|
||||
def build_graph():
|
||||
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
|
|
@ -24,8 +21,9 @@ def build_graph():
|
|||
# Compile the workflow into an executable graph
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith
|
||||
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# Compile the graph once when the module is loaded
|
||||
graph = build_graph()
|
||||
|
|
|
|||
|
|
@ -1,148 +1,154 @@
|
|||
from typing import Any, Dict
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from ffmpeg.asyncio import FFmpeg
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from litellm import aspeech
|
||||
from ffmpeg.asyncio import FFmpeg
|
||||
|
||||
from .configuration import Configuration
|
||||
from .state import PodcastTranscriptEntry, State, PodcastTranscripts
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
from app.config import config as app_config
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
from .state import PodcastTranscriptEntry, PodcastTranscripts, State
|
||||
|
||||
async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
|
||||
async def create_podcast_transcript(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Each node does work."""
|
||||
|
||||
|
||||
# Get configuration from runnable config
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
user_id = configuration.user_id
|
||||
|
||||
|
||||
# Get user's long context LLM
|
||||
llm = await get_user_long_context_llm(state.db_session, user_id)
|
||||
if not llm:
|
||||
error_message = f"No long context LLM configured for user {user_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
||||
# Get the prompt
|
||||
prompt = get_podcast_generation_prompt()
|
||||
|
||||
|
||||
# Create the messages
|
||||
messages = [
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(content=f"<source_content>{state.source_content}</source_content>")
|
||||
HumanMessage(
|
||||
content=f"<source_content>{state.source_content}</source_content>"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# Generate the podcast transcript
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
|
||||
|
||||
# First try the direct approach
|
||||
try:
|
||||
podcast_transcript = PodcastTranscripts.model_validate(json.loads(llm_response.content))
|
||||
podcast_transcript = PodcastTranscripts.model_validate(
|
||||
json.loads(llm_response.content)
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {str(e)}")
|
||||
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||
|
||||
# Fallback: Parse the JSON response manually
|
||||
try:
|
||||
# Extract JSON content from the response
|
||||
content = llm_response.content
|
||||
|
||||
|
||||
# Find the JSON in the content (handle case where LLM might add additional text)
|
||||
json_start = content.find('{')
|
||||
json_end = content.rfind('}') + 1
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_str = content[json_start:json_end]
|
||||
|
||||
|
||||
# Parse the JSON string
|
||||
parsed_data = json.loads(json_str)
|
||||
|
||||
|
||||
# Convert to Pydantic model
|
||||
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
|
||||
|
||||
print(f"Successfully parsed podcast transcript using fallback approach")
|
||||
|
||||
print("Successfully parsed podcast transcript using fallback approach")
|
||||
else:
|
||||
# If JSON structure not found, raise a clear error
|
||||
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
||||
print(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e2:
|
||||
# Log the error and re-raise it
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {str(e2)}"
|
||||
print(f"Error parsing LLM response: {str(e2)}")
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||
print(f"Error parsing LLM response: {e2!s}")
|
||||
print(f"Raw response: {llm_response.content}")
|
||||
raise
|
||||
|
||||
return {
|
||||
"podcast_transcript": podcast_transcript.podcast_transcripts
|
||||
}
|
||||
|
||||
|
||||
async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
|
||||
return {"podcast_transcript": podcast_transcript.podcast_transcripts}
|
||||
|
||||
|
||||
async def create_merged_podcast_audio(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Generate audio for each transcript and merge them into a single podcast file."""
|
||||
|
||||
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
|
||||
|
||||
starting_transcript = PodcastTranscriptEntry(
|
||||
speaker_id=1,
|
||||
dialog=f"Welcome to {configuration.podcast_title} Podcast."
|
||||
speaker_id=1, dialog=f"Welcome to {configuration.podcast_title} Podcast."
|
||||
)
|
||||
|
||||
|
||||
transcript = state.podcast_transcript
|
||||
|
||||
|
||||
# Merge the starting transcript with the podcast transcript
|
||||
# Check if transcript is a PodcastTranscripts object or already a list
|
||||
if hasattr(transcript, 'podcast_transcripts'):
|
||||
if hasattr(transcript, "podcast_transcripts"):
|
||||
transcript_entries = transcript.podcast_transcripts
|
||||
else:
|
||||
transcript_entries = transcript
|
||||
|
||||
merged_transcript = [starting_transcript] + transcript_entries
|
||||
|
||||
|
||||
merged_transcript = [starting_transcript, *transcript_entries]
|
||||
|
||||
# Create a temporary directory for audio files
|
||||
temp_dir = Path("temp_audio")
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# Generate a unique session ID for this podcast
|
||||
session_id = str(uuid.uuid4())
|
||||
output_path = f"podcasts/{session_id}_podcast.mp3"
|
||||
os.makedirs("podcasts", exist_ok=True)
|
||||
|
||||
|
||||
# Map of speaker_id to voice
|
||||
voice_mapping = {
|
||||
0: "alloy", # Default/intro voice
|
||||
1: "echo", # First speaker
|
||||
1: "echo", # First speaker
|
||||
# 2: "fable", # Second speaker
|
||||
# 3: "onyx", # Third speaker
|
||||
# 4: "nova", # Fourth speaker
|
||||
# 5: "shimmer" # Fifth speaker
|
||||
}
|
||||
|
||||
|
||||
# Generate audio for each transcript segment
|
||||
audio_files = []
|
||||
|
||||
|
||||
async def generate_speech_for_segment(segment, index):
|
||||
# Handle both dictionary and PodcastTranscriptEntry objects
|
||||
if hasattr(segment, 'speaker_id'):
|
||||
if hasattr(segment, "speaker_id"):
|
||||
speaker_id = segment.speaker_id
|
||||
dialog = segment.dialog
|
||||
else:
|
||||
speaker_id = segment.get("speaker_id", 0)
|
||||
dialog = segment.get("dialog", "")
|
||||
|
||||
|
||||
# Select voice based on speaker_id
|
||||
voice = voice_mapping.get(speaker_id, "alloy")
|
||||
|
||||
|
||||
# Generate a unique filename for this segment
|
||||
filename = f"{temp_dir}/{session_id}_{index}.mp3"
|
||||
|
||||
|
||||
try:
|
||||
if app_config.TTS_SERVICE_API_BASE:
|
||||
response = await aspeech(
|
||||
|
|
@ -163,55 +169,61 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
|
|||
max_retries=2,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
|
||||
# Save the audio to a file - use proper streaming method
|
||||
with open(filename, 'wb') as f:
|
||||
with open(filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
|
||||
return filename
|
||||
except Exception as e:
|
||||
print(f"Error generating speech for segment {index}: {str(e)}")
|
||||
print(f"Error generating speech for segment {index}: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# Generate all audio files concurrently
|
||||
tasks = [generate_speech_for_segment(segment, i) for i, segment in enumerate(merged_transcript)]
|
||||
tasks = [
|
||||
generate_speech_for_segment(segment, i)
|
||||
for i, segment in enumerate(merged_transcript)
|
||||
]
|
||||
audio_files = await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
# Merge audio files using ffmpeg
|
||||
try:
|
||||
# Create FFmpeg instance with the first input
|
||||
ffmpeg = FFmpeg().option("y")
|
||||
|
||||
|
||||
# Add each audio file as input
|
||||
for audio_file in audio_files:
|
||||
ffmpeg = ffmpeg.input(audio_file)
|
||||
|
||||
|
||||
# Configure the concatenation and output
|
||||
filter_complex = []
|
||||
for i in range(len(audio_files)):
|
||||
filter_complex.append(f"[{i}:0]")
|
||||
|
||||
filter_complex_str = "".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
|
||||
|
||||
filter_complex_str = (
|
||||
"".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
|
||||
)
|
||||
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
|
||||
ffmpeg = ffmpeg.output(output_path, map="[outa]")
|
||||
|
||||
|
||||
# Execute FFmpeg
|
||||
await ffmpeg.execute()
|
||||
|
||||
|
||||
print(f"Successfully created podcast audio: {output_path}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error merging audio files: {str(e)}")
|
||||
print(f"Error merging audio files: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
for audio_file in audio_files:
|
||||
try:
|
||||
os.remove(audio_file)
|
||||
except:
|
||||
except Exception as e:
|
||||
print(f"Error removing audio file {audio_file}: {e!s}")
|
||||
pass
|
||||
|
||||
|
||||
return {
|
||||
"podcast_transcript": merged_transcript,
|
||||
"final_podcast_file_path": output_path
|
||||
"final_podcast_file_path": output_path,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,4 +108,4 @@ Output:
|
|||
|
||||
Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 5-minute audio duration.
|
||||
</podcast_generation_system>
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,14 +3,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class PodcastTranscriptEntry(BaseModel):
|
||||
"""
|
||||
Represents a single entry in a podcast transcript.
|
||||
"""
|
||||
|
||||
speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)")
|
||||
dialog: str = Field(..., description="The dialog text spoken by the speaker")
|
||||
|
||||
|
|
@ -19,10 +21,11 @@ class PodcastTranscripts(BaseModel):
|
|||
"""
|
||||
Represents the full podcast transcript structure.
|
||||
"""
|
||||
podcast_transcripts: List[PodcastTranscriptEntry] = Field(
|
||||
...,
|
||||
description="List of transcript entries with alternating speakers"
|
||||
)
|
||||
|
||||
podcast_transcripts: list[PodcastTranscriptEntry] = Field(
|
||||
..., description="List of transcript entries with alternating speakers"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
|
|
@ -32,8 +35,9 @@ class State:
|
|||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
source_content: str
|
||||
podcast_transcript: Optional[List[PodcastTranscriptEntry]] = None
|
||||
final_podcast_file_path: Optional[str] = None
|
||||
podcast_transcript: list[PodcastTranscriptEntry] | None = None
|
||||
final_podcast_file_path: str | None = None
|
||||
|
|
|
|||
|
|
@ -4,17 +4,20 @@ from __future__ import annotations
|
|||
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
class SearchMode(Enum):
|
||||
|
||||
class SearchMode(Enum):
|
||||
"""Enum defining the type of search mode."""
|
||||
|
||||
CHUNKS = "CHUNKS"
|
||||
DOCUMENTS = "DOCUMENTS"
|
||||
|
||||
|
||||
class ResearchMode(Enum):
|
||||
"""Enum defining the type of research mode."""
|
||||
|
||||
QNA = "QNA"
|
||||
REPORT_GENERAL = "REPORT_GENERAL"
|
||||
REPORT_DEEP = "REPORT_DEEP"
|
||||
|
|
@ -28,16 +31,16 @@ class Configuration:
|
|||
# Input parameters provided at invocation
|
||||
user_query: str
|
||||
num_sections: int
|
||||
connectors_to_search: List[str]
|
||||
connectors_to_search: list[str]
|
||||
user_id: str
|
||||
search_space_id: int
|
||||
search_mode: SearchMode
|
||||
research_mode: ResearchMode
|
||||
document_ids_to_add_in_context: List[int]
|
||||
document_ids_to_add_in_context: list[int]
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: Optional[RunnableConfig] = None
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
|
|
|
|||
|
|
@ -1,31 +1,41 @@
|
|||
from typing import Any, TypedDict
|
||||
|
||||
from langgraph.graph import StateGraph
|
||||
from .state import State
|
||||
from .nodes import reformulate_user_query, write_answer_outline, process_sections, handle_qna_workflow, generate_further_questions
|
||||
|
||||
from .configuration import Configuration, ResearchMode
|
||||
from typing import TypedDict, List, Dict, Any, Optional
|
||||
from .nodes import (
|
||||
generate_further_questions,
|
||||
handle_qna_workflow,
|
||||
process_sections,
|
||||
reformulate_user_query,
|
||||
write_answer_outline,
|
||||
)
|
||||
from .state import State
|
||||
|
||||
|
||||
# Define what keys are in our state dict
|
||||
class GraphState(TypedDict):
|
||||
# Intermediate data produced during workflow
|
||||
answer_outline: Optional[Any]
|
||||
answer_outline: Any | None
|
||||
# Final output
|
||||
final_written_report: Optional[str]
|
||||
final_written_report: str | None
|
||||
|
||||
|
||||
def build_graph():
|
||||
"""
|
||||
Build and return the LangGraph workflow.
|
||||
|
||||
|
||||
This function constructs the researcher agent graph with conditional routing
|
||||
based on research_mode - QNA mode uses a direct Q&A workflow while other modes
|
||||
use the full report generation pipeline. Both paths generate follow-up questions
|
||||
at the end using the reranked documents from the sub-agents.
|
||||
|
||||
|
||||
Returns:
|
||||
A compiled LangGraph workflow
|
||||
"""
|
||||
# Define a new graph with state class
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
|
||||
# Add nodes to the graph
|
||||
workflow.add_node("reformulate_user_query", reformulate_user_query)
|
||||
workflow.add_node("handle_qna_workflow", handle_qna_workflow)
|
||||
|
|
@ -35,41 +45,42 @@ def build_graph():
|
|||
|
||||
# Define the edges
|
||||
workflow.add_edge("__start__", "reformulate_user_query")
|
||||
|
||||
|
||||
# Add conditional edges from reformulate_user_query based on research mode
|
||||
def route_after_reformulate(state: State, config) -> str:
|
||||
"""Route based on research_mode after reformulating the query."""
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
|
||||
|
||||
if configuration.research_mode == ResearchMode.QNA.value:
|
||||
return "handle_qna_workflow"
|
||||
else:
|
||||
return "write_answer_outline"
|
||||
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"reformulate_user_query",
|
||||
route_after_reformulate,
|
||||
{
|
||||
"handle_qna_workflow": "handle_qna_workflow",
|
||||
"write_answer_outline": "write_answer_outline"
|
||||
}
|
||||
"write_answer_outline": "write_answer_outline",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# QNA workflow path: handle_qna_workflow -> generate_further_questions -> __end__
|
||||
workflow.add_edge("handle_qna_workflow", "generate_further_questions")
|
||||
|
||||
|
||||
# Report generation workflow path: write_answer_outline -> process_sections -> generate_further_questions -> __end__
|
||||
workflow.add_edge("write_answer_outline", "process_sections")
|
||||
workflow.add_edge("process_sections", "generate_further_questions")
|
||||
|
||||
|
||||
# Both paths end after generating further questions
|
||||
workflow.add_edge("generate_further_questions", "__end__")
|
||||
|
||||
# Compile the workflow into an executable graph
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith
|
||||
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# Compile the graph once when the module is loaded
|
||||
graph = build_graph()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from app.db import Document, SearchSpace
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.query_service import QueryService
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
|
@ -13,6 +10,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
# Additional imports for document fetching
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import Document, SearchSpace
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.query_service import QueryService
|
||||
|
||||
from .configuration import Configuration, SearchMode
|
||||
from .prompts import (
|
||||
get_answer_outline_system_prompt,
|
||||
|
|
@ -26,8 +27,8 @@ from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_na
|
|||
|
||||
|
||||
async def fetch_documents_by_ids(
|
||||
document_ids: List[int], user_id: str, db_session: AsyncSession
|
||||
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
document_ids: list[int], user_id: str, db_session: AsyncSession
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""
|
||||
Fetch documents by their IDs with ownership check using DOCUMENTS mode approach.
|
||||
|
||||
|
|
@ -358,13 +359,13 @@ async def fetch_documents_by_ids(
|
|||
return source_objects, formatted_documents
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching documents by IDs: {str(e)}")
|
||||
print(f"Error fetching documents by IDs: {e!s}")
|
||||
return [], []
|
||||
|
||||
|
||||
async def write_answer_outline(
|
||||
state: State, config: RunnableConfig, writer: StreamWriter
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a structured answer outline based on the user query.
|
||||
|
||||
|
|
@ -502,27 +503,27 @@ async def write_answer_outline(
|
|||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
# Log the error and re-raise it
|
||||
error_message = f"Error parsing LLM response: {str(e)}"
|
||||
error_message = f"Error parsing LLM response: {e!s}"
|
||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||
|
||||
print(f"Error parsing LLM response: {str(e)}")
|
||||
print(f"Error parsing LLM response: {e!s}")
|
||||
print(f"Raw response: {response.content}")
|
||||
raise
|
||||
|
||||
|
||||
async def fetch_relevant_documents(
|
||||
research_questions: List[str],
|
||||
research_questions: list[str],
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
connectors_to_search: List[str],
|
||||
connectors_to_search: list[str],
|
||||
writer: StreamWriter = None,
|
||||
state: State = None,
|
||||
top_k: int = 10,
|
||||
connector_service: ConnectorService = None,
|
||||
search_mode: SearchMode = SearchMode.CHUNKS,
|
||||
user_selected_sources: List[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
user_selected_sources: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch relevant documents for research questions using the provided connectors.
|
||||
|
||||
|
|
@ -832,8 +833,11 @@ async def fetch_relevant_documents(
|
|||
|
||||
elif connector == "LINKUP_API":
|
||||
linkup_mode = "standard"
|
||||
|
||||
source_object, linkup_chunks = await connector_service.search_linkup(
|
||||
|
||||
(
|
||||
source_object,
|
||||
linkup_chunks,
|
||||
) = await connector_service.search_linkup(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
mode=linkup_mode,
|
||||
|
|
@ -904,7 +908,7 @@ async def fetch_relevant_documents(
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error searching connector {connector}: {str(e)}"
|
||||
error_message = f"Error searching connector {connector}: {e!s}"
|
||||
print(error_message)
|
||||
|
||||
# Stream error message
|
||||
|
|
@ -913,7 +917,7 @@ async def fetch_relevant_documents(
|
|||
writer(
|
||||
{
|
||||
"yield_value": streaming_service.format_error(
|
||||
f"Error searching {friendly_name}: {str(e)}"
|
||||
f"Error searching {friendly_name}: {e!s}"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
|
@ -948,37 +952,49 @@ async def fetch_relevant_documents(
|
|||
|
||||
if source_id and source_type:
|
||||
source_key = f"{source_type}_{source_id}"
|
||||
current_sources_count = len(source_obj.get('sources', []))
|
||||
|
||||
current_sources_count = len(source_obj.get("sources", []))
|
||||
|
||||
if source_key not in seen_source_keys:
|
||||
seen_source_keys.add(source_key)
|
||||
deduplicated_sources.append(source_obj)
|
||||
print(f"Debug: Added source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count}")
|
||||
print(
|
||||
f"Debug: Added source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count}"
|
||||
)
|
||||
else:
|
||||
# Check if this source object has more sources than the existing one
|
||||
existing_index = None
|
||||
for i, existing_source in enumerate(deduplicated_sources):
|
||||
existing_id = existing_source.get('id')
|
||||
existing_type = existing_source.get('type')
|
||||
existing_id = existing_source.get("id")
|
||||
existing_type = existing_source.get("type")
|
||||
if existing_id == source_id and existing_type == source_type:
|
||||
existing_index = i
|
||||
break
|
||||
|
||||
|
||||
if existing_index is not None:
|
||||
existing_sources_count = len(deduplicated_sources[existing_index].get('sources', []))
|
||||
existing_sources_count = len(
|
||||
deduplicated_sources[existing_index].get("sources", [])
|
||||
)
|
||||
if current_sources_count > existing_sources_count:
|
||||
# Replace the existing source object with the new one that has more sources
|
||||
deduplicated_sources[existing_index] = source_obj
|
||||
print(f"Debug: Replaced source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {existing_sources_count} -> {current_sources_count}")
|
||||
print(
|
||||
f"Debug: Replaced source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {existing_sources_count} -> {current_sources_count}"
|
||||
)
|
||||
else:
|
||||
print(f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count} <= {existing_sources_count}")
|
||||
print(
|
||||
f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count} <= {existing_sources_count}"
|
||||
)
|
||||
else:
|
||||
print(f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key} (couldn't find existing)")
|
||||
print(
|
||||
f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key} (couldn't find existing)"
|
||||
)
|
||||
else:
|
||||
# If there's no ID or type, just add it to be safe
|
||||
deduplicated_sources.append(source_obj)
|
||||
print(f"Debug: Added source without ID/type - {source_obj.get('name', 'UNKNOWN')}")
|
||||
|
||||
print(
|
||||
f"Debug: Added source without ID/type - {source_obj.get('name', 'UNKNOWN')}"
|
||||
)
|
||||
|
||||
# Stream info about deduplicated sources
|
||||
if streaming_service and writer:
|
||||
user_source_count = len(user_selected_sources) if user_selected_sources else 0
|
||||
|
|
@ -1039,7 +1055,7 @@ async def fetch_relevant_documents(
|
|||
|
||||
async def process_sections(
|
||||
state: State, config: RunnableConfig, writer: StreamWriter
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process all sections in parallel and combine the results.
|
||||
|
||||
|
|
@ -1100,13 +1116,13 @@ async def process_sections(
|
|||
)
|
||||
|
||||
if configuration.num_sections == 1:
|
||||
TOP_K = 10
|
||||
top_k = 10
|
||||
elif configuration.num_sections == 3:
|
||||
TOP_K = 20
|
||||
top_k = 20
|
||||
elif configuration.num_sections == 6:
|
||||
TOP_K = 30
|
||||
top_k = 30
|
||||
else:
|
||||
TOP_K = 10
|
||||
top_k = 10
|
||||
|
||||
relevant_documents = []
|
||||
user_selected_documents = []
|
||||
|
|
@ -1155,13 +1171,13 @@ async def process_sections(
|
|||
connectors_to_search=configuration.connectors_to_search,
|
||||
writer=writer,
|
||||
state=state,
|
||||
top_k=TOP_K,
|
||||
top_k=top_k,
|
||||
connector_service=connector_service,
|
||||
search_mode=configuration.search_mode,
|
||||
user_selected_sources=user_selected_sources,
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = f"Error fetching relevant documents: {str(e)}"
|
||||
error_message = f"Error fetching relevant documents: {e!s}"
|
||||
print(error_message)
|
||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||
# Log the error and continue with an empty list of documents
|
||||
|
|
@ -1251,7 +1267,7 @@ async def process_sections(
|
|||
for i, result in enumerate(section_results):
|
||||
if isinstance(result, Exception):
|
||||
section_title = answer_outline.answer_outline[i].section_title
|
||||
error_message = f"Error processing section '{section_title}': {str(result)}"
|
||||
error_message = f"Error processing section '{section_title}': {result!s}"
|
||||
print(error_message)
|
||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||
processed_results.append(error_message)
|
||||
|
|
@ -1260,8 +1276,8 @@ async def process_sections(
|
|||
|
||||
# Combine the results into a final report with section titles
|
||||
final_report = []
|
||||
for i, (section, content) in enumerate(
|
||||
zip(answer_outline.answer_outline, processed_results)
|
||||
for _i, (section, content) in enumerate(
|
||||
zip(answer_outline.answer_outline, processed_results, strict=False)
|
||||
):
|
||||
# Skip adding the section header since the content already contains the title
|
||||
final_report.append(content)
|
||||
|
|
@ -1299,15 +1315,15 @@ async def process_sections(
|
|||
async def process_section_with_documents(
|
||||
section_id: int,
|
||||
section_title: str,
|
||||
section_questions: List[str],
|
||||
section_questions: list[str],
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
relevant_documents: List[Dict[str, Any]],
|
||||
relevant_documents: list[dict[str, Any]],
|
||||
user_query: str,
|
||||
state: State = None,
|
||||
writer: StreamWriter = None,
|
||||
sub_section_type: SubSectionType = SubSectionType.MIDDLE,
|
||||
section_contents: Dict[int, Dict[str, Any]] = None,
|
||||
section_contents: dict[int, dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Process a single section using pre-fetched documents.
|
||||
|
|
@ -1388,7 +1404,7 @@ async def process_section_with_documents(
|
|||
# Variables to track streaming state
|
||||
complete_content = "" # Tracks the complete content received so far
|
||||
|
||||
async for chunk_type, chunk in sub_section_writer_graph.astream(
|
||||
async for _chunk_type, chunk in sub_section_writer_graph.astream(
|
||||
sub_state, config, stream_mode=["values"]
|
||||
):
|
||||
if "final_answer" in chunk:
|
||||
|
|
@ -1448,24 +1464,24 @@ async def process_section_with_documents(
|
|||
|
||||
return complete_content
|
||||
except Exception as e:
|
||||
print(f"Error processing section '{section_title}': {str(e)}")
|
||||
print(f"Error processing section '{section_title}': {e!s}")
|
||||
|
||||
# Send error update via streaming if available
|
||||
if state and state.streaming_service and writer:
|
||||
writer(
|
||||
{
|
||||
"yield_value": state.streaming_service.format_error(
|
||||
f'Error processing section "{section_title}": {str(e)}'
|
||||
f'Error processing section "{section_title}": {e!s}'
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return f"Error processing section: {section_title}. Details: {str(e)}"
|
||||
return f"Error processing section: {section_title}. Details: {e!s}"
|
||||
|
||||
|
||||
async def reformulate_user_query(
|
||||
state: State, config: RunnableConfig, writer: StreamWriter
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Reforms the user query based on the chat history.
|
||||
"""
|
||||
|
|
@ -1490,7 +1506,7 @@ async def reformulate_user_query(
|
|||
|
||||
async def handle_qna_workflow(
|
||||
state: State, config: RunnableConfig, writer: StreamWriter
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Handle the QNA research workflow.
|
||||
|
||||
|
|
@ -1532,7 +1548,7 @@ async def handle_qna_workflow(
|
|||
)
|
||||
|
||||
# Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM
|
||||
TOP_K = 15
|
||||
top_k = 15
|
||||
|
||||
relevant_documents = []
|
||||
user_selected_documents = []
|
||||
|
|
@ -1584,13 +1600,13 @@ async def handle_qna_workflow(
|
|||
connectors_to_search=configuration.connectors_to_search,
|
||||
writer=writer,
|
||||
state=state,
|
||||
top_k=TOP_K,
|
||||
top_k=top_k,
|
||||
connector_service=connector_service,
|
||||
search_mode=configuration.search_mode,
|
||||
user_selected_sources=user_selected_sources,
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = f"Error fetching relevant documents for QNA: {str(e)}"
|
||||
error_message = f"Error fetching relevant documents for QNA: {e!s}"
|
||||
print(error_message)
|
||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||
# Continue with empty documents - the QNA agent will handle this gracefully
|
||||
|
|
@ -1688,16 +1704,16 @@ async def handle_qna_workflow(
|
|||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error generating QNA answer: {str(e)}"
|
||||
error_message = f"Error generating QNA answer: {e!s}"
|
||||
print(error_message)
|
||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||
|
||||
return {"final_written_report": f"Error generating answer: {str(e)}"}
|
||||
return {"final_written_report": f"Error generating answer: {e!s}"}
|
||||
|
||||
|
||||
async def generate_further_questions(
|
||||
state: State, config: RunnableConfig, writer: StreamWriter
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate contextually relevant follow-up questions based on chat history and available documents.
|
||||
|
||||
|
|
@ -1748,7 +1764,7 @@ async def generate_further_questions(
|
|||
chat_history_xml += f"<assistant>{message.content}</assistant>\n"
|
||||
else:
|
||||
# Handle other message types if needed
|
||||
chat_history_xml += f"<message>{str(message)}</message>\n"
|
||||
chat_history_xml += f"<message>{message!s}</message>\n"
|
||||
chat_history_xml += "</chat_history>"
|
||||
|
||||
# Format available documents for the prompt
|
||||
|
|
@ -1868,7 +1884,7 @@ async def generate_further_questions(
|
|||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
# Log the error and return empty list
|
||||
error_message = f"Error parsing further questions response: {str(e)}"
|
||||
error_message = f"Error parsing further questions response: {e!s}"
|
||||
print(error_message)
|
||||
writer(
|
||||
{"yield_value": streaming_service.format_error(f"Warning: {error_message}")}
|
||||
|
|
@ -1880,7 +1896,7 @@ async def generate_further_questions(
|
|||
|
||||
except Exception as e:
|
||||
# Handle any other errors
|
||||
error_message = f"Error generating further questions: {str(e)}"
|
||||
error_message = f"Error generating further questions: {e!s}"
|
||||
print(error_message)
|
||||
writer(
|
||||
{"yield_value": streaming_service.format_error(f"Warning: {error_message}")}
|
||||
|
|
|
|||
|
|
@ -221,4 +221,4 @@ Output:
|
|||
}}
|
||||
</examples>
|
||||
</further_questions_system>
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""QnA Agent.
|
||||
"""
|
||||
"""QnA Agent."""
|
||||
|
||||
from .graph import graph
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional, List, Any
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
|
@ -15,13 +15,15 @@ class Configuration:
|
|||
# Configuration parameters for the Q&A agent
|
||||
user_query: str # The user's question to answer
|
||||
reformulated_query: str # The reformulated query
|
||||
relevant_documents: List[Any] # Documents provided directly to the agent for answering
|
||||
relevant_documents: list[
|
||||
Any
|
||||
] # Documents provided directly to the agent for answering
|
||||
user_id: str # User identifier
|
||||
search_space_id: int # Search space identifier
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: Optional[RunnableConfig] = None
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from langgraph.graph import StateGraph
|
||||
from .state import State
|
||||
from .nodes import rerank_documents, answer_question
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import answer_question, rerank_documents
|
||||
from .state import State
|
||||
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
|
|
|||
|
|
@ -1,24 +1,28 @@
|
|||
from app.services.reranker_service import RerankerService
|
||||
from .configuration import Configuration
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from .state import State
|
||||
from typing import Any, Dict
|
||||
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from ..utils import (
|
||||
optimize_documents_for_token_limit,
|
||||
calculate_token_count,
|
||||
format_documents_section
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from app.services.reranker_service import RerankerService
|
||||
|
||||
from ..utils import (
|
||||
calculate_token_count,
|
||||
format_documents_section,
|
||||
optimize_documents_for_token_limit,
|
||||
)
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
||||
from .state import State
|
||||
|
||||
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""
|
||||
Rerank the documents based on relevance to the user's question.
|
||||
|
||||
|
||||
This node takes the relevant documents provided in the configuration,
|
||||
reranks them using the reranker service based on the user's query,
|
||||
and updates the state with the reranked documents.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing the reranked documents.
|
||||
"""
|
||||
|
|
@ -30,16 +34,14 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
|
||||
# If no documents were provided, return empty list
|
||||
if not documents or len(documents) == 0:
|
||||
return {
|
||||
"reranked_documents": []
|
||||
}
|
||||
|
||||
return {"reranked_documents": []}
|
||||
|
||||
# Get reranker service from app config
|
||||
reranker_service = RerankerService.get_reranker_instance()
|
||||
|
||||
|
||||
# Use documents as is if no reranker service is available
|
||||
reranked_docs = documents
|
||||
|
||||
|
||||
if reranker_service:
|
||||
try:
|
||||
# Convert documents to format expected by reranker if needed
|
||||
|
|
@ -51,58 +53,64 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
"document": {
|
||||
"id": doc.get("document", {}).get("id", ""),
|
||||
"title": doc.get("document", {}).get("title", ""),
|
||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
||||
"metadata": doc.get("document", {}).get("metadata", {})
|
||||
}
|
||||
} for i, doc in enumerate(documents)
|
||||
"document_type": doc.get("document", {}).get(
|
||||
"document_type", ""
|
||||
),
|
||||
"metadata": doc.get("document", {}).get("metadata", {}),
|
||||
},
|
||||
}
|
||||
for i, doc in enumerate(documents)
|
||||
]
|
||||
|
||||
|
||||
# Rerank documents using the user's query
|
||||
reranked_docs = reranker_service.rerank_documents(user_query + "\n" + reformulated_query, reranker_input_docs)
|
||||
|
||||
reranked_docs = reranker_service.rerank_documents(
|
||||
user_query + "\n" + reformulated_query, reranker_input_docs
|
||||
)
|
||||
|
||||
# Sort by score in descending order
|
||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
|
||||
print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}")
|
||||
except Exception as e:
|
||||
print(f"Error during reranking: {str(e)}")
|
||||
# Use original docs if reranking fails
|
||||
|
||||
return {
|
||||
"reranked_documents": reranked_docs
|
||||
}
|
||||
|
||||
async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
print(
|
||||
f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during reranking: {e!s}")
|
||||
# Use original docs if reranking fails
|
||||
|
||||
return {"reranked_documents": reranked_docs}
|
||||
|
||||
|
||||
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""
|
||||
Answer the user's question using the provided documents.
|
||||
|
||||
|
||||
This node takes the relevant documents provided in the configuration and uses
|
||||
an LLM to generate a comprehensive answer to the user's question with
|
||||
proper citations. The citations follow IEEE format using source IDs from the
|
||||
documents. If no documents are provided, it will use chat history to generate
|
||||
an answer.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing the final answer in the "final_answer" key.
|
||||
"""
|
||||
from app.services.llm_service import get_user_fast_llm
|
||||
|
||||
|
||||
# Get configuration and relevant documents from configuration
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = state.reranked_documents
|
||||
user_query = configuration.user_query
|
||||
user_id = configuration.user_id
|
||||
|
||||
|
||||
# Get user's fast LLM
|
||||
llm = await get_user_fast_llm(state.db_session, user_id)
|
||||
if not llm:
|
||||
error_message = f"No fast LLM configured for user {user_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
||||
# Determine if we have documents and optimize for token limits
|
||||
has_documents_initially = documents and len(documents) > 0
|
||||
|
||||
|
||||
if has_documents_initially:
|
||||
# Create base message template for token calculation (without documents)
|
||||
base_human_message_template = f"""
|
||||
|
|
@ -114,41 +122,49 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
|||
|
||||
Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner.
|
||||
"""
|
||||
|
||||
|
||||
# Use initial system prompt for token calculation
|
||||
initial_system_prompt = get_qna_citation_system_prompt()
|
||||
base_messages = state.chat_history + [
|
||||
base_messages = [
|
||||
*state.chat_history,
|
||||
SystemMessage(content=initial_system_prompt),
|
||||
HumanMessage(content=base_human_message_template)
|
||||
HumanMessage(content=base_human_message_template),
|
||||
]
|
||||
|
||||
|
||||
# Optimize documents to fit within token limits
|
||||
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
|
||||
documents, base_messages, llm.model
|
||||
optimized_documents, has_optimized_documents = (
|
||||
optimize_documents_for_token_limit(documents, base_messages, llm.model)
|
||||
)
|
||||
|
||||
|
||||
# Update state based on optimization result
|
||||
documents = optimized_documents
|
||||
has_documents = has_optimized_documents
|
||||
else:
|
||||
has_documents = False
|
||||
|
||||
|
||||
# Choose system prompt based on final document availability
|
||||
system_prompt = get_qna_citation_system_prompt() if has_documents else get_qna_no_documents_system_prompt()
|
||||
|
||||
system_prompt = (
|
||||
get_qna_citation_system_prompt()
|
||||
if has_documents
|
||||
else get_qna_no_documents_system_prompt()
|
||||
)
|
||||
|
||||
# Generate documents section
|
||||
documents_text = format_documents_section(
|
||||
documents,
|
||||
"Source material from your personal knowledge base"
|
||||
) if has_documents else ""
|
||||
|
||||
documents_text = (
|
||||
format_documents_section(
|
||||
documents, "Source material from your personal knowledge base"
|
||||
)
|
||||
if has_documents
|
||||
else ""
|
||||
)
|
||||
|
||||
# Create final human message content
|
||||
instruction_text = (
|
||||
"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner."
|
||||
if has_documents else
|
||||
"Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
|
||||
if has_documents
|
||||
else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
|
||||
)
|
||||
|
||||
|
||||
human_message_content = f"""
|
||||
{documents_text}
|
||||
|
||||
|
|
@ -159,22 +175,20 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
|||
|
||||
{instruction_text}
|
||||
"""
|
||||
|
||||
|
||||
# Create final messages for the LLM
|
||||
messages_with_chat_history = state.chat_history + [
|
||||
messages_with_chat_history = [
|
||||
*state.chat_history,
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=human_message_content)
|
||||
HumanMessage(content=human_message_content),
|
||||
]
|
||||
|
||||
|
||||
# Log final token count
|
||||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
||||
print(f"Final token count: {total_tokens}")
|
||||
|
||||
|
||||
|
||||
# Call the LLM and get the response
|
||||
response = await llm.ainvoke(messages_with_chat_history)
|
||||
final_answer = response.content
|
||||
|
||||
return {
|
||||
"final_answer": final_answer
|
||||
}
|
||||
|
||||
return {"final_answer": final_answer}
|
||||
|
|
|
|||
|
|
@ -3,14 +3,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the dynamic state for the Q&A agent during execution.
|
||||
|
||||
This state tracks the database session, chat history, and the outputs
|
||||
This state tracks the database session, chat history, and the outputs
|
||||
generated by the agent's nodes during question answering.
|
||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
|
|
@ -18,8 +20,8 @@ class State:
|
|||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
|
||||
chat_history: Optional[List[Any]] = field(default_factory=list)
|
||||
|
||||
chat_history: list[Any] | None = field(default_factory=list)
|
||||
# OUTPUT: Populated by agent nodes
|
||||
reranked_documents: Optional[List[Any]] = None
|
||||
final_answer: Optional[str] = None
|
||||
reranked_documents: list[Any] | None = None
|
||||
final_answer: str | None = None
|
||||
|
|
|
|||
|
|
@ -3,10 +3,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.streaming_service import StreamingService
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the dynamic state for the agent during execution.
|
||||
|
|
@ -15,23 +18,23 @@ class State:
|
|||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context (not part of actual graph state)
|
||||
db_session: AsyncSession
|
||||
|
||||
|
||||
# Streaming service
|
||||
streaming_service: StreamingService
|
||||
|
||||
chat_history: Optional[List[Any]] = field(default_factory=list)
|
||||
|
||||
reformulated_query: Optional[str] = field(default=None)
|
||||
|
||||
chat_history: list[Any] | None = field(default_factory=list)
|
||||
|
||||
reformulated_query: str | None = field(default=None)
|
||||
# Using field to explicitly mark as part of state
|
||||
answer_outline: Optional[Any] = field(default=None)
|
||||
further_questions: Optional[Any] = field(default=None)
|
||||
|
||||
answer_outline: Any | None = field(default=None)
|
||||
further_questions: Any | None = field(default=None)
|
||||
|
||||
# Temporary field to hold reranked documents from sub-agents for further question generation
|
||||
reranked_documents: Optional[List[Any]] = field(default=None)
|
||||
|
||||
reranked_documents: list[Any] | None = field(default=None)
|
||||
|
||||
# OUTPUT: Populated by agent nodes
|
||||
# Using field to explicitly mark as part of state
|
||||
final_written_report: Optional[str] = field(default=None)
|
||||
|
||||
final_written_report: str | None = field(default=None)
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@ from __future__ import annotations
|
|||
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Any
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
class SubSectionType(Enum):
|
||||
"""Enum defining the type of sub-section."""
|
||||
|
||||
START = "START"
|
||||
MIDDLE = "MIDDLE"
|
||||
END = "END"
|
||||
|
|
@ -22,17 +23,16 @@ class Configuration:
|
|||
|
||||
# Input parameters provided at invocation
|
||||
sub_section_title: str
|
||||
sub_section_questions: List[str]
|
||||
sub_section_questions: list[str]
|
||||
sub_section_type: SubSectionType
|
||||
user_query: str
|
||||
relevant_documents: List[Any] # Documents provided directly to the agent
|
||||
relevant_documents: list[Any] # Documents provided directly to the agent
|
||||
user_id: str
|
||||
search_space_id: int
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: Optional[RunnableConfig] = None
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from langgraph.graph import StateGraph
|
||||
from .state import State
|
||||
from .nodes import write_sub_section, rerank_documents
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import rerank_documents, write_sub_section
|
||||
from .state import State
|
||||
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
|
|
|||
|
|
@ -1,25 +1,28 @@
|
|||
from .configuration import Configuration
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from .state import State
|
||||
from typing import Any, Dict
|
||||
from app.services.reranker_service import RerankerService
|
||||
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from .configuration import SubSectionType
|
||||
from ..utils import (
|
||||
optimize_documents_for_token_limit,
|
||||
calculate_token_count,
|
||||
format_documents_section
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from app.services.reranker_service import RerankerService
|
||||
|
||||
from ..utils import (
|
||||
calculate_token_count,
|
||||
format_documents_section,
|
||||
optimize_documents_for_token_limit,
|
||||
)
|
||||
from .configuration import Configuration, SubSectionType
|
||||
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
||||
from .state import State
|
||||
|
||||
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""
|
||||
Rerank the documents based on relevance to the sub-section title.
|
||||
|
||||
|
||||
This node takes the relevant documents provided in the configuration,
|
||||
reranks them using the reranker service based on the sub-section title,
|
||||
and updates the state with the reranked documents.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing the reranked documents.
|
||||
"""
|
||||
|
|
@ -30,23 +33,23 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
|
||||
# If no documents were provided, return empty list
|
||||
if not documents or len(documents) == 0:
|
||||
return {
|
||||
"reranked_documents": []
|
||||
}
|
||||
|
||||
return {"reranked_documents": []}
|
||||
|
||||
# Get reranker service from app config
|
||||
reranker_service = RerankerService.get_reranker_instance()
|
||||
|
||||
|
||||
# Use documents as is if no reranker service is available
|
||||
reranked_docs = documents
|
||||
|
||||
|
||||
if reranker_service:
|
||||
try:
|
||||
# Use the sub-section questions for reranking context
|
||||
# rerank_query = "\n".join(sub_section_questions)
|
||||
# rerank_query = configuration.user_query
|
||||
|
||||
rerank_query = configuration.user_query + "\n" + "\n".join(sub_section_questions)
|
||||
|
||||
rerank_query = (
|
||||
configuration.user_query + "\n" + "\n".join(sub_section_questions)
|
||||
)
|
||||
|
||||
# Convert documents to format expected by reranker if needed
|
||||
reranker_input_docs = [
|
||||
|
|
@ -57,54 +60,60 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
"document": {
|
||||
"id": doc.get("document", {}).get("id", ""),
|
||||
"title": doc.get("document", {}).get("title", ""),
|
||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
||||
"metadata": doc.get("document", {}).get("metadata", {})
|
||||
}
|
||||
} for i, doc in enumerate(documents)
|
||||
"document_type": doc.get("document", {}).get(
|
||||
"document_type", ""
|
||||
),
|
||||
"metadata": doc.get("document", {}).get("metadata", {}),
|
||||
},
|
||||
}
|
||||
for i, doc in enumerate(documents)
|
||||
]
|
||||
|
||||
|
||||
# Rerank documents using the section title
|
||||
reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs)
|
||||
|
||||
reranked_docs = reranker_service.rerank_documents(
|
||||
rerank_query, reranker_input_docs
|
||||
)
|
||||
|
||||
# Sort by score in descending order
|
||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
|
||||
print(f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}")
|
||||
except Exception as e:
|
||||
print(f"Error during reranking: {str(e)}")
|
||||
# Use original docs if reranking fails
|
||||
|
||||
return {
|
||||
"reranked_documents": reranked_docs
|
||||
}
|
||||
|
||||
async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
print(
|
||||
f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during reranking: {e!s}")
|
||||
# Use original docs if reranking fails
|
||||
|
||||
return {"reranked_documents": reranked_docs}
|
||||
|
||||
|
||||
async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""
|
||||
Write the sub-section using the provided documents.
|
||||
|
||||
|
||||
This node takes the relevant documents provided in the configuration and uses
|
||||
an LLM to generate a comprehensive answer to the sub-section title with
|
||||
proper citations. The citations follow IEEE format using source IDs from the
|
||||
documents. If no documents are provided, it will use chat history to generate
|
||||
content.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing the final answer in the "final_answer" key.
|
||||
"""
|
||||
from app.services.llm_service import get_user_fast_llm
|
||||
|
||||
|
||||
# Get configuration and relevant documents from configuration
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = state.reranked_documents
|
||||
user_id = configuration.user_id
|
||||
|
||||
|
||||
# Get user's fast LLM
|
||||
llm = await get_user_fast_llm(state.db_session, user_id)
|
||||
if not llm:
|
||||
error_message = f"No fast LLM configured for user {user_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
||||
# Extract configuration data
|
||||
section_title = configuration.sub_section_title
|
||||
sub_section_questions = configuration.sub_section_questions
|
||||
|
|
@ -113,18 +122,18 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
|
||||
# Format the questions as bullet points for clarity
|
||||
questions_text = "\n".join([f"- {question}" for question in sub_section_questions])
|
||||
|
||||
|
||||
# Provide context based on the subsection type
|
||||
section_position_context_map = {
|
||||
SubSectionType.START: "This is the INTRODUCTION section.",
|
||||
SubSectionType.MIDDLE: "This is a MIDDLE section. Ensure this content flows naturally from previous sections and into subsequent ones. This could be any middle section in the document, so maintain coherence with the overall structure while addressing the specific topic of this section. Do not provide any conclusions in this section, as conclusions should only appear in the final section.",
|
||||
SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure."
|
||||
SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure.",
|
||||
}
|
||||
section_position_context = section_position_context_map.get(sub_section_type, "")
|
||||
|
||||
|
||||
# Determine if we have documents and optimize for token limits
|
||||
has_documents_initially = documents and len(documents) > 0
|
||||
|
||||
|
||||
if has_documents_initially:
|
||||
# Create base message template for token calculation (without documents)
|
||||
base_human_message_template = f"""
|
||||
|
|
@ -149,38 +158,45 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
|
||||
Please write content for this sub-section using the provided source material and cite all information appropriately.
|
||||
"""
|
||||
|
||||
|
||||
# Use initial system prompt for token calculation
|
||||
initial_system_prompt = get_citation_system_prompt()
|
||||
base_messages = state.chat_history + [
|
||||
base_messages = [
|
||||
*state.chat_history,
|
||||
SystemMessage(content=initial_system_prompt),
|
||||
HumanMessage(content=base_human_message_template)
|
||||
HumanMessage(content=base_human_message_template),
|
||||
]
|
||||
|
||||
|
||||
# Optimize documents to fit within token limits
|
||||
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
|
||||
documents, base_messages, llm.model
|
||||
optimized_documents, has_optimized_documents = (
|
||||
optimize_documents_for_token_limit(documents, base_messages, llm.model)
|
||||
)
|
||||
|
||||
|
||||
# Update state based on optimization result
|
||||
documents = optimized_documents
|
||||
has_documents = has_optimized_documents
|
||||
else:
|
||||
has_documents = False
|
||||
|
||||
|
||||
# Choose system prompt based on final document availability
|
||||
system_prompt = get_citation_system_prompt() if has_documents else get_no_documents_system_prompt()
|
||||
|
||||
system_prompt = (
|
||||
get_citation_system_prompt()
|
||||
if has_documents
|
||||
else get_no_documents_system_prompt()
|
||||
)
|
||||
|
||||
# Generate documents section
|
||||
documents_text = format_documents_section(documents, "Source material") if has_documents else ""
|
||||
|
||||
documents_text = (
|
||||
format_documents_section(documents, "Source material") if has_documents else ""
|
||||
)
|
||||
|
||||
# Create final human message content
|
||||
instruction_text = (
|
||||
"Please write content for this sub-section using the provided source material and cite all information appropriately."
|
||||
if has_documents else
|
||||
"Please write content for this sub-section based on our conversation history and your general knowledge."
|
||||
if has_documents
|
||||
else "Please write content for this sub-section based on our conversation history and your general knowledge."
|
||||
)
|
||||
|
||||
|
||||
human_message_content = f"""
|
||||
{documents_text}
|
||||
|
||||
|
|
@ -204,22 +220,20 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
|
||||
{instruction_text}
|
||||
"""
|
||||
|
||||
|
||||
# Create final messages for the LLM
|
||||
messages_with_chat_history = state.chat_history + [
|
||||
messages_with_chat_history = [
|
||||
*state.chat_history,
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=human_message_content)
|
||||
HumanMessage(content=human_message_content),
|
||||
]
|
||||
|
||||
|
||||
# Log final token count
|
||||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
||||
print(f"Final token count: {total_tokens}")
|
||||
|
||||
|
||||
# Call the LLM and get the response
|
||||
response = await llm.ainvoke(messages_with_chat_history)
|
||||
final_answer = response.content
|
||||
|
||||
return {
|
||||
"final_answer": final_answer
|
||||
}
|
||||
|
||||
return {"final_answer": final_answer}
|
||||
|
|
|
|||
|
|
@ -182,4 +182,4 @@ When writing content for a sub-section without access to personal documents:
|
|||
5. Address the guiding questions through natural content flow without explicitly listing them
|
||||
6. Suggest how adding relevant sources to SurfSense could enhance future content when appropriate
|
||||
</user_query_instructions>
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the dynamic state for the agent during execution.
|
||||
|
|
@ -14,11 +16,11 @@ class State:
|
|||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
|
||||
chat_history: Optional[List[Any]] = field(default_factory=list)
|
||||
# OUTPUT: Populated by agent nodes
|
||||
reranked_documents: Optional[List[Any]] = None
|
||||
final_answer: Optional[str] = None
|
||||
|
||||
chat_history: list[Any] | None = field(default_factory=list)
|
||||
# OUTPUT: Populated by agent nodes
|
||||
reranked_documents: list[Any] | None = None
|
||||
final_answer: str | None = None
|
||||
|
|
|
|||
|
|
@ -1,27 +1,37 @@
|
|||
from typing import List, Dict, Any, Tuple, NamedTuple
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from litellm import get_model_info, token_counter
|
||||
from pydantic import BaseModel, Field
|
||||
from litellm import token_counter, get_model_info
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""A section in the answer outline."""
|
||||
|
||||
section_id: int = Field(..., description="The zero-based index of the section")
|
||||
section_title: str = Field(..., description="The title of the section")
|
||||
questions: List[str] = Field(..., description="Questions to research for this section")
|
||||
questions: list[str] = Field(
|
||||
..., description="Questions to research for this section"
|
||||
)
|
||||
|
||||
|
||||
class AnswerOutline(BaseModel):
|
||||
"""The complete answer outline with all sections."""
|
||||
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
||||
|
||||
answer_outline: list[Section] = Field(
|
||||
..., description="List of sections in the answer outline"
|
||||
)
|
||||
|
||||
|
||||
class DocumentTokenInfo(NamedTuple):
|
||||
"""Information about a document and its token cost."""
|
||||
|
||||
index: int
|
||||
document: Dict[str, Any]
|
||||
document: dict[str, Any]
|
||||
formatted_content: str
|
||||
token_count: int
|
||||
|
||||
|
||||
|
||||
|
||||
def get_connector_emoji(connector_name: str) -> str:
|
||||
"""Get an appropriate emoji for a connector type."""
|
||||
connector_emojis = {
|
||||
|
|
@ -36,7 +46,7 @@ def get_connector_emoji(connector_name: str) -> str:
|
|||
"JIRA_CONNECTOR": "🎫",
|
||||
"DISCORD_CONNECTOR": "🗨️",
|
||||
"TAVILY_API": "🔍",
|
||||
"LINKUP_API": "🔗"
|
||||
"LINKUP_API": "🔗",
|
||||
}
|
||||
return connector_emojis.get(connector_name, "🔎")
|
||||
|
||||
|
|
@ -55,31 +65,26 @@ def get_connector_friendly_name(connector_name: str) -> str:
|
|||
"JIRA_CONNECTOR": "Jira",
|
||||
"DISCORD_CONNECTOR": "Discord",
|
||||
"TAVILY_API": "Tavily Search",
|
||||
"LINKUP_API": "Linkup Search"
|
||||
"LINKUP_API": "Linkup Search",
|
||||
}
|
||||
return connector_friendly_names.get(connector_name, connector_name)
|
||||
|
||||
|
||||
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
||||
def convert_langchain_messages_to_dict(
|
||||
messages: list[BaseMessage],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert LangChain messages to format expected by token_counter."""
|
||||
role_mapping = {
|
||||
'system': 'system',
|
||||
'human': 'user',
|
||||
'ai': 'assistant'
|
||||
}
|
||||
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
|
||||
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = role_mapping.get(getattr(msg, 'type', None), 'user')
|
||||
converted_messages.append({
|
||||
"role": role,
|
||||
"content": str(msg.content)
|
||||
})
|
||||
role = role_mapping.get(getattr(msg, "type", None), "user")
|
||||
converted_messages.append({"role": role, "content": str(msg.content)})
|
||||
|
||||
return converted_messages
|
||||
|
||||
|
||||
def format_document_for_citation(document: Dict[str, Any]) -> str:
|
||||
def format_document_for_citation(document: dict[str, Any]) -> str:
|
||||
"""Format a single document for citation in the standard XML format."""
|
||||
content = document.get("content", "")
|
||||
doc_info = document.get("document", {})
|
||||
|
|
@ -97,7 +102,9 @@ def format_document_for_citation(document: Dict[str, Any]) -> str:
|
|||
</document>"""
|
||||
|
||||
|
||||
def format_documents_section(documents: List[Dict[str, Any]], section_title: str = "Source material") -> str:
|
||||
def format_documents_section(
|
||||
documents: list[dict[str, Any]], section_title: str = "Source material"
|
||||
) -> str:
|
||||
"""Format multiple documents into a complete documents section."""
|
||||
if not documents:
|
||||
return ""
|
||||
|
|
@ -110,7 +117,9 @@ def format_documents_section(documents: List[Dict[str, Any]], section_title: str
|
|||
</documents>"""
|
||||
|
||||
|
||||
def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) -> List[DocumentTokenInfo]:
|
||||
def calculate_document_token_costs(
|
||||
documents: list[dict[str, Any]], model: str
|
||||
) -> list[DocumentTokenInfo]:
|
||||
"""Pre-calculate token costs for each document."""
|
||||
document_token_info = []
|
||||
|
||||
|
|
@ -119,24 +128,24 @@ def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str)
|
|||
|
||||
# Calculate token count for this document
|
||||
token_count = token_counter(
|
||||
messages=[{"role": "user", "content": formatted_doc}],
|
||||
model=model
|
||||
messages=[{"role": "user", "content": formatted_doc}], model=model
|
||||
)
|
||||
|
||||
document_token_info.append(DocumentTokenInfo(
|
||||
index=i,
|
||||
document=doc,
|
||||
formatted_content=formatted_doc,
|
||||
token_count=token_count
|
||||
))
|
||||
document_token_info.append(
|
||||
DocumentTokenInfo(
|
||||
index=i,
|
||||
document=doc,
|
||||
formatted_content=formatted_doc,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
return document_token_info
|
||||
|
||||
|
||||
def find_optimal_documents_with_binary_search(
|
||||
document_tokens: List[DocumentTokenInfo],
|
||||
available_tokens: int
|
||||
) -> List[DocumentTokenInfo]:
|
||||
document_tokens: list[DocumentTokenInfo], available_tokens: int
|
||||
) -> list[DocumentTokenInfo]:
|
||||
"""Use binary search to find the maximum number of documents that fit within token limit."""
|
||||
if not document_tokens or available_tokens <= 0:
|
||||
return []
|
||||
|
|
@ -147,8 +156,7 @@ def find_optimal_documents_with_binary_search(
|
|||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
current_docs = document_tokens[:mid]
|
||||
current_token_sum = sum(
|
||||
doc_info.token_count for doc_info in current_docs)
|
||||
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
|
||||
|
||||
if current_token_sum <= available_tokens:
|
||||
optimal_docs = current_docs
|
||||
|
|
@ -163,20 +171,18 @@ def get_model_context_window(model_name: str) -> int:
|
|||
"""Get the total context window size for a model (input + output tokens)."""
|
||||
try:
|
||||
model_info = get_model_info(model_name)
|
||||
context_window = model_info.get(
|
||||
'max_input_tokens', 4096) # Default fallback
|
||||
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
|
||||
return context_window
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}")
|
||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
|
||||
)
|
||||
return 4096 # Conservative fallback
|
||||
|
||||
|
||||
def optimize_documents_for_token_limit(
|
||||
documents: List[Dict[str, Any]],
|
||||
base_messages: List[BaseMessage],
|
||||
model_name: str
|
||||
) -> Tuple[List[Dict[str, Any]], bool]:
|
||||
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
|
||||
) -> tuple[list[dict[str, Any]], bool]:
|
||||
"""
|
||||
Optimize documents to fit within token limits using binary search.
|
||||
|
||||
|
|
@ -201,7 +207,8 @@ def optimize_documents_for_token_limit(
|
|||
available_tokens_for_docs = context_window - base_tokens
|
||||
|
||||
print(
|
||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}")
|
||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
|
||||
)
|
||||
|
||||
if available_tokens_for_docs <= 0:
|
||||
print("No tokens available for documents after base content and output buffer")
|
||||
|
|
@ -212,8 +219,7 @@ def optimize_documents_for_token_limit(
|
|||
|
||||
# Find optimal number of documents using binary search
|
||||
optimal_doc_info = find_optimal_documents_with_binary_search(
|
||||
document_token_info,
|
||||
available_tokens_for_docs
|
||||
document_token_info, available_tokens_for_docs
|
||||
)
|
||||
|
||||
# Extract the original document objects
|
||||
|
|
@ -221,12 +227,13 @@ def optimize_documents_for_token_limit(
|
|||
has_documents_remaining = len(optimized_documents) > 0
|
||||
|
||||
print(
|
||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents")
|
||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
|
||||
)
|
||||
|
||||
return optimized_documents, has_documents_remaining
|
||||
|
||||
|
||||
def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int:
|
||||
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
|
||||
"""Calculate token count for a list of LangChain messages."""
|
||||
model = model_name
|
||||
messages_dict = convert_langchain_messages_to_dict(messages)
|
||||
|
|
|
|||
|
|
@ -2,22 +2,13 @@ from contextlib import asynccontextmanager
|
|||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
|
||||
|
||||
from app.routes import router as crud_router
|
||||
from app.config import config
|
||||
|
||||
from app.users import (
|
||||
SECRET,
|
||||
auth_backend,
|
||||
fastapi_users,
|
||||
current_active_user
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.routes import router as crud_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -64,12 +55,10 @@ app.include_router(
|
|||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from app.users import google_oauth_client
|
||||
|
||||
app.include_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
google_oauth_client,
|
||||
auth_backend,
|
||||
SECRET,
|
||||
is_verified_by_default=True
|
||||
google_oauth_client, auth_backend, SECRET, is_verified_by_default=True
|
||||
),
|
||||
prefix="/auth/google",
|
||||
tags=["auth"],
|
||||
|
|
@ -79,5 +68,8 @@ app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
|
|||
|
||||
|
||||
@app.get("/verify-token")
|
||||
async def authenticated_route(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)):
|
||||
async def authenticated_route(
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
return {"message": "Token is valid"}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from rerankers import Reranker
|
||||
|
||||
|
||||
# Get the base directory of the project
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
|
|
@ -18,37 +16,37 @@ load_dotenv(env_file)
|
|||
def is_ffmpeg_installed():
|
||||
"""
|
||||
Check if ffmpeg is installed on the current system.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if ffmpeg is installed, False otherwise.
|
||||
"""
|
||||
return shutil.which("ffmpeg") is not None
|
||||
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
import static_ffmpeg
|
||||
|
||||
# ffmpeg installed on first call to add_paths(), threadsafe.
|
||||
static_ffmpeg.add_paths()
|
||||
# check if ffmpeg is installed again
|
||||
if not is_ffmpeg_installed():
|
||||
raise ValueError("FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster.")
|
||||
|
||||
raise ValueError(
|
||||
"FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster."
|
||||
)
|
||||
|
||||
# Database
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
||||
|
||||
|
||||
|
||||
# AUTH: Google OAuth
|
||||
AUTH_TYPE = os.getenv("AUTH_TYPE")
|
||||
if AUTH_TYPE == "GOOGLE":
|
||||
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
|
||||
|
||||
|
||||
# LLM instances are now managed per-user through the LLMConfig system
|
||||
# Legacy environment variables removed in favor of user-specific configurations
|
||||
|
||||
|
|
@ -56,12 +54,12 @@ class Config:
|
|||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
|
||||
chunker_instance = RecursiveChunker(
|
||||
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512)
|
||||
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
|
||||
)
|
||||
code_chunker_instance = CodeChunker(
|
||||
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512)
|
||||
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
|
||||
)
|
||||
|
||||
|
||||
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
|
||||
RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME")
|
||||
RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE")
|
||||
|
|
@ -69,45 +67,46 @@ class Config:
|
|||
model_name=RERANKERS_MODEL_NAME,
|
||||
model_type=RERANKERS_MODEL_TYPE,
|
||||
)
|
||||
|
||||
|
||||
# OAuth JWT
|
||||
SECRET_KEY = os.getenv("SECRET_KEY")
|
||||
|
||||
|
||||
# ETL Service
|
||||
ETL_SERVICE = os.getenv("ETL_SERVICE")
|
||||
|
||||
|
||||
if ETL_SERVICE == "UNSTRUCTURED":
|
||||
# Unstructured API Key
|
||||
UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY")
|
||||
|
||||
|
||||
elif ETL_SERVICE == "LLAMACLOUD":
|
||||
# LlamaCloud API Key
|
||||
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
|
||||
|
||||
# Firecrawl API Key
|
||||
FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None)
|
||||
|
||||
FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None)
|
||||
|
||||
# Litellm TTS Configuration
|
||||
TTS_SERVICE = os.getenv("TTS_SERVICE")
|
||||
TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE")
|
||||
TTS_SERVICE_API_KEY = os.getenv("TTS_SERVICE_API_KEY")
|
||||
|
||||
|
||||
# Litellm STT Configuration
|
||||
STT_SERVICE = os.getenv("STT_SERVICE")
|
||||
STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE")
|
||||
STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY")
|
||||
|
||||
|
||||
|
||||
# Validation Checks
|
||||
# Check embedding dimension
|
||||
if hasattr(embedding_model_instance, 'dimension') and embedding_model_instance.dimension > 2000:
|
||||
if (
|
||||
hasattr(embedding_model_instance, "dimension")
|
||||
and embedding_model_instance.dimension > 2000
|
||||
):
|
||||
raise ValueError(
|
||||
f"Embedding dimension for Model: {EMBEDDING_MODEL} "
|
||||
f"has {embedding_model_instance.dimension} dimensions, which "
|
||||
f"exceeds the maximum of 2000 allowed by PGVector."
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
"""Get all settings as a dictionary."""
|
||||
|
|
|
|||
|
|
@ -1,26 +1,25 @@
|
|||
import os
|
||||
|
||||
|
||||
def _parse_bool(value):
|
||||
"""Parse boolean value from string."""
|
||||
return value.lower() == "true" if value else False
|
||||
|
||||
|
||||
def _parse_int(value, var_name):
|
||||
"""Parse integer value with error handling."""
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid integer value for {var_name}: {value}")
|
||||
raise ValueError(f"Invalid integer value for {var_name}: {value}") from None
|
||||
|
||||
|
||||
def _parse_headers(value):
|
||||
"""Parse headers from comma-separated string."""
|
||||
try:
|
||||
return [
|
||||
tuple(h.split(":", 1))
|
||||
for h in value.split(",")
|
||||
if ":" in h
|
||||
]
|
||||
return [tuple(h.split(":", 1)) for h in value.split(",") if ":" in h]
|
||||
except Exception:
|
||||
raise ValueError(f"Invalid headers format: {value}")
|
||||
raise ValueError(f"Invalid headers format: {value}") from None
|
||||
|
||||
|
||||
def load_uvicorn_config(args=None):
|
||||
|
|
@ -28,16 +27,16 @@ def load_uvicorn_config(args=None):
|
|||
Load Uvicorn configuration from environment variables and CLI args.
|
||||
Returns a dict suitable for passing to uvicorn.Config.
|
||||
"""
|
||||
config_kwargs = dict(
|
||||
app="app.app:app",
|
||||
host=os.getenv("UVICORN_HOST", "0.0.0.0"),
|
||||
port=int(os.getenv("UVICORN_PORT", 8000)),
|
||||
log_level=os.getenv("UVICORN_LOG_LEVEL", "info"),
|
||||
reload=args.reload if args else False,
|
||||
reload_dirs=["app"] if (args and args.reload) else None,
|
||||
)
|
||||
|
||||
# Configuration mapping for advanced options
|
||||
config_kwargs = {
|
||||
"app": "app.app:app",
|
||||
"host": os.getenv("UVICORN_HOST", "0.0.0.0"),
|
||||
"port": int(os.getenv("UVICORN_PORT", 8000)),
|
||||
"log_level": os.getenv("UVICORN_LOG_LEVEL", "info"),
|
||||
"reload": args.reload if args else False,
|
||||
"reload_dirs": ["app"] if (args and args.reload) else None,
|
||||
}
|
||||
|
||||
# Configuration mapping for advanced options
|
||||
config_mapping = {
|
||||
"UVICORN_PROXY_HEADERS": ("proxy_headers", _parse_bool),
|
||||
"UVICORN_FORWARDED_ALLOW_IPS": ("forwarded_allow_ips", str),
|
||||
|
|
@ -51,15 +50,33 @@ def load_uvicorn_config(args=None):
|
|||
"UVICORN_LOG_CONFIG": ("log_config", str),
|
||||
"UVICORN_SERVER_HEADER": ("server_header", _parse_bool),
|
||||
"UVICORN_DATE_HEADER": ("date_header", _parse_bool),
|
||||
"UVICORN_LIMIT_CONCURRENCY": ("limit_concurrency", lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY")),
|
||||
"UVICORN_LIMIT_MAX_REQUESTS": ("limit_max_requests", lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS")),
|
||||
"UVICORN_TIMEOUT_KEEP_ALIVE": ("timeout_keep_alive", lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE")),
|
||||
"UVICORN_TIMEOUT_NOTIFY": ("timeout_notify", lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY")),
|
||||
"UVICORN_LIMIT_CONCURRENCY": (
|
||||
"limit_concurrency",
|
||||
lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY"),
|
||||
),
|
||||
"UVICORN_LIMIT_MAX_REQUESTS": (
|
||||
"limit_max_requests",
|
||||
lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS"),
|
||||
),
|
||||
"UVICORN_TIMEOUT_KEEP_ALIVE": (
|
||||
"timeout_keep_alive",
|
||||
lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE"),
|
||||
),
|
||||
"UVICORN_TIMEOUT_NOTIFY": (
|
||||
"timeout_notify",
|
||||
lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY"),
|
||||
),
|
||||
"UVICORN_SSL_KEYFILE": ("ssl_keyfile", str),
|
||||
"UVICORN_SSL_CERTFILE": ("ssl_certfile", str),
|
||||
"UVICORN_SSL_KEYFILE_PASSWORD": ("ssl_keyfile_password", str),
|
||||
"UVICORN_SSL_VERSION": ("ssl_version", lambda x: _parse_int(x, "UVICORN_SSL_VERSION")),
|
||||
"UVICORN_SSL_CERT_REQS": ("ssl_cert_reqs", lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS")),
|
||||
"UVICORN_SSL_VERSION": (
|
||||
"ssl_version",
|
||||
lambda x: _parse_int(x, "UVICORN_SSL_VERSION"),
|
||||
),
|
||||
"UVICORN_SSL_CERT_REQS": (
|
||||
"ssl_cert_reqs",
|
||||
lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS"),
|
||||
),
|
||||
"UVICORN_SSL_CA_CERTS": ("ssl_ca_certs", str),
|
||||
"UVICORN_SSL_CIPHERS": ("ssl_ciphers", str),
|
||||
"UVICORN_HEADERS": ("headers", _parse_headers),
|
||||
|
|
@ -76,7 +93,6 @@ def load_uvicorn_config(args=None):
|
|||
try:
|
||||
config_kwargs[config_key] = parser(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Configuration error for {env_var}: {e}")
|
||||
|
||||
raise ValueError(f"Configuration error for {env_var}: {e}") from e
|
||||
|
||||
return config_kwargs
|
||||
|
|
|
|||
|
|
@ -6,11 +6,12 @@ A module for interacting with Discord's HTTP API to retrieve guilds, channels, a
|
|||
Requires a Discord bot token.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import datetime
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||
class DiscordConnector(commands.Bot):
|
||||
"""Class for retrieving guild, channel, and message history from Discord."""
|
||||
|
||||
def __init__(self, token: str = None):
|
||||
def __init__(self, token: str | None = None):
|
||||
"""
|
||||
Initialize the DiscordConnector with a bot token.
|
||||
|
||||
|
|
@ -30,7 +31,9 @@ class DiscordConnector(commands.Bot):
|
|||
intents.messages = True # Required to fetch messages
|
||||
intents.message_content = True # Required to read message content
|
||||
intents.members = True # Required to fetch member information
|
||||
super().__init__(command_prefix="!", intents=intents) # command_prefix is required but not strictly used here
|
||||
super().__init__(
|
||||
command_prefix="!", intents=intents
|
||||
) # command_prefix is required but not strictly used here
|
||||
self.token = token
|
||||
self._bot_task = None # Holds the async bot task
|
||||
self._is_running = False # Flag to track if the bot is running
|
||||
|
|
@ -48,7 +51,7 @@ class DiscordConnector(commands.Bot):
|
|||
@self.event
|
||||
async def on_disconnect():
|
||||
logger.debug("Bot disconnected from Discord gateway.")
|
||||
self._is_running = False # Reset flag on disconnect
|
||||
self._is_running = False # Reset flag on disconnect
|
||||
|
||||
@self.event
|
||||
async def on_resumed():
|
||||
|
|
@ -63,17 +66,23 @@ class DiscordConnector(commands.Bot):
|
|||
|
||||
try:
|
||||
if self._is_running:
|
||||
logger.warning("Bot is already running. Use close_bot() to stop it before starting again.")
|
||||
logger.warning(
|
||||
"Bot is already running. Use close_bot() to stop it before starting again."
|
||||
)
|
||||
return
|
||||
|
||||
await self.start(self.token)
|
||||
logger.info("Discord bot started successfully.")
|
||||
except discord.LoginFailure:
|
||||
logger.error("Failed to log in: Invalid token was provided. Please check your bot token.")
|
||||
logger.error(
|
||||
"Failed to log in: Invalid token was provided. Please check your bot token."
|
||||
)
|
||||
self._is_running = False
|
||||
raise
|
||||
except discord.PrivilegedIntentsRequired as e:
|
||||
logger.error(f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page.")
|
||||
logger.error(
|
||||
f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page."
|
||||
)
|
||||
self._is_running = False
|
||||
raise
|
||||
except discord.ConnectionClosed as e:
|
||||
|
|
@ -96,7 +105,6 @@ class DiscordConnector(commands.Bot):
|
|||
else:
|
||||
logger.info("Bot is not running or already disconnected.")
|
||||
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""
|
||||
Set the discord bot token.
|
||||
|
|
@ -106,8 +114,10 @@ class DiscordConnector(commands.Bot):
|
|||
"""
|
||||
logger.info("Setting Discord bot token.")
|
||||
self.token = token
|
||||
logger.info("Token set successfully. You can now start the bot with start_bot().")
|
||||
|
||||
logger.info(
|
||||
"Token set successfully. You can now start the bot with start_bot()."
|
||||
)
|
||||
|
||||
async def _wait_until_ready(self):
|
||||
"""Helper to wait until the bot is connected and ready."""
|
||||
logger.info("Waiting for the bot to be ready...")
|
||||
|
|
@ -115,16 +125,20 @@ class DiscordConnector(commands.Bot):
|
|||
# Give the event loop a chance to switch to the bot's startup task.
|
||||
# This allows self.start() to begin initializing the client.
|
||||
# Terrible solution, but necessary to avoid blocking the event loop.
|
||||
await asyncio.sleep(1) # Yield control to the event loop
|
||||
|
||||
await asyncio.sleep(1) # Yield control to the event loop
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self.wait_until_ready(), timeout=60.0)
|
||||
logger.info("Bot is ready.")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Bot did not become ready within 60 seconds. Connection may have failed.")
|
||||
except TimeoutError:
|
||||
logger.error(
|
||||
"Bot did not become ready within 60 seconds. Connection may have failed."
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred while waiting for the bot to be ready: {e}")
|
||||
logger.error(
|
||||
f"An unexpected error occurred while waiting for the bot to be ready: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_guilds(self) -> list[dict]:
|
||||
|
|
@ -143,7 +157,9 @@ class DiscordConnector(commands.Bot):
|
|||
|
||||
guilds_data = []
|
||||
for guild in self.guilds:
|
||||
member_count = guild.member_count if guild.member_count is not None else "N/A"
|
||||
member_count = (
|
||||
guild.member_count if guild.member_count is not None else "N/A"
|
||||
)
|
||||
guilds_data.append(
|
||||
{
|
||||
"id": str(guild.id),
|
||||
|
|
@ -183,15 +199,17 @@ class DiscordConnector(commands.Bot):
|
|||
channels_data.append(
|
||||
{"id": str(channel.id), "name": channel.name, "type": "text"}
|
||||
)
|
||||
|
||||
logger.info(f"Fetched {len(channels_data)} text channels from guild {guild_id}.")
|
||||
|
||||
logger.info(
|
||||
f"Fetched {len(channels_data)} text channels from guild {guild_id}."
|
||||
)
|
||||
return channels_data
|
||||
|
||||
async def get_channel_history(
|
||||
self,
|
||||
channel_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Fetch message history from a text channel.
|
||||
|
|
@ -227,20 +245,26 @@ class DiscordConnector(commands.Bot):
|
|||
|
||||
if start_date:
|
||||
try:
|
||||
start_datetime = datetime.datetime.fromisoformat(start_date).replace(tzinfo=datetime.timezone.utc)
|
||||
start_datetime = datetime.datetime.fromisoformat(start_date).replace(
|
||||
tzinfo=datetime.UTC
|
||||
)
|
||||
after = start_datetime
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid start_date format: {start_date}. Ignoring.")
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(tzinfo=datetime.timezone.utc)
|
||||
end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(
|
||||
tzinfo=datetime.UTC
|
||||
)
|
||||
before = end_datetime
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid end_date format: {end_date}. Ignoring.")
|
||||
|
||||
try:
|
||||
async for message in channel.history(limit=None, before=before, after=after):
|
||||
async for message in channel.history(
|
||||
limit=None, before=before, after=after
|
||||
):
|
||||
messages_data.append(
|
||||
{
|
||||
"id": str(message.id),
|
||||
|
|
@ -251,12 +275,14 @@ class DiscordConnector(commands.Bot):
|
|||
}
|
||||
)
|
||||
except discord.Forbidden:
|
||||
logger.error(f"Bot does not have permissions to read message history in channel {channel_id}.")
|
||||
logger.error(
|
||||
f"Bot does not have permissions to read message history in channel {channel_id}."
|
||||
)
|
||||
raise
|
||||
except discord.HTTPException as e:
|
||||
logger.error(f"Failed to fetch messages from channel {channel_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
logger.info(f"Fetched {len(messages_data)} messages from channel {channel_id}.")
|
||||
return messages_data
|
||||
|
||||
|
|
@ -278,7 +304,9 @@ class DiscordConnector(commands.Bot):
|
|||
permissions to view members.
|
||||
"""
|
||||
await self._wait_until_ready()
|
||||
logger.info(f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}")
|
||||
logger.info(
|
||||
f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}"
|
||||
)
|
||||
|
||||
guild = self.get_guild(int(guild_id))
|
||||
if not guild:
|
||||
|
|
@ -294,7 +322,9 @@ class DiscordConnector(commands.Bot):
|
|||
return {
|
||||
"id": str(member.id),
|
||||
"name": member.name,
|
||||
"joined_at": member.joined_at.isoformat() if member.joined_at else None,
|
||||
"joined_at": member.joined_at.isoformat()
|
||||
if member.joined_at
|
||||
else None,
|
||||
"roles": roles,
|
||||
}
|
||||
logger.warning(f"User {user_id} not found in guild {guild_id}.")
|
||||
|
|
@ -303,8 +333,12 @@ class DiscordConnector(commands.Bot):
|
|||
logger.warning(f"User {user_id} not found in guild {guild_id}.")
|
||||
return None
|
||||
except discord.Forbidden:
|
||||
logger.error(f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled.")
|
||||
logger.error(
|
||||
f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled."
|
||||
)
|
||||
raise
|
||||
except discord.HTTPException as e:
|
||||
logger.error(f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}")
|
||||
logger.error(
|
||||
f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,54 +1,91 @@
|
|||
import base64
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from github3 import login as github_login, exceptions as github_exceptions
|
||||
from github3.repos.contents import Contents
|
||||
from typing import Any
|
||||
|
||||
from github3 import exceptions as github_exceptions, login as github_login
|
||||
from github3.exceptions import ForbiddenError, NotFoundError
|
||||
from github3.repos.contents import Contents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of common code file extensions to target
|
||||
CODE_EXTENSIONS = {
|
||||
'.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp',
|
||||
'.cs', '.go', '.rb', '.php', '.swift', '.kt', '.scala', '.rs', '.m',
|
||||
'.sh', '.bash', '.ps1', '.lua', '.pl', '.pm', '.r', '.dart', '.sql'
|
||||
".py",
|
||||
".js",
|
||||
".jsx",
|
||||
".ts",
|
||||
".tsx",
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".rs",
|
||||
".m",
|
||||
".sh",
|
||||
".bash",
|
||||
".ps1",
|
||||
".lua",
|
||||
".pl",
|
||||
".pm",
|
||||
".r",
|
||||
".dart",
|
||||
".sql",
|
||||
}
|
||||
|
||||
# List of common documentation/text file extensions
|
||||
DOC_EXTENSIONS = {
|
||||
'.md', '.txt', '.rst', '.adoc', '.html', '.htm', '.xml', '.json', '.yaml', '.yml', '.toml'
|
||||
".md",
|
||||
".txt",
|
||||
".rst",
|
||||
".adoc",
|
||||
".html",
|
||||
".htm",
|
||||
".xml",
|
||||
".json",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
}
|
||||
|
||||
# Maximum file size in bytes (e.g., 1MB)
|
||||
MAX_FILE_SIZE = 1 * 1024 * 1024
|
||||
|
||||
|
||||
class GitHubConnector:
|
||||
"""Connector for interacting with the GitHub API."""
|
||||
|
||||
# Directories to skip during file traversal
|
||||
SKIPPED_DIRS = {
|
||||
# Version control
|
||||
'.git',
|
||||
".git",
|
||||
# Dependencies
|
||||
'node_modules',
|
||||
'vendor',
|
||||
"node_modules",
|
||||
"vendor",
|
||||
# Build artifacts / Caches
|
||||
'build',
|
||||
'dist',
|
||||
'target',
|
||||
'__pycache__',
|
||||
"build",
|
||||
"dist",
|
||||
"target",
|
||||
"__pycache__",
|
||||
# Virtual environments
|
||||
'venv',
|
||||
'.venv',
|
||||
'env',
|
||||
"venv",
|
||||
".venv",
|
||||
"env",
|
||||
# IDE/Editor config
|
||||
'.vscode',
|
||||
'.idea',
|
||||
'.project',
|
||||
'.settings',
|
||||
".vscode",
|
||||
".idea",
|
||||
".project",
|
||||
".settings",
|
||||
# Temporary / Logs
|
||||
'tmp',
|
||||
'logs',
|
||||
"tmp",
|
||||
"logs",
|
||||
# Add other project-specific irrelevant directories if needed
|
||||
}
|
||||
|
||||
|
|
@ -68,35 +105,39 @@ class GitHubConnector:
|
|||
logger.info("Successfully authenticated with GitHub API.")
|
||||
except (github_exceptions.AuthenticationFailed, ForbiddenError) as e:
|
||||
logger.error(f"GitHub authentication failed: {e}")
|
||||
raise ValueError("Invalid GitHub token or insufficient permissions.")
|
||||
raise ValueError("Invalid GitHub token or insufficient permissions.") from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GitHub client: {e}")
|
||||
raise
|
||||
raise e
|
||||
|
||||
def get_user_repositories(self) -> List[Dict[str, Any]]:
|
||||
def get_user_repositories(self) -> list[dict[str, Any]]:
|
||||
"""Fetches repositories accessible by the authenticated user."""
|
||||
repos_data = []
|
||||
try:
|
||||
# type='owner' fetches repos owned by the user
|
||||
# type='member' fetches repos the user is a collaborator on (including orgs)
|
||||
# type='all' fetches both
|
||||
for repo in self.gh.repositories(type='all', sort='updated'):
|
||||
repos_data.append({
|
||||
"id": repo.id,
|
||||
"name": repo.name,
|
||||
"full_name": repo.full_name,
|
||||
"private": repo.private,
|
||||
"url": repo.html_url,
|
||||
"description": repo.description or "",
|
||||
"last_updated": repo.updated_at if repo.updated_at else None,
|
||||
})
|
||||
for repo in self.gh.repositories(type="all", sort="updated"):
|
||||
repos_data.append(
|
||||
{
|
||||
"id": repo.id,
|
||||
"name": repo.name,
|
||||
"full_name": repo.full_name,
|
||||
"private": repo.private,
|
||||
"url": repo.html_url,
|
||||
"description": repo.description or "",
|
||||
"last_updated": repo.updated_at if repo.updated_at else None,
|
||||
}
|
||||
)
|
||||
logger.info(f"Fetched {len(repos_data)} repositories.")
|
||||
return repos_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch GitHub repositories: {e}")
|
||||
return [] # Return empty list on error
|
||||
return [] # Return empty list on error
|
||||
|
||||
def get_repository_files(self, repo_full_name: str, path: str = '') -> List[Dict[str, Any]]:
|
||||
def get_repository_files(
|
||||
self, repo_full_name: str, path: str = ""
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Recursively fetches details of relevant files (code, docs) within a repository path.
|
||||
|
||||
|
|
@ -110,54 +151,72 @@ class GitHubConnector:
|
|||
"""
|
||||
files_list = []
|
||||
try:
|
||||
owner, repo_name = repo_full_name.split('/')
|
||||
owner, repo_name = repo_full_name.split("/")
|
||||
repo = self.gh.repository(owner, repo_name)
|
||||
if not repo:
|
||||
logger.warning(f"Repository '{repo_full_name}' not found.")
|
||||
return []
|
||||
contents = repo.directory_contents(directory_path=path) # Use directory_contents for clarity
|
||||
|
||||
contents = repo.directory_contents(
|
||||
directory_path=path
|
||||
) # Use directory_contents for clarity
|
||||
|
||||
# contents returns a list of tuples (name, content_obj)
|
||||
for item_name, content_item in contents:
|
||||
for _item_name, content_item in contents:
|
||||
if not isinstance(content_item, Contents):
|
||||
continue
|
||||
|
||||
if content_item.type == 'dir':
|
||||
if content_item.type == "dir":
|
||||
# Check if the directory name is in the skipped list
|
||||
if content_item.name in self.SKIPPED_DIRS:
|
||||
logger.debug(f"Skipping directory: {content_item.path}")
|
||||
continue # Skip recursion for this directory
|
||||
|
||||
continue # Skip recursion for this directory
|
||||
|
||||
# Recursively fetch contents of subdirectory
|
||||
files_list.extend(self.get_repository_files(repo_full_name, path=content_item.path))
|
||||
elif content_item.type == 'file':
|
||||
files_list.extend(
|
||||
self.get_repository_files(
|
||||
repo_full_name, path=content_item.path
|
||||
)
|
||||
)
|
||||
elif content_item.type == "file":
|
||||
# Check if the file extension is relevant and size is within limits
|
||||
file_extension = '.' + content_item.name.split('.')[-1].lower() if '.' in content_item.name else ''
|
||||
file_extension = (
|
||||
"." + content_item.name.split(".")[-1].lower()
|
||||
if "." in content_item.name
|
||||
else ""
|
||||
)
|
||||
is_code = file_extension in CODE_EXTENSIONS
|
||||
is_doc = file_extension in DOC_EXTENSIONS
|
||||
|
||||
|
||||
if (is_code or is_doc) and content_item.size <= MAX_FILE_SIZE:
|
||||
files_list.append({
|
||||
"path": content_item.path,
|
||||
"sha": content_item.sha,
|
||||
"url": content_item.html_url,
|
||||
"size": content_item.size,
|
||||
"type": "code" if is_code else "doc"
|
||||
})
|
||||
files_list.append(
|
||||
{
|
||||
"path": content_item.path,
|
||||
"sha": content_item.sha,
|
||||
"url": content_item.html_url,
|
||||
"size": content_item.size,
|
||||
"type": "code" if is_code else "doc",
|
||||
}
|
||||
)
|
||||
elif content_item.size > MAX_FILE_SIZE:
|
||||
logger.debug(f"Skipping large file: {content_item.path} ({content_item.size} bytes)")
|
||||
logger.debug(
|
||||
f"Skipping large file: {content_item.path} ({content_item.size} bytes)"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Skipping irrelevant file type: {content_item.path}")
|
||||
logger.debug(
|
||||
f"Skipping irrelevant file type: {content_item.path}"
|
||||
)
|
||||
|
||||
except (NotFoundError, ForbiddenError) as e:
|
||||
logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}")
|
||||
logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get files for {repo_full_name} at path '{path}': {e}")
|
||||
logger.error(
|
||||
f"Failed to get files for {repo_full_name} at path '{path}': {e}"
|
||||
)
|
||||
# Return what we have collected so far in case of partial failure
|
||||
|
||||
|
||||
return files_list
|
||||
|
||||
def get_file_content(self, repo_full_name: str, file_path: str) -> Optional[str]:
|
||||
def get_file_content(self, repo_full_name: str, file_path: str) -> str | None:
|
||||
"""
|
||||
Fetches the decoded content of a specific file.
|
||||
|
||||
|
|
@ -169,43 +228,69 @@ class GitHubConnector:
|
|||
The decoded file content as a string, or None if fetching fails or file is too large.
|
||||
"""
|
||||
try:
|
||||
owner, repo_name = repo_full_name.split('/')
|
||||
owner, repo_name = repo_full_name.split("/")
|
||||
repo = self.gh.repository(owner, repo_name)
|
||||
if not repo:
|
||||
logger.warning(f"Repository '{repo_full_name}' not found when fetching file '{file_path}'.")
|
||||
logger.warning(
|
||||
f"Repository '{repo_full_name}' not found when fetching file '{file_path}'."
|
||||
)
|
||||
return None
|
||||
|
||||
content_item = repo.file_contents(path=file_path) # Use file_contents for clarity
|
||||
|
||||
if not content_item or not isinstance(content_item, Contents) or content_item.type != 'file':
|
||||
logger.warning(f"File '{file_path}' not found or is not a file in '{repo_full_name}'.")
|
||||
content_item = repo.file_contents(
|
||||
path=file_path
|
||||
) # Use file_contents for clarity
|
||||
|
||||
if (
|
||||
not content_item
|
||||
or not isinstance(content_item, Contents)
|
||||
or content_item.type != "file"
|
||||
):
|
||||
logger.warning(
|
||||
f"File '{file_path}' not found or is not a file in '{repo_full_name}'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
if content_item.size > MAX_FILE_SIZE:
|
||||
logger.warning(f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch.")
|
||||
logger.warning(
|
||||
f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch."
|
||||
)
|
||||
return None
|
||||
|
||||
# Content is base64 encoded
|
||||
if content_item.content:
|
||||
try:
|
||||
decoded_content = base64.b64decode(content_item.content).decode('utf-8')
|
||||
decoded_content = base64.b64decode(content_item.content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
return decoded_content
|
||||
except UnicodeDecodeError:
|
||||
logger.warning(f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'.")
|
||||
logger.warning(
|
||||
f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'."
|
||||
)
|
||||
try:
|
||||
# Try a fallback encoding
|
||||
decoded_content = base64.b64decode(content_item.content).decode('latin-1')
|
||||
decoded_content = base64.b64decode(content_item.content).decode(
|
||||
"latin-1"
|
||||
)
|
||||
return decoded_content
|
||||
except Exception as decode_err:
|
||||
logger.error(f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}")
|
||||
return None # Give up if fallback fails
|
||||
logger.error(
|
||||
f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}"
|
||||
)
|
||||
return None # Give up if fallback fails
|
||||
else:
|
||||
logger.warning(f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty.")
|
||||
return "" # Return empty string for empty files
|
||||
logger.warning(
|
||||
f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty."
|
||||
)
|
||||
return "" # Return empty string for empty files
|
||||
|
||||
except (NotFoundError, ForbiddenError) as e:
|
||||
logger.warning(f"Cannot access file '{file_path}' in '{repo_full_name}': {e}")
|
||||
return None
|
||||
logger.warning(
|
||||
f"Cannot access file '{file_path}' in '{repo_full_name}': {e}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}")
|
||||
return None
|
||||
logger.error(
|
||||
f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}"
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ Allows fetching issue lists and their comments, projects and more.
|
|||
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -17,9 +17,9 @@ class JiraConnector:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
api_token: Optional[str] = None,
|
||||
base_url: str | None = None,
|
||||
email: str | None = None,
|
||||
api_token: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the JiraConnector class.
|
||||
|
|
@ -65,7 +65,7 @@ class JiraConnector:
|
|||
"""
|
||||
self.api_token = api_token
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for Jira API requests using Basic Authentication.
|
||||
|
||||
|
|
@ -92,8 +92,8 @@ class JiraConnector:
|
|||
}
|
||||
|
||||
def make_api_request(
|
||||
self, endpoint: str, params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, endpoint: str, params: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Make a request to the Jira API.
|
||||
|
||||
|
|
@ -138,7 +138,7 @@ class JiraConnector:
|
|||
"""
|
||||
return self.make_api_request("project/search")
|
||||
|
||||
def get_all_issues(self, project_key: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
def get_all_issues(self, project_key: str | None = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all issues from Jira.
|
||||
|
||||
|
|
@ -204,8 +204,8 @@ class JiraConnector:
|
|||
start_date: str,
|
||||
end_date: str,
|
||||
include_comments: bool = True,
|
||||
project_key: Optional[str] = None,
|
||||
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
project_key: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
Fetch issues within a date range.
|
||||
|
||||
|
|
@ -226,9 +226,9 @@ class JiraConnector:
|
|||
)
|
||||
# TODO : This JQL needs some improvement to work as expected
|
||||
|
||||
jql = f"{date_filter}"
|
||||
_jql = f"{date_filter}"
|
||||
if project_key:
|
||||
jql = (
|
||||
_jql = (
|
||||
f'project = "{project_key}" AND {date_filter} ORDER BY created DESC'
|
||||
)
|
||||
|
||||
|
|
@ -250,7 +250,7 @@ class JiraConnector:
|
|||
fields.append("comment")
|
||||
|
||||
params = {
|
||||
# "jql": "", TODO : Add a JQL query to filter from a date range
|
||||
# "jql": "", TODO : Add a JQL query to filter from a date range
|
||||
"fields": ",".join(fields),
|
||||
"maxResults": 100,
|
||||
"startAt": 0,
|
||||
|
|
@ -283,9 +283,9 @@ class JiraConnector:
|
|||
return all_issues, None
|
||||
|
||||
except Exception as e:
|
||||
return [], f"Error fetching issues: {str(e)}"
|
||||
return [], f"Error fetching issues: {e!s}"
|
||||
|
||||
def format_issue(self, issue: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Format an issue for easier consumption.
|
||||
|
||||
|
|
@ -401,7 +401,7 @@ class JiraConnector:
|
|||
|
||||
return formatted
|
||||
|
||||
def format_issue_to_markdown(self, issue: Dict[str, Any]) -> str:
|
||||
def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
|
||||
"""
|
||||
Convert an issue to markdown format.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,96 +5,94 @@ A module for retrieving issues and comments from Linear.
|
|||
Allows fetching issue lists and their comments with date range filtering.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class LinearConnector:
|
||||
"""Class for retrieving issues and comments from Linear."""
|
||||
|
||||
def __init__(self, token: str = None):
|
||||
|
||||
def __init__(self, token: str | None = None):
|
||||
"""
|
||||
Initialize the LinearConnector class.
|
||||
|
||||
|
||||
Args:
|
||||
token: Linear API token (optional, can be set later with set_token)
|
||||
"""
|
||||
self.token = token
|
||||
self.api_url = "https://api.linear.app/graphql"
|
||||
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""
|
||||
Set the Linear API token.
|
||||
|
||||
|
||||
Args:
|
||||
token: Linear API token
|
||||
"""
|
||||
self.token = token
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for Linear API requests.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary of headers
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Linear token has been set
|
||||
"""
|
||||
if not self.token:
|
||||
raise ValueError("Linear token not initialized. Call set_token() first.")
|
||||
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': self.token
|
||||
}
|
||||
|
||||
def execute_graphql_query(self, query: str, variables: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
|
||||
return {"Content-Type": "application/json", "Authorization": self.token}
|
||||
|
||||
def execute_graphql_query(
|
||||
self, query: str, variables: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a GraphQL query against the Linear API.
|
||||
|
||||
|
||||
Args:
|
||||
query: GraphQL query string
|
||||
variables: Variables for the GraphQL query (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
Response data from the API
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Linear token has been set
|
||||
Exception: If the API request fails
|
||||
"""
|
||||
if not self.token:
|
||||
raise ValueError("Linear token not initialized. Call set_token() first.")
|
||||
|
||||
|
||||
headers = self.get_headers()
|
||||
payload = {'query': query}
|
||||
|
||||
payload = {"query": query}
|
||||
|
||||
if variables:
|
||||
payload['variables'] = variables
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
payload["variables"] = variables
|
||||
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(f"Query failed with status code {response.status_code}: {response.text}")
|
||||
|
||||
def get_all_issues(self, include_comments: bool = True) -> List[Dict[str, Any]]:
|
||||
raise Exception(
|
||||
f"Query failed with status code {response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all issues from Linear.
|
||||
|
||||
|
||||
Args:
|
||||
include_comments: Whether to include comments in the response
|
||||
|
||||
|
||||
Returns:
|
||||
List of issue objects
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Linear token has been set
|
||||
Exception: If the API request fails
|
||||
|
|
@ -116,7 +114,7 @@ class LinearConnector:
|
|||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
query = f"""
|
||||
query {{
|
||||
issues {{
|
||||
|
|
@ -147,29 +145,30 @@ class LinearConnector:
|
|||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
result = self.execute_graphql_query(query)
|
||||
|
||||
|
||||
# Extract issues from the response
|
||||
if "data" in result and "issues" in result["data"] and "nodes" in result["data"]["issues"]:
|
||||
if (
|
||||
"data" in result
|
||||
and "issues" in result["data"]
|
||||
and "nodes" in result["data"]["issues"]
|
||||
):
|
||||
return result["data"]["issues"]["nodes"]
|
||||
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_issues_by_date_range(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
include_comments: bool = True
|
||||
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
self, start_date: str, end_date: str, include_comments: bool = True
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
Fetch issues within a date range.
|
||||
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format (inclusive)
|
||||
include_comments: Whether to include comments in the response
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple containing (issues list, error message or None)
|
||||
"""
|
||||
|
|
@ -194,7 +193,7 @@ class LinearConnector:
|
|||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
# Query issues that were either created OR updated within the date range
|
||||
# This ensures we catch both new issues and updated existing issues
|
||||
query = f"""
|
||||
|
|
@ -250,58 +249,65 @@ class LinearConnector:
|
|||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
all_issues = []
|
||||
has_next_page = True
|
||||
cursor = None
|
||||
|
||||
|
||||
# Handle pagination to get all issues
|
||||
while has_next_page:
|
||||
variables = {"after": cursor} if cursor else {}
|
||||
result = self.execute_graphql_query(query, variables)
|
||||
|
||||
|
||||
# Check for errors
|
||||
if "errors" in result:
|
||||
error_message = "; ".join([error.get("message", "Unknown error") for error in result["errors"]])
|
||||
error_message = "; ".join(
|
||||
[
|
||||
error.get("message", "Unknown error")
|
||||
for error in result["errors"]
|
||||
]
|
||||
)
|
||||
return [], f"GraphQL errors: {error_message}"
|
||||
|
||||
|
||||
# Extract issues from the response
|
||||
if "data" in result and "issues" in result["data"]:
|
||||
issues_page = result["data"]["issues"]
|
||||
|
||||
|
||||
# Add issues from this page
|
||||
if "nodes" in issues_page:
|
||||
all_issues.extend(issues_page["nodes"])
|
||||
|
||||
|
||||
# Check if there are more pages
|
||||
if "pageInfo" in issues_page:
|
||||
page_info = issues_page["pageInfo"]
|
||||
has_next_page = page_info.get("hasNextPage", False)
|
||||
cursor = page_info.get("endCursor") if has_next_page else None
|
||||
cursor = (
|
||||
page_info.get("endCursor") if has_next_page else None
|
||||
)
|
||||
else:
|
||||
has_next_page = False
|
||||
else:
|
||||
has_next_page = False
|
||||
|
||||
|
||||
if not all_issues:
|
||||
return [], "No issues found in the specified date range."
|
||||
|
||||
|
||||
return all_issues, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return [], f"Error fetching issues: {str(e)}"
|
||||
|
||||
return [], f"Error fetching issues: {e!s}"
|
||||
|
||||
except ValueError as e:
|
||||
return [], f"Invalid date format: {str(e)}. Please use YYYY-MM-DD."
|
||||
|
||||
def format_issue(self, issue: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return [], f"Invalid date format: {e!s}. Please use YYYY-MM-DD."
|
||||
|
||||
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Format an issue for easier consumption.
|
||||
|
||||
|
||||
Args:
|
||||
issue: The issue object from Linear API
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted issue dictionary
|
||||
"""
|
||||
|
|
@ -311,23 +317,37 @@ class LinearConnector:
|
|||
"identifier": issue.get("identifier", ""),
|
||||
"title": issue.get("title", ""),
|
||||
"description": issue.get("description", ""),
|
||||
"state": issue.get("state", {}).get("name", "Unknown") if issue.get("state") else "Unknown",
|
||||
"state_type": issue.get("state", {}).get("type", "Unknown") if issue.get("state") else "Unknown",
|
||||
"state": issue.get("state", {}).get("name", "Unknown")
|
||||
if issue.get("state")
|
||||
else "Unknown",
|
||||
"state_type": issue.get("state", {}).get("type", "Unknown")
|
||||
if issue.get("state")
|
||||
else "Unknown",
|
||||
"created_at": issue.get("createdAt", ""),
|
||||
"updated_at": issue.get("updatedAt", ""),
|
||||
"creator": {
|
||||
"id": issue.get("creator", {}).get("id", "") if issue.get("creator") else "",
|
||||
"name": issue.get("creator", {}).get("name", "Unknown") if issue.get("creator") else "Unknown",
|
||||
"email": issue.get("creator", {}).get("email", "") if issue.get("creator") else ""
|
||||
} if issue.get("creator") else {"id": "", "name": "Unknown", "email": ""},
|
||||
"id": issue.get("creator", {}).get("id", "")
|
||||
if issue.get("creator")
|
||||
else "",
|
||||
"name": issue.get("creator", {}).get("name", "Unknown")
|
||||
if issue.get("creator")
|
||||
else "Unknown",
|
||||
"email": issue.get("creator", {}).get("email", "")
|
||||
if issue.get("creator")
|
||||
else "",
|
||||
}
|
||||
if issue.get("creator")
|
||||
else {"id": "", "name": "Unknown", "email": ""},
|
||||
"assignee": {
|
||||
"id": issue.get("assignee", {}).get("id", ""),
|
||||
"name": issue.get("assignee", {}).get("name", "Unknown"),
|
||||
"email": issue.get("assignee", {}).get("email", "")
|
||||
} if issue.get("assignee") else None,
|
||||
"comments": []
|
||||
"email": issue.get("assignee", {}).get("email", ""),
|
||||
}
|
||||
if issue.get("assignee")
|
||||
else None,
|
||||
"comments": [],
|
||||
}
|
||||
|
||||
|
||||
# Extract comments if available
|
||||
if "comments" in issue and "nodes" in issue["comments"]:
|
||||
for comment in issue["comments"]["nodes"]:
|
||||
|
|
@ -337,85 +357,93 @@ class LinearConnector:
|
|||
"created_at": comment.get("createdAt", ""),
|
||||
"updated_at": comment.get("updatedAt", ""),
|
||||
"user": {
|
||||
"id": comment.get("user", {}).get("id", "") if comment.get("user") else "",
|
||||
"name": comment.get("user", {}).get("name", "Unknown") if comment.get("user") else "Unknown",
|
||||
"email": comment.get("user", {}).get("email", "") if comment.get("user") else ""
|
||||
} if comment.get("user") else {"id": "", "name": "Unknown", "email": ""}
|
||||
"id": comment.get("user", {}).get("id", "")
|
||||
if comment.get("user")
|
||||
else "",
|
||||
"name": comment.get("user", {}).get("name", "Unknown")
|
||||
if comment.get("user")
|
||||
else "Unknown",
|
||||
"email": comment.get("user", {}).get("email", "")
|
||||
if comment.get("user")
|
||||
else "",
|
||||
}
|
||||
if comment.get("user")
|
||||
else {"id": "", "name": "Unknown", "email": ""},
|
||||
}
|
||||
formatted["comments"].append(formatted_comment)
|
||||
|
||||
|
||||
return formatted
|
||||
|
||||
def format_issue_to_markdown(self, issue: Dict[str, Any]) -> str:
|
||||
|
||||
def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
|
||||
"""
|
||||
Convert an issue to markdown format.
|
||||
|
||||
|
||||
Args:
|
||||
issue: The issue object (either raw or formatted)
|
||||
|
||||
|
||||
Returns:
|
||||
Markdown string representation of the issue
|
||||
"""
|
||||
# Format the issue if it's not already formatted
|
||||
if "identifier" not in issue:
|
||||
issue = self.format_issue(issue)
|
||||
|
||||
|
||||
# Build the markdown content
|
||||
markdown = f"# {issue.get('identifier', 'No ID')}: {issue.get('title', 'No Title')}\n\n"
|
||||
|
||||
if issue.get('state'):
|
||||
|
||||
if issue.get("state"):
|
||||
markdown += f"**Status:** {issue['state']}\n\n"
|
||||
|
||||
if issue.get('assignee') and issue['assignee'].get('name'):
|
||||
|
||||
if issue.get("assignee") and issue["assignee"].get("name"):
|
||||
markdown += f"**Assignee:** {issue['assignee']['name']}\n"
|
||||
|
||||
if issue.get('creator') and issue['creator'].get('name'):
|
||||
|
||||
if issue.get("creator") and issue["creator"].get("name"):
|
||||
markdown += f"**Created by:** {issue['creator']['name']}\n"
|
||||
|
||||
if issue.get('created_at'):
|
||||
created_date = self.format_date(issue['created_at'])
|
||||
|
||||
if issue.get("created_at"):
|
||||
created_date = self.format_date(issue["created_at"])
|
||||
markdown += f"**Created:** {created_date}\n"
|
||||
|
||||
if issue.get('updated_at'):
|
||||
updated_date = self.format_date(issue['updated_at'])
|
||||
|
||||
if issue.get("updated_at"):
|
||||
updated_date = self.format_date(issue["updated_at"])
|
||||
markdown += f"**Updated:** {updated_date}\n\n"
|
||||
|
||||
if issue.get('description'):
|
||||
|
||||
if issue.get("description"):
|
||||
markdown += f"## Description\n\n{issue['description']}\n\n"
|
||||
|
||||
if issue.get('comments'):
|
||||
|
||||
if issue.get("comments"):
|
||||
markdown += f"## Comments ({len(issue['comments'])})\n\n"
|
||||
|
||||
for comment in issue['comments']:
|
||||
|
||||
for comment in issue["comments"]:
|
||||
user_name = "Unknown"
|
||||
if comment.get('user') and comment['user'].get('name'):
|
||||
user_name = comment['user']['name']
|
||||
|
||||
if comment.get("user") and comment["user"].get("name"):
|
||||
user_name = comment["user"]["name"]
|
||||
|
||||
comment_date = "Unknown date"
|
||||
if comment.get('created_at'):
|
||||
comment_date = self.format_date(comment['created_at'])
|
||||
|
||||
if comment.get("created_at"):
|
||||
comment_date = self.format_date(comment["created_at"])
|
||||
|
||||
markdown += f"### {user_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n"
|
||||
|
||||
|
||||
return markdown
|
||||
|
||||
|
||||
@staticmethod
|
||||
def format_date(iso_date: str) -> str:
|
||||
"""
|
||||
Format an ISO date string to a more readable format.
|
||||
|
||||
|
||||
Args:
|
||||
iso_date: ISO format date string
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted date string
|
||||
"""
|
||||
if not iso_date or not isinstance(iso_date, str):
|
||||
return "Unknown date"
|
||||
|
||||
|
||||
try:
|
||||
dt = datetime.fromisoformat(iso_date.replace('Z', '+00:00'))
|
||||
return dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
dt = datetime.fromisoformat(iso_date.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
return iso_date
|
||||
|
||||
|
|
|
|||
|
|
@ -1,176 +1,182 @@
|
|||
from notion_client import Client
|
||||
|
||||
|
||||
class NotionHistoryConnector:
|
||||
def __init__(self, token):
|
||||
"""
|
||||
Initialize the NotionPageFetcher with a token.
|
||||
|
||||
|
||||
Args:
|
||||
token (str): Notion integration token
|
||||
"""
|
||||
self.notion = Client(auth=token)
|
||||
|
||||
|
||||
def get_all_pages(self, start_date=None, end_date=None):
|
||||
"""
|
||||
Fetches all pages shared with your integration and their content.
|
||||
|
||||
|
||||
Args:
|
||||
start_date (str, optional): ISO 8601 date string (e.g., "2023-01-01T00:00:00Z")
|
||||
end_date (str, optional): ISO 8601 date string (e.g., "2023-12-31T23:59:59Z")
|
||||
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing page data
|
||||
"""
|
||||
# Build the filter for the search
|
||||
# Note: Notion API requires specific filter structure
|
||||
search_params = {}
|
||||
|
||||
|
||||
# Filter for pages only (not databases)
|
||||
search_params["filter"] = {
|
||||
"value": "page",
|
||||
"property": "object"
|
||||
}
|
||||
|
||||
search_params["filter"] = {"value": "page", "property": "object"}
|
||||
|
||||
# Add date filters if provided
|
||||
if start_date or end_date:
|
||||
date_filter = {}
|
||||
|
||||
|
||||
if start_date:
|
||||
date_filter["on_or_after"] = start_date
|
||||
|
||||
|
||||
if end_date:
|
||||
date_filter["on_or_before"] = end_date
|
||||
|
||||
|
||||
# Add the date filter to the search params
|
||||
if date_filter:
|
||||
search_params["sort"] = {
|
||||
"direction": "descending",
|
||||
"timestamp": "last_edited_time"
|
||||
"timestamp": "last_edited_time",
|
||||
}
|
||||
|
||||
|
||||
# First, get a list of all pages the integration has access to
|
||||
search_results = self.notion.search(**search_params)
|
||||
|
||||
|
||||
pages = search_results["results"]
|
||||
all_page_data = []
|
||||
|
||||
|
||||
for page in pages:
|
||||
page_id = page["id"]
|
||||
|
||||
|
||||
# Get detailed page information
|
||||
page_content = self.get_page_content(page_id)
|
||||
|
||||
all_page_data.append({
|
||||
"page_id": page_id,
|
||||
"title": self.get_page_title(page),
|
||||
"content": page_content
|
||||
})
|
||||
|
||||
|
||||
all_page_data.append(
|
||||
{
|
||||
"page_id": page_id,
|
||||
"title": self.get_page_title(page),
|
||||
"content": page_content,
|
||||
}
|
||||
)
|
||||
|
||||
return all_page_data
|
||||
|
||||
|
||||
def get_page_title(self, page):
|
||||
"""
|
||||
Extracts the title from a page object.
|
||||
|
||||
|
||||
Args:
|
||||
page (dict): Notion page object
|
||||
|
||||
|
||||
Returns:
|
||||
str: Page title or a fallback string
|
||||
"""
|
||||
# Title can be in different properties depending on the page type
|
||||
if "properties" in page:
|
||||
# Try to find a title property
|
||||
for prop_name, prop_data in page["properties"].items():
|
||||
for _prop_name, prop_data in page["properties"].items():
|
||||
if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
|
||||
return " ".join([text_obj["plain_text"] for text_obj in prop_data["title"]])
|
||||
|
||||
return " ".join(
|
||||
[text_obj["plain_text"] for text_obj in prop_data["title"]]
|
||||
)
|
||||
|
||||
# If no title found, return the page ID as fallback
|
||||
return f"Untitled page ({page['id']})"
|
||||
|
||||
|
||||
def get_page_content(self, page_id):
|
||||
"""
|
||||
Fetches the content (blocks) of a specific page.
|
||||
|
||||
|
||||
Args:
|
||||
page_id (str): The ID of the page to fetch
|
||||
|
||||
|
||||
Returns:
|
||||
list: List of processed blocks from the page
|
||||
"""
|
||||
blocks = []
|
||||
has_more = True
|
||||
cursor = None
|
||||
|
||||
|
||||
# Paginate through all blocks
|
||||
while has_more:
|
||||
if cursor:
|
||||
response = self.notion.blocks.children.list(block_id=page_id, start_cursor=cursor)
|
||||
response = self.notion.blocks.children.list(
|
||||
block_id=page_id, start_cursor=cursor
|
||||
)
|
||||
else:
|
||||
response = self.notion.blocks.children.list(block_id=page_id)
|
||||
|
||||
|
||||
blocks.extend(response["results"])
|
||||
has_more = response["has_more"]
|
||||
|
||||
|
||||
if has_more:
|
||||
cursor = response["next_cursor"]
|
||||
|
||||
|
||||
# Process nested blocks recursively
|
||||
processed_blocks = []
|
||||
for block in blocks:
|
||||
processed_block = self.process_block(block)
|
||||
processed_blocks.append(processed_block)
|
||||
|
||||
|
||||
return processed_blocks
|
||||
|
||||
|
||||
def process_block(self, block):
|
||||
"""
|
||||
Processes a block and recursively fetches any child blocks.
|
||||
|
||||
|
||||
Args:
|
||||
block (dict): The block to process
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Processed block with content and children
|
||||
"""
|
||||
block_id = block["id"]
|
||||
block_type = block["type"]
|
||||
|
||||
|
||||
# Extract block content based on its type
|
||||
content = self.extract_block_content(block)
|
||||
|
||||
|
||||
# Check if block has children
|
||||
has_children = block.get("has_children", False)
|
||||
child_blocks = []
|
||||
|
||||
|
||||
if has_children:
|
||||
# Fetch and process child blocks
|
||||
children_response = self.notion.blocks.children.list(block_id=block_id)
|
||||
for child_block in children_response["results"]:
|
||||
child_blocks.append(self.process_block(child_block))
|
||||
|
||||
|
||||
return {
|
||||
"id": block_id,
|
||||
"type": block_type,
|
||||
"content": content,
|
||||
"children": child_blocks
|
||||
"children": child_blocks,
|
||||
}
|
||||
|
||||
|
||||
def extract_block_content(self, block):
|
||||
"""
|
||||
Extracts the content from a block based on its type.
|
||||
|
||||
|
||||
Args:
|
||||
block (dict): The block to extract content from
|
||||
|
||||
|
||||
Returns:
|
||||
str: Extracted content as a string
|
||||
"""
|
||||
block_type = block["type"]
|
||||
|
||||
|
||||
# Different block types have different structures
|
||||
if block_type in block and "rich_text" in block[block_type]:
|
||||
return "".join([text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]])
|
||||
return "".join(
|
||||
[text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]]
|
||||
)
|
||||
elif block_type == "image":
|
||||
# Instead of returning the raw URL which may contain sensitive AWS credentials,
|
||||
# return a placeholder or reference to the image
|
||||
|
|
@ -183,18 +189,21 @@ class NotionHistoryConnector:
|
|||
# Only return the domain part of external URLs to avoid potential sensitive parameters
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
return f"[External Image from {parsed_url.netloc}]"
|
||||
except:
|
||||
except Exception:
|
||||
return "[External Image]"
|
||||
elif block_type == "code":
|
||||
language = block["code"]["language"]
|
||||
code_text = "".join([text_obj["plain_text"] for text_obj in block["code"]["rich_text"]])
|
||||
code_text = "".join(
|
||||
[text_obj["plain_text"] for text_obj in block["code"]["rich_text"]]
|
||||
)
|
||||
return f"```{language}\n{code_text}\n```"
|
||||
elif block_type == "equation":
|
||||
return block["equation"]["expression"]
|
||||
# Add more block types as needed
|
||||
|
||||
|
||||
# Return empty string for unsupported block types
|
||||
return ""
|
||||
|
||||
|
|
@ -203,23 +212,23 @@ class NotionHistoryConnector:
|
|||
# if __name__ == "__main__":
|
||||
# # Simple example of how to use this module
|
||||
# import argparse
|
||||
|
||||
|
||||
# parser = argparse.ArgumentParser(description="Fetch Notion pages using an integration token")
|
||||
# parser.add_argument("--token", help="Your Notion integration token")
|
||||
# parser.add_argument("--start-date", help="Start date in ISO format (e.g., 2023-01-01T00:00:00Z)")
|
||||
# parser.add_argument("--end-date", help="End date in ISO format (e.g., 2023-12-31T23:59:59Z)")
|
||||
# args = parser.parse_args()
|
||||
|
||||
|
||||
# token = args.token
|
||||
# if not token:
|
||||
# token = input("Enter your Notion integration token: ")
|
||||
|
||||
|
||||
# fetcher = NotionPageFetcher(token)
|
||||
|
||||
|
||||
# try:
|
||||
# pages = fetcher.get_all_pages(args.start_date, args.end_date)
|
||||
# print(f"Fetched {len(pages)} pages from Notion")
|
||||
# for page in pages:
|
||||
# print(f"- {page['title']}")
|
||||
# except Exception as e:
|
||||
# print(f"Error: {str(e)}")
|
||||
# print(f"Error: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -5,47 +5,48 @@ A module for retrieving conversation history from Slack channels.
|
|||
Allows fetching channel lists and message history with date range filtering.
|
||||
"""
|
||||
|
||||
import time # Added import
|
||||
import logging # Added import
|
||||
import logging # Added import
|
||||
import time # Added import
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
|
||||
logger = logging.getLogger(__name__) # Added logger
|
||||
logger = logging.getLogger(__name__) # Added logger
|
||||
|
||||
|
||||
class SlackHistory:
|
||||
"""Class for retrieving conversation history from Slack channels."""
|
||||
|
||||
def __init__(self, token: str = None):
|
||||
|
||||
def __init__(self, token: str | None = None):
|
||||
"""
|
||||
Initialize the SlackHistory class.
|
||||
|
||||
|
||||
Args:
|
||||
token: Slack API token (optional, can be set later with set_token)
|
||||
"""
|
||||
self.client = WebClient(token=token) if token else None
|
||||
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""
|
||||
Set the Slack API token.
|
||||
|
||||
|
||||
Args:
|
||||
token: Slack API token
|
||||
"""
|
||||
self.client = WebClient(token=token)
|
||||
|
||||
def get_all_channels(self, include_private: bool = True) -> List[Dict[str, Any]]:
|
||||
|
||||
def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all channels that the bot has access to, with rate limit handling.
|
||||
|
||||
|
||||
Args:
|
||||
include_private: Whether to include private channels
|
||||
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each representing a channel with id, name, is_private, is_member.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an unrecoverable error calling the Slack API
|
||||
|
|
@ -53,8 +54,8 @@ class SlackHistory:
|
|||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
channels_list = [] # Changed from dict to list
|
||||
|
||||
channels_list = [] # Changed from dict to list
|
||||
types = "public_channel"
|
||||
if include_private:
|
||||
types += ",private_channel"
|
||||
|
|
@ -65,16 +66,16 @@ class SlackHistory:
|
|||
while is_first_request or next_cursor:
|
||||
try:
|
||||
if not is_first_request: # Add delay only for paginated requests
|
||||
logger.info(f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}")
|
||||
logger.info(
|
||||
f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}"
|
||||
)
|
||||
time.sleep(3)
|
||||
|
||||
current_limit = 1000 # Max limit
|
||||
api_result = self.client.conversations_list(
|
||||
types=types,
|
||||
cursor=next_cursor,
|
||||
limit=current_limit
|
||||
types=types, cursor=next_cursor, limit=current_limit
|
||||
)
|
||||
|
||||
|
||||
channels_on_page = api_result["channels"]
|
||||
for channel in channels_on_page:
|
||||
if "name" in channel and "id" in channel:
|
||||
|
|
@ -86,12 +87,13 @@ class SlackHistory:
|
|||
# It indicates if the authenticated user (bot) is a member.
|
||||
# For public channels, this might be true or the API might not focus on it
|
||||
# if the bot can read it anyway. For private, it's crucial.
|
||||
"is_member": channel.get("is_member", False)
|
||||
"is_member": channel.get("is_member", False),
|
||||
}
|
||||
channels_list.append(channel_data)
|
||||
else:
|
||||
logger.warning(f"Channel found with missing name or id. Data: {channel}")
|
||||
|
||||
logger.warning(
|
||||
f"Channel found with missing name or id. Data: {channel}"
|
||||
)
|
||||
|
||||
next_cursor = api_result.get("response_metadata", {}).get("next_cursor")
|
||||
is_first_request = False # Subsequent requests are not the first
|
||||
|
|
@ -101,57 +103,65 @@ class SlackHistory:
|
|||
|
||||
except SlackApiError as e:
|
||||
if e.response is not None and e.response.status_code == 429:
|
||||
retry_after_header = e.response.headers.get('Retry-After')
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
wait_duration = 60 # Default wait time
|
||||
if retry_after_header and retry_after_header.isdigit():
|
||||
wait_duration = int(retry_after_header)
|
||||
|
||||
logger.warning(f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}")
|
||||
|
||||
logger.warning(
|
||||
f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}"
|
||||
)
|
||||
time.sleep(wait_duration)
|
||||
# The loop will continue, retrying with the same cursor
|
||||
else:
|
||||
# Not a 429 error, or no response object, re-raise
|
||||
raise SlackApiError(f"Error retrieving channels: {e}", e.response)
|
||||
raise SlackApiError(
|
||||
f"Error retrieving channels: {e}", e.response
|
||||
) from e
|
||||
except Exception as general_error:
|
||||
# Handle other potential errors like network issues if necessary, or re-raise
|
||||
logger.error(f"An unexpected error occurred during channel fetching: {general_error}")
|
||||
raise RuntimeError(f"An unexpected error occurred during channel fetching: {general_error}")
|
||||
|
||||
logger.error(
|
||||
f"An unexpected error occurred during channel fetching: {general_error}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"An unexpected error occurred during channel fetching: {general_error}"
|
||||
) from general_error
|
||||
|
||||
return channels_list
|
||||
|
||||
|
||||
def get_conversation_history(
|
||||
self,
|
||||
channel_id: str,
|
||||
limit: int = 1000,
|
||||
oldest: Optional[int] = None,
|
||||
latest: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
self,
|
||||
channel_id: str,
|
||||
limit: int = 1000,
|
||||
oldest: int | None = None,
|
||||
latest: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch conversation history for a channel.
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to fetch history for
|
||||
limit: Maximum number of messages to return per request (default 1000)
|
||||
oldest: Start of time range (Unix timestamp)
|
||||
latest: End of time range (Unix timestamp)
|
||||
|
||||
|
||||
Returns:
|
||||
List of message objects
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
|
||||
messages = []
|
||||
next_cursor = None
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Proactive delay for conversations.history (Tier 3)
|
||||
time.sleep(1.2) # Wait 1.2 seconds before each history call.
|
||||
time.sleep(1.2) # Wait 1.2 seconds before each history call.
|
||||
|
||||
kwargs = {
|
||||
"channel": channel_id,
|
||||
|
|
@ -163,16 +173,19 @@ class SlackHistory:
|
|||
kwargs["latest"] = latest
|
||||
if next_cursor:
|
||||
kwargs["cursor"] = next_cursor
|
||||
|
||||
|
||||
current_api_call_successful = False
|
||||
result = None # Ensure result is defined
|
||||
result = None # Ensure result is defined
|
||||
try:
|
||||
result = self.client.conversations_history(**kwargs)
|
||||
current_api_call_successful = True
|
||||
except SlackApiError as e_history:
|
||||
if e_history.response is not None and e_history.response.status_code == 429:
|
||||
retry_after_str = e_history.response.headers.get('Retry-After')
|
||||
wait_time = 60 # Default
|
||||
if (
|
||||
e_history.response is not None
|
||||
and e_history.response.status_code == 429
|
||||
):
|
||||
retry_after_str = e_history.response.headers.get("Retry-After")
|
||||
wait_time = 60 # Default
|
||||
if retry_after_str and retry_after_str.isdigit():
|
||||
wait_time = int(retry_after_str)
|
||||
logger.warning(
|
||||
|
|
@ -182,47 +195,54 @@ class SlackHistory:
|
|||
time.sleep(wait_time)
|
||||
# current_api_call_successful remains False, loop will retry this page
|
||||
else:
|
||||
raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors
|
||||
|
||||
raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors
|
||||
|
||||
if not current_api_call_successful:
|
||||
continue # Retry the current page fetch due to handled rate limit
|
||||
continue # Retry the current page fetch due to handled rate limit
|
||||
|
||||
# Process result if successful
|
||||
batch = result["messages"]
|
||||
messages.extend(batch)
|
||||
|
||||
|
||||
if result.get("has_more", False) and len(messages) < limit:
|
||||
next_cursor = result["response_metadata"]["next_cursor"]
|
||||
else:
|
||||
break # Exit pagination loop
|
||||
|
||||
except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try
|
||||
if (e.response is not None and
|
||||
hasattr(e.response, 'data') and
|
||||
isinstance(e.response.data, dict) and
|
||||
e.response.data.get('error') == 'not_in_channel'):
|
||||
break # Exit pagination loop
|
||||
|
||||
except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try
|
||||
if (
|
||||
e.response is not None
|
||||
and hasattr(e.response, "data")
|
||||
and isinstance(e.response.data, dict)
|
||||
and e.response.data.get("error") == "not_in_channel"
|
||||
):
|
||||
logger.warning(
|
||||
f"Bot is not in channel '{channel_id}'. Cannot fetch history. "
|
||||
"Please add the bot to this channel."
|
||||
)
|
||||
return []
|
||||
return []
|
||||
# For other SlackApiErrors from inner block or this level
|
||||
raise SlackApiError(f"Error retrieving history for channel {channel_id}: {e}", e.response)
|
||||
except Exception as general_error: # Catch any other unexpected errors
|
||||
logger.error(f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}")
|
||||
raise SlackApiError(
|
||||
f"Error retrieving history for channel {channel_id}: {e}",
|
||||
e.response,
|
||||
) from e
|
||||
except Exception as general_error: # Catch any other unexpected errors
|
||||
logger.error(
|
||||
f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}"
|
||||
)
|
||||
# Re-raise the general error to allow higher-level handling or visibility
|
||||
raise
|
||||
|
||||
raise general_error from general_error
|
||||
|
||||
return messages[:limit]
|
||||
|
||||
@staticmethod
|
||||
def convert_date_to_timestamp(date_str: str) -> Optional[int]:
|
||||
def convert_date_to_timestamp(date_str: str) -> int | None:
|
||||
"""
|
||||
Convert a date string in format YYYY-MM-DD to Unix timestamp.
|
||||
|
||||
|
||||
Args:
|
||||
date_str: Date string in YYYY-MM-DD format
|
||||
|
||||
|
||||
Returns:
|
||||
Unix timestamp (seconds since epoch) or None if invalid format
|
||||
"""
|
||||
|
|
@ -231,67 +251,63 @@ class SlackHistory:
|
|||
return int(dt.timestamp())
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def get_history_by_date_range(
|
||||
self,
|
||||
channel_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
limit: int = 1000
|
||||
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
self, channel_id: str, start_date: str, end_date: str, limit: int = 1000
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
Fetch conversation history within a date range.
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to fetch history for
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format (inclusive)
|
||||
limit: Maximum number of messages to return
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple containing (messages list, error message or None)
|
||||
"""
|
||||
oldest = self.convert_date_to_timestamp(start_date)
|
||||
if not oldest:
|
||||
return [], f"Invalid start date format: {start_date}. Please use YYYY-MM-DD."
|
||||
|
||||
return (
|
||||
[],
|
||||
f"Invalid start date format: {start_date}. Please use YYYY-MM-DD.",
|
||||
)
|
||||
|
||||
latest = self.convert_date_to_timestamp(end_date)
|
||||
if not latest:
|
||||
return [], f"Invalid end date format: {end_date}. Please use YYYY-MM-DD."
|
||||
|
||||
|
||||
# Add one day to end date to make it inclusive
|
||||
latest += 86400 # seconds in a day
|
||||
|
||||
|
||||
try:
|
||||
messages = self.get_conversation_history(
|
||||
channel_id=channel_id,
|
||||
limit=limit,
|
||||
oldest=oldest,
|
||||
latest=latest
|
||||
channel_id=channel_id, limit=limit, oldest=oldest, latest=latest
|
||||
)
|
||||
return messages, None
|
||||
except SlackApiError as e:
|
||||
return [], f"Slack API error: {str(e)}"
|
||||
return [], f"Slack API error: {e!s}"
|
||||
except ValueError as e:
|
||||
return [], str(e)
|
||||
|
||||
def get_user_info(self, user_id: str) -> Dict[str, Any]:
|
||||
|
||||
def get_user_info(self, user_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about a user.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to get info for
|
||||
|
||||
|
||||
Returns:
|
||||
User information dictionary
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Proactive delay for users.info (Tier 4) - generally not needed unless called extremely rapidly.
|
||||
|
|
@ -299,46 +315,60 @@ class SlackHistory:
|
|||
# time.sleep(0.6) # Optional: ~100 req/min if ever needed.
|
||||
|
||||
result = self.client.users_info(user=user_id)
|
||||
return result["user"] # Success, return and exit loop implicitly
|
||||
return result["user"] # Success, return and exit loop implicitly
|
||||
|
||||
except SlackApiError as e_user_info:
|
||||
if e_user_info.response is not None and e_user_info.response.status_code == 429:
|
||||
retry_after_str = e_user_info.response.headers.get('Retry-After')
|
||||
if (
|
||||
e_user_info.response is not None
|
||||
and e_user_info.response.status_code == 429
|
||||
):
|
||||
retry_after_str = e_user_info.response.headers.get("Retry-After")
|
||||
wait_time = 30 # Default for Tier 4, can be adjusted
|
||||
if retry_after_str and retry_after_str.isdigit():
|
||||
wait_time = int(retry_after_str)
|
||||
logger.warning(f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds.")
|
||||
logger.warning(
|
||||
f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
continue # Retry the API call
|
||||
else:
|
||||
# Not a 429 error, or no response object, re-raise
|
||||
raise SlackApiError(f"Error retrieving user info for {user_id}: {e_user_info}", e_user_info.response)
|
||||
except Exception as general_error: # Catch any other unexpected errors
|
||||
logger.error(f"Unexpected error in get_user_info for user {user_id}: {general_error}")
|
||||
raise # Re-raise unexpected errors
|
||||
|
||||
def format_message(self, msg: Dict[str, Any], include_user_info: bool = False) -> Dict[str, Any]:
|
||||
raise SlackApiError(
|
||||
f"Error retrieving user info for {user_id}: {e_user_info}",
|
||||
e_user_info.response,
|
||||
) from e_user_info
|
||||
except Exception as general_error: # Catch any other unexpected errors
|
||||
logger.error(
|
||||
f"Unexpected error in get_user_info for user {user_id}: {general_error}"
|
||||
)
|
||||
raise general_error from general_error # Re-raise unexpected errors
|
||||
|
||||
def format_message(
|
||||
self, msg: dict[str, Any], include_user_info: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Format a message for easier consumption.
|
||||
|
||||
|
||||
Args:
|
||||
msg: The message object from Slack API
|
||||
include_user_info: Whether to fetch and include user info
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted message dictionary
|
||||
"""
|
||||
formatted = {
|
||||
"text": msg.get("text", ""),
|
||||
"timestamp": msg.get("ts"),
|
||||
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
"user_id": msg.get("user", "UNKNOWN"),
|
||||
"has_attachments": bool(msg.get("attachments")),
|
||||
"has_files": bool(msg.get("files")),
|
||||
"thread_ts": msg.get("thread_ts"),
|
||||
"is_thread": "thread_ts" in msg,
|
||||
}
|
||||
|
||||
|
||||
if include_user_info and "user" in msg and self.client:
|
||||
try:
|
||||
user_info = self.get_user_info(msg["user"])
|
||||
|
|
@ -347,7 +377,7 @@ class SlackHistory:
|
|||
except Exception:
|
||||
# If we can't get user info, just continue without it
|
||||
formatted["user_name"] = "Unknown"
|
||||
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
|
|
@ -388,4 +418,4 @@ if __name__ == "__main__":
|
|||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,23 +1,24 @@
|
|||
import unittest
|
||||
from unittest.mock import patch, Mock
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from github3.exceptions import ForbiddenError # Import the specific exception
|
||||
|
||||
# Adjust the import path based on the actual location if test_github_connector.py
|
||||
# is not in the same directory as github_connector.py or if paths are set up differently.
|
||||
# Assuming surfsend_backend/app/connectors/test_github_connector.py
|
||||
from surfsense_backend.app.connectors.github_connector import GitHubConnector
|
||||
from github3.exceptions import ForbiddenError # Import the specific exception
|
||||
|
||||
|
||||
class TestGitHubConnector(unittest.TestCase):
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_get_user_repositories_uses_type_all(self, mock_github_login):
|
||||
# Mock the GitHub client object and its methods
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
||||
# Mock the self.gh.me() call in __init__ to prevent an actual API call
|
||||
mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization
|
||||
mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization
|
||||
|
||||
# Prepare mock repository data
|
||||
mock_repo1_data = Mock()
|
||||
|
|
@ -27,7 +28,9 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
mock_repo1_data.private = False
|
||||
mock_repo1_data.html_url = "http://example.com/user/repo1"
|
||||
mock_repo1_data.description = "Test repo 1"
|
||||
mock_repo1_data.updated_at = datetime(2023, 1, 1, 10, 30, 0) # Added time component
|
||||
mock_repo1_data.updated_at = datetime(
|
||||
2023, 1, 1, 10, 30, 0
|
||||
) # Added time component
|
||||
|
||||
mock_repo2_data = Mock()
|
||||
mock_repo2_data.id = 2
|
||||
|
|
@ -36,8 +39,10 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
mock_repo2_data.private = True
|
||||
mock_repo2_data.html_url = "http://example.com/org/org-repo"
|
||||
mock_repo2_data.description = "Org repo"
|
||||
mock_repo2_data.updated_at = datetime(2023, 1, 2, 12, 0, 0) # Added time component
|
||||
|
||||
mock_repo2_data.updated_at = datetime(
|
||||
2023, 1, 2, 12, 0, 0
|
||||
) # Added time component
|
||||
|
||||
# Configure the mock for gh.repositories() call
|
||||
# This method is an iterator, so it should return an iterable (e.g., a list)
|
||||
mock_gh_instance.repositories.return_value = [mock_repo1_data, mock_repo2_data]
|
||||
|
|
@ -46,26 +51,38 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
repositories = connector.get_user_repositories()
|
||||
|
||||
# Assert that gh.repositories was called correctly
|
||||
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated')
|
||||
mock_gh_instance.repositories.assert_called_once_with(
|
||||
type="all", sort="updated"
|
||||
)
|
||||
|
||||
# Assert the structure and content of the returned data
|
||||
expected_repositories = [
|
||||
{
|
||||
"id": 1, "name": "repo1", "full_name": "user/repo1", "private": False,
|
||||
"url": "http://example.com/user/repo1", "description": "Test repo 1",
|
||||
"last_updated": datetime(2023, 1, 1, 10, 30, 0)
|
||||
"id": 1,
|
||||
"name": "repo1",
|
||||
"full_name": "user/repo1",
|
||||
"private": False,
|
||||
"url": "http://example.com/user/repo1",
|
||||
"description": "Test repo 1",
|
||||
"last_updated": datetime(2023, 1, 1, 10, 30, 0),
|
||||
},
|
||||
{
|
||||
"id": 2, "name": "org-repo", "full_name": "org/org-repo", "private": True,
|
||||
"url": "http://example.com/org/org-repo", "description": "Org repo",
|
||||
"last_updated": datetime(2023, 1, 2, 12, 0, 0)
|
||||
}
|
||||
"id": 2,
|
||||
"name": "org-repo",
|
||||
"full_name": "org/org-repo",
|
||||
"private": True,
|
||||
"url": "http://example.com/org/org-repo",
|
||||
"description": "Org repo",
|
||||
"last_updated": datetime(2023, 1, 2, 12, 0, 0),
|
||||
},
|
||||
]
|
||||
self.assertEqual(repositories, expected_repositories)
|
||||
self.assertEqual(len(repositories), 2)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
def test_get_user_repositories_handles_empty_description_and_none_updated_at(self, mock_github_login):
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_get_user_repositories_handles_empty_description_and_none_updated_at(
|
||||
self, mock_github_login
|
||||
):
|
||||
# Mock the GitHub client object and its methods
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
|
@ -77,61 +94,73 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
mock_repo_data.full_name = "user/repo_no_desc"
|
||||
mock_repo_data.private = False
|
||||
mock_repo_data.html_url = "http://example.com/user/repo_no_desc"
|
||||
mock_repo_data.description = None # Test None description
|
||||
mock_repo_data.updated_at = None # Test None updated_at
|
||||
mock_repo_data.description = None # Test None description
|
||||
mock_repo_data.updated_at = None # Test None updated_at
|
||||
|
||||
mock_gh_instance.repositories.return_value = [mock_repo_data]
|
||||
connector = GitHubConnector(token="fake_token")
|
||||
repositories = connector.get_user_repositories()
|
||||
|
||||
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated')
|
||||
mock_gh_instance.repositories.assert_called_once_with(
|
||||
type="all", sort="updated"
|
||||
)
|
||||
expected_repositories = [
|
||||
{
|
||||
"id": 1, "name": "repo_no_desc", "full_name": "user/repo_no_desc", "private": False,
|
||||
"url": "http://example.com/user/repo_no_desc", "description": "", # Expect empty string
|
||||
"last_updated": None # Expect None
|
||||
"id": 1,
|
||||
"name": "repo_no_desc",
|
||||
"full_name": "user/repo_no_desc",
|
||||
"private": False,
|
||||
"url": "http://example.com/user/repo_no_desc",
|
||||
"description": "", # Expect empty string
|
||||
"last_updated": None, # Expect None
|
||||
}
|
||||
]
|
||||
self.assertEqual(repositories, expected_repositories)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_github_connector_initialization_failure_forbidden(self, mock_github_login):
|
||||
# Test that __init__ raises ValueError on auth failure (ForbiddenError)
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
||||
|
||||
# Create a mock response object for the ForbiddenError
|
||||
# The actual response structure might vary, but github3.py's ForbiddenError
|
||||
# can be instantiated with just a response object that has a status_code.
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 403 # Typically Forbidden
|
||||
|
||||
mock_response.status_code = 403 # Typically Forbidden
|
||||
|
||||
# Setup the side_effect for self.gh.me()
|
||||
mock_gh_instance.me.side_effect = ForbiddenError(mock_response)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
GitHubConnector(token="invalid_token_forbidden")
|
||||
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception))
|
||||
self.assertIn(
|
||||
"Invalid GitHub token or insufficient permissions.", str(context.exception)
|
||||
)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
def test_github_connector_initialization_failure_authentication_failed(self, mock_github_login):
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_github_connector_initialization_failure_authentication_failed(
|
||||
self, mock_github_login
|
||||
):
|
||||
# Test that __init__ raises ValueError on auth failure (AuthenticationFailed, which is a subclass of ForbiddenError)
|
||||
# For github3.py, AuthenticationFailed is more specific for token issues.
|
||||
from github3.exceptions import AuthenticationFailed
|
||||
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401 # Typically Unauthorized
|
||||
|
||||
mock_response.status_code = 401 # Typically Unauthorized
|
||||
|
||||
mock_gh_instance.me.side_effect = AuthenticationFailed(mock_response)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
GitHubConnector(token="invalid_token_authfailed")
|
||||
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception))
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
self.assertIn(
|
||||
"Invalid GitHub token or insufficient permissions.", str(context.exception)
|
||||
)
|
||||
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_get_user_repositories_handles_api_exception(self, mock_github_login):
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
|
@ -142,13 +171,18 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
|
||||
connector = GitHubConnector(token="fake_token")
|
||||
# We expect it to log an error and return an empty list
|
||||
with patch('surfsense_backend.app.connectors.github_connector.logger') as mock_logger:
|
||||
with patch(
|
||||
"surfsense_backend.app.connectors.github_connector.logger"
|
||||
) as mock_logger:
|
||||
repositories = connector.get_user_repositories()
|
||||
|
||||
|
||||
self.assertEqual(repositories, [])
|
||||
mock_logger.error.assert_called_once()
|
||||
self.assertIn("Failed to fetch GitHub repositories: API Error", mock_logger.error.call_args[0][0])
|
||||
self.assertIn(
|
||||
"Failed to fetch GitHub repositories: API Error",
|
||||
mock_logger.error.call_args[0][0],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,373 +1,448 @@
|
|||
import unittest
|
||||
import time # Imported to be available for patching target module
|
||||
from unittest.mock import patch, Mock, call
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
# Since test_slack_history.py is in the same directory as slack_history.py
|
||||
from .slack_history import SlackHistory
|
||||
|
||||
class TestSlackHistoryGetAllChannels(unittest.TestCase):
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_pagination_with_delay(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
class TestSlackHistoryGetAllChannels(unittest.TestCase):
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_pagination_with_delay(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
# Mock API responses now include is_private and is_member
|
||||
page1_response = {
|
||||
"channels": [
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True},
|
||||
{"name": "dev", "id": "C0", "is_private": False, "is_member": True}
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True},
|
||||
{"name": "dev", "id": "C0", "is_private": False, "is_member": True},
|
||||
],
|
||||
"response_metadata": {"next_cursor": "cursor123"}
|
||||
"response_metadata": {"next_cursor": "cursor123"},
|
||||
}
|
||||
page2_response = {
|
||||
"channels": [{"name": "random", "id": "C2", "is_private": True, "is_member": True}],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
"channels": [
|
||||
{"name": "random", "id": "C2", "is_private": True, "is_member": True}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
page1_response,
|
||||
page2_response
|
||||
page2_response,
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C1", "name": "general", "is_private": False, "is_member": True},
|
||||
{"id": "C0", "name": "dev", "is_private": False, "is_member": True},
|
||||
{"id": "C2", "name": "random", "is_private": True, "is_member": True}
|
||||
{"id": "C2", "name": "random", "is_private": True, "is_member": True},
|
||||
]
|
||||
|
||||
|
||||
self.assertEqual(len(channels_list), 3)
|
||||
self.assertListEqual(channels_list, expected_channels_list) # Assert list equality
|
||||
|
||||
self.assertListEqual(
|
||||
channels_list, expected_channels_list
|
||||
) # Assert list equality
|
||||
|
||||
expected_calls = [
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||
call(types="public_channel,private_channel", cursor="cursor123", limit=1000)
|
||||
call(
|
||||
types="public_channel,private_channel", cursor="cursor123", limit=1000
|
||||
),
|
||||
]
|
||||
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
mock_sleep.assert_called_once_with(3)
|
||||
mock_logger.info.assert_called_once_with("Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123")
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_rate_limit_with_retry_after(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_sleep.assert_called_once_with(3)
|
||||
mock_logger.info.assert_called_once_with(
|
||||
"Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123"
|
||||
)
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_rate_limit_with_retry_after(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {'Retry-After': '5'}
|
||||
|
||||
mock_error_response.headers = {"Retry-After": "5"}
|
||||
|
||||
successful_response = {
|
||||
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
"channels": [
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
successful_response
|
||||
successful_response,
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
|
||||
]
|
||||
self.assertEqual(len(channels_list), 1)
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
|
||||
mock_sleep.assert_called_once_with(5)
|
||||
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None")
|
||||
|
||||
|
||||
mock_sleep.assert_called_once_with(5)
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None"
|
||||
)
|
||||
|
||||
expected_calls = [
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000)
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||
]
|
||||
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_rate_limit_no_retry_after_valid_header(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_rate_limit_no_retry_after_valid_header(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {'Retry-After': 'invalid_value'}
|
||||
|
||||
mock_error_response.headers = {"Retry-After": "invalid_value"}
|
||||
|
||||
successful_response = {
|
||||
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
"channels": [
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
successful_response
|
||||
successful_response,
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
|
||||
]
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None")
|
||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
|
||||
)
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_rate_limit_no_retry_after_header(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_rate_limit_no_retry_after_header(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {}
|
||||
|
||||
successful_response = {
|
||||
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
}
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
successful_response
|
||||
]
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None")
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_other_slack_api_error(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 500
|
||||
mock_error_response.headers = {}
|
||||
mock_error_response.data = {"ok": False, "error": "internal_error"}
|
||||
|
||||
original_error = SlackApiError(message="server error", response=mock_error_response)
|
||||
mock_client_instance.conversations_list.side_effect = original_error
|
||||
|
||||
|
||||
successful_response = {
|
||||
"channels": [
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
successful_response,
|
||||
]
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
|
||||
]
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
|
||||
)
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_other_slack_api_error(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 500
|
||||
mock_error_response.headers = {}
|
||||
mock_error_response.data = {"ok": False, "error": "internal_error"}
|
||||
|
||||
original_error = SlackApiError(
|
||||
message="server error", response=mock_error_response
|
||||
)
|
||||
mock_client_instance.conversations_list.side_effect = original_error
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
with self.assertRaises(SlackApiError) as context:
|
||||
slack_history.get_all_channels(include_private=True)
|
||||
|
||||
|
||||
self.assertEqual(context.exception.response.status_code, 500)
|
||||
self.assertIn("server error", str(context.exception))
|
||||
mock_sleep.assert_not_called()
|
||||
mock_logger.warning.assert_not_called() # Ensure no rate limit log
|
||||
mock_logger.warning.assert_not_called() # Ensure no rate limit log
|
||||
mock_client_instance.conversations_list.assert_called_once_with(
|
||||
types="public_channel,private_channel", cursor=None, limit=1000
|
||||
)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_handles_missing_name_id_gracefully(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_handles_missing_name_id_gracefully(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
response_with_malformed_data = {
|
||||
"channels": [
|
||||
{"id": "C1_missing_name", "is_private": False, "is_member": True},
|
||||
{"id": "C1_missing_name", "is_private": False, "is_member": True},
|
||||
{"name": "channel_missing_id", "is_private": False, "is_member": True},
|
||||
{"name": "general", "id": "C2_valid", "is_private": False, "is_member": True}
|
||||
{
|
||||
"name": "general",
|
||||
"id": "C2_valid",
|
||||
"is_private": False,
|
||||
"is_member": True,
|
||||
},
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
mock_client_instance.conversations_list.return_value = response_with_malformed_data
|
||||
|
||||
|
||||
mock_client_instance.conversations_list.return_value = (
|
||||
response_with_malformed_data
|
||||
)
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C2_valid", "name": "general", "is_private": False, "is_member": True}
|
||||
]
|
||||
self.assertEqual(len(channels_list), 1)
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
|
||||
self.assertEqual(mock_logger.warning.call_count, 2)
|
||||
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}")
|
||||
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}")
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
expected_channels_list = [
|
||||
{
|
||||
"id": "C2_valid",
|
||||
"name": "general",
|
||||
"is_private": False,
|
||||
"is_member": True,
|
||||
}
|
||||
]
|
||||
self.assertEqual(len(channels_list), 1)
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
|
||||
self.assertEqual(mock_logger.warning.call_count, 2)
|
||||
mock_logger.warning.assert_any_call(
|
||||
"Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}"
|
||||
)
|
||||
mock_logger.warning.assert_any_call(
|
||||
"Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}"
|
||||
)
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
mock_client_instance.conversations_list.assert_called_once_with(
|
||||
types="public_channel,private_channel", cursor=None, limit=1000
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_proactive_delay_single_page(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_proactive_delay_single_page(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
mock_client_instance.conversations_history.return_value = {
|
||||
"messages": [{"text": "msg1"}],
|
||||
"has_more": False
|
||||
"has_more": False,
|
||||
}
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_proactive_delay_multiple_pages(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_proactive_delay_multiple_pages(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
mock_client_instance.conversations_history.side_effect = [
|
||||
{
|
||||
"messages": [{"text": "msg1"}],
|
||||
"has_more": True,
|
||||
"response_metadata": {"next_cursor": "cursor1"}
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
{
|
||||
"messages": [{"text": "msg2"}],
|
||||
"has_more": False
|
||||
}
|
||||
{"messages": [{"text": "msg2"}], "has_more": False},
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
|
||||
# Expected calls: 1.2 (page1), 1.2 (page2)
|
||||
self.assertEqual(mock_time_sleep.call_count, 2)
|
||||
mock_time_sleep.assert_has_calls([call(1.2), call(1.2)])
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {'Retry-After': '5'}
|
||||
|
||||
mock_error_response.headers = {"Retry-After": "5"}
|
||||
|
||||
mock_client_instance.conversations_history.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
{"messages": [{"text": "msg1"}], "has_more": False}
|
||||
{"messages": [{"text": "msg1"}], "has_more": False},
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
messages = slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
|
||||
self.assertEqual(len(messages), 1)
|
||||
self.assertEqual(messages[0]["text"], "msg1")
|
||||
|
||||
# Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt)
|
||||
mock_time_sleep.assert_has_calls([call(1.2), call(5), call(1.2)], any_order=False)
|
||||
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_not_in_channel_error(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
# Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt)
|
||||
mock_time_sleep.assert_has_calls(
|
||||
[call(1.2), call(5), call(1.2)], any_order=False
|
||||
)
|
||||
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_not_in_channel_error(self, mock_web_client, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 403 # Typical for not_in_channel, but data matters more
|
||||
mock_error_response.data = {'ok': False, 'error': 'not_in_channel'}
|
||||
|
||||
mock_error_response.status_code = (
|
||||
403 # Typical for not_in_channel, but data matters more
|
||||
)
|
||||
mock_error_response.data = {"ok": False, "error": "not_in_channel"}
|
||||
|
||||
# This error is now raised by the inner try-except, then caught by the outer one
|
||||
mock_client_instance.conversations_history.side_effect = SlackApiError(
|
||||
message="not_in_channel error",
|
||||
response=mock_error_response
|
||||
message="not_in_channel error", response=mock_error_response
|
||||
)
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
messages = slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
|
||||
self.assertEqual(messages, [])
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Bot is not in channel 'C123'. Cannot fetch history. Please add the bot to this channel."
|
||||
)
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay before the API call
|
||||
mock_time_sleep.assert_called_once_with(
|
||||
1.2
|
||||
) # Proactive delay before the API call
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_other_slack_api_error_propagates(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 500
|
||||
mock_error_response.data = {'ok': False, 'error': 'internal_error'}
|
||||
original_error = SlackApiError(message="server error", response=mock_error_response)
|
||||
mock_error_response.data = {"ok": False, "error": "internal_error"}
|
||||
original_error = SlackApiError(
|
||||
message="server error", response=mock_error_response
|
||||
)
|
||||
|
||||
mock_client_instance.conversations_history.side_effect = original_error
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
|
||||
with self.assertRaises(SlackApiError) as context:
|
||||
slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
self.assertIn("Error retrieving history for channel C123", str(context.exception))
|
||||
self.assertIs(context.exception.response, mock_error_response)
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
self.assertIn(
|
||||
"Error retrieving history for channel C123", str(context.exception)
|
||||
)
|
||||
self.assertIs(context.exception.response, mock_error_response)
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_general_exception_propagates(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
original_error = Exception("Something broke")
|
||||
mock_client_instance.conversations_history.side_effect = original_error
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
with self.assertRaises(Exception) as context: # Check for generic Exception
|
||||
|
||||
with self.assertRaises(Exception) as context: # Check for generic Exception
|
||||
slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
self.assertIs(context.exception, original_error) # Should re-raise the original error
|
||||
mock_logger.error.assert_called_once_with("Unexpected error in get_conversation_history for channel C123: Something broke")
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
self.assertIs(
|
||||
context.exception, original_error
|
||||
) # Should re-raise the original error
|
||||
mock_logger.error.assert_called_once_with(
|
||||
"Unexpected error in get_conversation_history for channel C123: Something broke"
|
||||
)
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
|
||||
class TestSlackHistoryGetUserInfo(unittest.TestCase):
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {'Retry-After': '3'} # Using 3 seconds for test
|
||||
|
||||
mock_error_response.headers = {"Retry-After": "3"} # Using 3 seconds for test
|
||||
|
||||
successful_user_data = {"id": "U123", "name": "testuser"}
|
||||
|
||||
|
||||
mock_client_instance.users_info.side_effect = [
|
||||
SlackApiError(message="ratelimited_userinfo", response=mock_error_response),
|
||||
{"user": successful_user_data}
|
||||
{"user": successful_user_data},
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
user_info = slack_history.get_user_info(user_id="U123")
|
||||
|
||||
|
||||
self.assertEqual(user_info, successful_user_data)
|
||||
|
||||
|
||||
# Assert that time.sleep was called for the rate limit
|
||||
mock_time_sleep.assert_called_once_with(3)
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
|
|
@ -375,46 +450,58 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase):
|
|||
)
|
||||
# Assert users_info was called twice (original + retry)
|
||||
self.assertEqual(mock_client_instance.users_info.call_count, 2)
|
||||
mock_client_instance.users_info.assert_has_calls([call(user="U123"), call(user="U123")])
|
||||
mock_client_instance.users_info.assert_has_calls(
|
||||
[call(user="U123"), call(user="U123")]
|
||||
)
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch(
|
||||
"surfsense_backend.app.connectors.slack_history.time.sleep"
|
||||
) # time.sleep might be called by other logic, but not expected here
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_other_slack_api_error_propagates(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') # time.sleep might be called by other logic, but not expected here
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 500 # Some other error
|
||||
mock_error_response.data = {'ok': False, 'error': 'internal_server_error'}
|
||||
original_error = SlackApiError(message="internal server error", response=mock_error_response)
|
||||
mock_error_response.status_code = 500 # Some other error
|
||||
mock_error_response.data = {"ok": False, "error": "internal_server_error"}
|
||||
original_error = SlackApiError(
|
||||
message="internal server error", response=mock_error_response
|
||||
)
|
||||
|
||||
mock_client_instance.users_info.side_effect = original_error
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
|
||||
with self.assertRaises(SlackApiError) as context:
|
||||
slack_history.get_user_info(user_id="U123")
|
||||
|
||||
|
||||
# Check that the raised error is the one we expect
|
||||
self.assertIn("Error retrieving user info for U123", str(context.exception))
|
||||
self.assertIs(context.exception.response, mock_error_response)
|
||||
mock_time_sleep.assert_not_called() # No rate limit sleep
|
||||
mock_time_sleep.assert_not_called() # No rate limit sleep
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_general_exception_propagates(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
original_error = Exception("A very generic problem")
|
||||
mock_client_instance.users_info.side_effect = original_error
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
|
||||
with self.assertRaises(Exception) as context:
|
||||
slack_history.get_user_info(user_id="U123")
|
||||
|
||||
self.assertIs(context.exception, original_error) # Check it's the exact same exception
|
||||
|
||||
self.assertIs(
|
||||
context.exception, original_error
|
||||
) # Check it's the exact same exception
|
||||
mock_logger.error.assert_called_once_with(
|
||||
"Unexpected error in get_user_info for user U123: A very generic problem"
|
||||
)
|
||||
mock_time_sleep.assert_not_called() # No rate limit sleep
|
||||
mock_time_sleep.assert_not_called() # No rate limit sleep
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
|
||||
from app.config import config
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
|
|
@ -13,9 +11,7 @@ from sqlalchemy import (
|
|||
TIMESTAMP,
|
||||
Boolean,
|
||||
Column,
|
||||
)
|
||||
from sqlalchemy import Enum as SQLAlchemyEnum
|
||||
from sqlalchemy import (
|
||||
Enum as SQLAlchemyEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
|
|
@ -26,17 +22,12 @@ from sqlalchemy.dialects.postgresql import UUID
|
|||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
|
||||
|
||||
from app.config import config
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from fastapi_users.db import (
|
||||
SQLAlchemyBaseOAuthAccountTableUUID,
|
||||
SQLAlchemyBaseUserTableUUID,
|
||||
SQLAlchemyUserDatabase,
|
||||
)
|
||||
else:
|
||||
from fastapi_users.db import (
|
||||
SQLAlchemyBaseUserTableUUID,
|
||||
SQLAlchemyUserDatabase,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
|
||||
DATABASE_URL = config.DATABASE_URL
|
||||
|
||||
|
|
@ -118,11 +109,11 @@ class Base(DeclarativeBase):
|
|||
|
||||
class TimestampMixin:
|
||||
@declared_attr
|
||||
def created_at(cls):
|
||||
def created_at(cls): # noqa: N805
|
||||
return Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
default=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from datetime import datetime, timezone
|
||||
|
||||
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
|
||||
DATE_TODAY = "Today's date is " + datetime.now(UTC).astimezone().isoformat() + "\n"
|
||||
|
||||
SUMMARY_PROMPT = DATE_TODAY + """
|
||||
SUMMARY_PROMPT = (
|
||||
DATE_TODAY
|
||||
+ """
|
||||
<INSTRUCTIONS>
|
||||
<context>
|
||||
You are an expert document analyst and summarization specialist tasked with distilling complex information into clear,
|
||||
|
|
@ -96,8 +99,8 @@ SUMMARY_PROMPT = DATE_TODAY + """
|
|||
</document_to_summarize>
|
||||
</INSTRUCTIONS>
|
||||
"""
|
||||
)
|
||||
|
||||
SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
|
||||
input_variables=["document"],
|
||||
template=SUMMARY_PROMPT
|
||||
)
|
||||
input_variables=["document"], template=SUMMARY_PROMPT
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,34 +2,41 @@ class ChucksHybridSearchRetriever:
|
|||
def __init__(self, db_session):
|
||||
"""
|
||||
Initialize the hybrid search retriever with a database session.
|
||||
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
|
||||
"""
|
||||
self.db_session = db_session
|
||||
|
||||
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
async def vector_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Perform vector similarity search on chunks.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
|
||||
Returns:
|
||||
List of chunks sorted by vector similarity
|
||||
"""
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
|
||||
from app.config import config
|
||||
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Chunk)
|
||||
|
|
@ -38,45 +45,48 @@ class ChucksHybridSearchRetriever:
|
|||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
)
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add vector similarity ordering
|
||||
query = (
|
||||
query
|
||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
|
||||
return chunks
|
||||
|
||||
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
async def full_text_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Perform full-text keyword search on chunks.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
|
||||
Returns:
|
||||
List of chunks sorted by text relevance
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Chunk.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
tsvector = func.to_tsvector("english", Chunk.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Chunk)
|
||||
|
|
@ -84,64 +94,70 @@ class ChucksHybridSearchRetriever:
|
|||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
|
||||
.where(
|
||||
tsvector.op("@@")(tsquery)
|
||||
) # Only include results that match the query
|
||||
)
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add text search ranking
|
||||
query = (
|
||||
query
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
|
||||
return chunks
|
||||
|
||||
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
document_type: str | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing chunk data and relevance scores
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace, DocumentType
|
||||
|
||||
from app.config import config
|
||||
|
||||
from app.db import Chunk, Document, DocumentType, SearchSpace
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
|
||||
# Constants for RRF calculation
|
||||
k = 60 # Constant for RRF calculation
|
||||
n_results = top_k * 2 # Get more results for better fusion
|
||||
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Chunk.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
tsvector = func.to_tsvector("english", Chunk.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
||||
# Base conditions for document filtering
|
||||
base_conditions = [SearchSpace.user_id == user_id]
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
base_conditions.append(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add document type filter if provided
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
|
|
@ -154,90 +170,97 @@ class ChucksHybridSearchRetriever:
|
|||
return []
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
|
||||
|
||||
# CTE for semantic search with user ownership check
|
||||
semantic_search_cte = (
|
||||
select(
|
||||
Chunk.id,
|
||||
func.rank().over(order_by=Chunk.embedding.op("<=>")(query_embedding)).label("rank")
|
||||
func.rank()
|
||||
.over(order_by=Chunk.embedding.op("<=>")(query_embedding))
|
||||
.label("rank"),
|
||||
)
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
)
|
||||
|
||||
|
||||
semantic_search_cte = (
|
||||
semantic_search_cte
|
||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
semantic_search_cte.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
.limit(n_results)
|
||||
.cte("semantic_search")
|
||||
)
|
||||
|
||||
|
||||
# CTE for keyword search with user ownership check
|
||||
keyword_search_cte = (
|
||||
select(
|
||||
Chunk.id,
|
||||
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
|
||||
func.rank()
|
||||
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.label("rank"),
|
||||
)
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
.where(tsvector.op("@@")(tsquery))
|
||||
)
|
||||
|
||||
|
||||
keyword_search_cte = (
|
||||
keyword_search_cte
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(n_results)
|
||||
.cte("keyword_search")
|
||||
)
|
||||
|
||||
|
||||
# Final combined query using a FULL OUTER JOIN with RRF scoring
|
||||
final_query = (
|
||||
select(
|
||||
Chunk,
|
||||
(
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
|
||||
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score")
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
|
||||
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score"),
|
||||
)
|
||||
.select_from(
|
||||
semantic_search_cte.outerjoin(
|
||||
keyword_search_cte,
|
||||
keyword_search_cte,
|
||||
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
||||
full=True
|
||||
full=True,
|
||||
)
|
||||
)
|
||||
.join(
|
||||
Chunk,
|
||||
Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
|
||||
Chunk.id
|
||||
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
|
||||
)
|
||||
.options(joinedload(Chunk.document))
|
||||
.order_by(text("score DESC"))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(final_query)
|
||||
chunks_with_scores = result.all()
|
||||
|
||||
|
||||
# If no results were found, return an empty list
|
||||
if not chunks_with_scores:
|
||||
return []
|
||||
|
||||
|
||||
# Convert to serializable dictionaries if no reranker is available or if reranking failed
|
||||
serialized_results = []
|
||||
for chunk, score in chunks_with_scores:
|
||||
serialized_results.append({
|
||||
"chunk_id": chunk.id,
|
||||
"content": chunk.content,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"document": {
|
||||
"id": chunk.document.id,
|
||||
"title": chunk.document.title,
|
||||
"document_type": chunk.document.document_type.value if hasattr(chunk.document, 'document_type') else None,
|
||||
"metadata": chunk.document.document_metadata
|
||||
serialized_results.append(
|
||||
{
|
||||
"chunk_id": chunk.id,
|
||||
"content": chunk.content,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"document": {
|
||||
"id": chunk.document.id,
|
||||
"title": chunk.document.title,
|
||||
"document_type": chunk.document.document_type.value
|
||||
if hasattr(chunk.document, "document_type")
|
||||
else None,
|
||||
"metadata": chunk.document.document_metadata,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
)
|
||||
|
||||
return serialized_results
|
||||
|
|
|
|||
|
|
@ -2,34 +2,41 @@ class DocumentHybridSearchRetriever:
|
|||
def __init__(self, db_session):
|
||||
"""
|
||||
Initialize the hybrid search retriever with a database session.
|
||||
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
|
||||
"""
|
||||
self.db_session = db_session
|
||||
|
||||
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
async def vector_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Perform vector similarity search on documents.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
|
||||
Returns:
|
||||
List of documents sorted by vector similarity
|
||||
"""
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Document, SearchSpace
|
||||
|
||||
from app.config import config
|
||||
|
||||
from app.db import Document, SearchSpace
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Document)
|
||||
|
|
@ -37,107 +44,118 @@ class DocumentHybridSearchRetriever:
|
|||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
)
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add vector similarity ordering
|
||||
query = (
|
||||
query
|
||||
.order_by(Document.embedding.op("<=>")(query_embedding))
|
||||
.limit(top_k)
|
||||
query = query.order_by(Document.embedding.op("<=>")(query_embedding)).limit(
|
||||
top_k
|
||||
)
|
||||
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
|
||||
return documents
|
||||
|
||||
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
async def full_text_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Perform full-text keyword search on documents.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
|
||||
Returns:
|
||||
List of documents sorted by text relevance
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.db import Document, SearchSpace
|
||||
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Document.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
tsvector = func.to_tsvector("english", Document.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Document)
|
||||
.options(joinedload(Document.search_space))
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
|
||||
.where(
|
||||
tsvector.op("@@")(tsquery)
|
||||
) # Only include results that match the query
|
||||
)
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add text search ranking
|
||||
query = (
|
||||
query
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
|
||||
return documents
|
||||
|
||||
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
document_type: str | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
|
||||
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Document, SearchSpace, DocumentType
|
||||
|
||||
from app.config import config
|
||||
|
||||
from app.db import Document, DocumentType, SearchSpace
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
|
||||
# Constants for RRF calculation
|
||||
k = 60 # Constant for RRF calculation
|
||||
n_results = top_k * 2 # Get more results for better fusion
|
||||
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Document.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
tsvector = func.to_tsvector("english", Document.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
||||
# Base conditions for document filtering
|
||||
base_conditions = [SearchSpace.user_id == user_id]
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
base_conditions.append(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add document type filter if provided
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
|
|
@ -150,98 +168,112 @@ class DocumentHybridSearchRetriever:
|
|||
return []
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
|
||||
|
||||
# CTE for semantic search with user ownership check
|
||||
semantic_search_cte = (
|
||||
select(
|
||||
Document.id,
|
||||
func.rank().over(order_by=Document.embedding.op("<=>")(query_embedding)).label("rank")
|
||||
func.rank()
|
||||
.over(order_by=Document.embedding.op("<=>")(query_embedding))
|
||||
.label("rank"),
|
||||
)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
)
|
||||
|
||||
|
||||
semantic_search_cte = (
|
||||
semantic_search_cte
|
||||
.order_by(Document.embedding.op("<=>")(query_embedding))
|
||||
semantic_search_cte.order_by(Document.embedding.op("<=>")(query_embedding))
|
||||
.limit(n_results)
|
||||
.cte("semantic_search")
|
||||
)
|
||||
|
||||
|
||||
# CTE for keyword search with user ownership check
|
||||
keyword_search_cte = (
|
||||
select(
|
||||
Document.id,
|
||||
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
|
||||
func.rank()
|
||||
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.label("rank"),
|
||||
)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
.where(tsvector.op("@@")(tsquery))
|
||||
)
|
||||
|
||||
|
||||
keyword_search_cte = (
|
||||
keyword_search_cte
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(n_results)
|
||||
.cte("keyword_search")
|
||||
)
|
||||
|
||||
|
||||
# Final combined query using a FULL OUTER JOIN with RRF scoring
|
||||
final_query = (
|
||||
select(
|
||||
Document,
|
||||
(
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
|
||||
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score")
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
|
||||
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score"),
|
||||
)
|
||||
.select_from(
|
||||
semantic_search_cte.outerjoin(
|
||||
keyword_search_cte,
|
||||
keyword_search_cte,
|
||||
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
||||
full=True
|
||||
full=True,
|
||||
)
|
||||
)
|
||||
.join(
|
||||
Document,
|
||||
Document.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
|
||||
Document.id
|
||||
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
|
||||
)
|
||||
.options(joinedload(Document.search_space))
|
||||
.order_by(text("score DESC"))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(final_query)
|
||||
documents_with_scores = result.all()
|
||||
|
||||
|
||||
# If no results were found, return an empty list
|
||||
if not documents_with_scores:
|
||||
return []
|
||||
|
||||
|
||||
# Convert to serializable dictionaries
|
||||
serialized_results = []
|
||||
for document, score in documents_with_scores:
|
||||
# Fetch associated chunks for this document
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Chunk
|
||||
|
||||
chunks_query = select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
|
||||
|
||||
chunks_query = (
|
||||
select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
|
||||
# Concatenate chunks content
|
||||
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else document.content
|
||||
|
||||
serialized_results.append({
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"content": document.content,
|
||||
"chunks_content": concatenated_chunks_content,
|
||||
"document_type": document.document_type.value if hasattr(document, 'document_type') else None,
|
||||
"metadata": document.document_metadata,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"search_space_id": document.search_space_id
|
||||
})
|
||||
|
||||
return serialized_results
|
||||
concatenated_chunks_content = (
|
||||
" ".join([chunk.content for chunk in chunks])
|
||||
if chunks
|
||||
else document.content
|
||||
)
|
||||
|
||||
serialized_results.append(
|
||||
{
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"content": document.content,
|
||||
"chunks_content": concatenated_chunks_content,
|
||||
"document_type": document.document_type.value
|
||||
if hasattr(document, "document_type")
|
||||
else None,
|
||||
"metadata": document.document_metadata,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"search_space_id": document.search_space_id,
|
||||
}
|
||||
)
|
||||
|
||||
return serialized_results
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from fastapi import APIRouter
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
from .documents_routes import router as documents_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
|
||||
from .chats_routes import router as chats_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .documents_routes import router as documents_router
|
||||
from .llm_config_routes import router as llm_config_router
|
||||
from .logs_routes import router as logs_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,38 +1,40 @@
|
|||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import Chat, SearchSpace, User, get_async_session
|
||||
from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate
|
||||
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from langchain.schema import HumanMessage, AIMessage
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def handle_chat_data(
|
||||
request: AISDKChatRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
messages = request.messages
|
||||
if messages[-1]['role'] != "user":
|
||||
if messages[-1]["role"] != "user":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Last message must be a user message")
|
||||
status_code=400, detail="Last message must be a user message"
|
||||
)
|
||||
|
||||
user_query = messages[-1]['content']
|
||||
search_space_id = request.data.get('search_space_id')
|
||||
research_mode: str = request.data.get('research_mode')
|
||||
selected_connectors: List[str] = request.data.get('selected_connectors')
|
||||
document_ids_to_add_in_context: List[int] = request.data.get('document_ids_to_add_in_context')
|
||||
|
||||
search_mode_str = request.data.get('search_mode', "CHUNKS")
|
||||
user_query = messages[-1]["content"]
|
||||
search_space_id = request.data.get("search_space_id")
|
||||
research_mode: str = request.data.get("research_mode")
|
||||
selected_connectors: list[str] = request.data.get("selected_connectors")
|
||||
document_ids_to_add_in_context: list[int] = request.data.get(
|
||||
"document_ids_to_add_in_context"
|
||||
)
|
||||
|
||||
search_mode_str = request.data.get("search_mode", "CHUNKS")
|
||||
|
||||
# Convert search_space_id to integer if it's a string
|
||||
if search_space_id and isinstance(search_space_id, str):
|
||||
|
|
@ -40,21 +42,23 @@ async def handle_chat_data(
|
|||
search_space_id = int(search_space_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid search_space_id format")
|
||||
status_code=400, detail="Invalid search_space_id format"
|
||||
) from None
|
||||
|
||||
# Check if the search space belongs to the current user
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
except HTTPException:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this search space")
|
||||
|
||||
status_code=403, detail="You don't have access to this search space"
|
||||
) from None
|
||||
|
||||
langchain_chat_history = []
|
||||
for message in messages[:-1]:
|
||||
if message['role'] == "user":
|
||||
langchain_chat_history.append(HumanMessage(content=message['content']))
|
||||
elif message['role'] == "assistant":
|
||||
langchain_chat_history.append(AIMessage(content=message['content']))
|
||||
if message["role"] == "user":
|
||||
langchain_chat_history.append(HumanMessage(content=message["content"]))
|
||||
elif message["role"] == "assistant":
|
||||
langchain_chat_history.append(AIMessage(content=message["content"]))
|
||||
|
||||
response = StreamingResponse(
|
||||
stream_connector_search_results(
|
||||
|
|
@ -69,7 +73,7 @@ async def handle_chat_data(
|
|||
document_ids_to_add_in_context,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
response.headers["x-vercel-ai-data-stream"] = "v1"
|
||||
return response
|
||||
|
||||
|
|
@ -78,7 +82,7 @@ async def handle_chat_data(
|
|||
async def create_chat(
|
||||
chat: ChatCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, chat.search_space_id, user)
|
||||
|
|
@ -89,52 +93,57 @@ async def create_chat(
|
|||
return db_chat
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError as e:
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
||||
except OperationalError as e:
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception as e:
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while creating the chat.")
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while creating the chat.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/chats/", response_model=List[ChatRead])
|
||||
@router.get("/chats/", response_model=list[ChatRead])
|
||||
async def read_chats(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: int = None,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
query = select(Chat).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
||||
|
||||
|
||||
# Filter by search_space_id if provided
|
||||
if search_space_id is not None:
|
||||
query = query.filter(Chat.search_space_id == search_space_id)
|
||||
|
||||
result = await session.execute(
|
||||
query.offset(skip).limit(limit)
|
||||
)
|
||||
|
||||
result = await session.execute(query.offset(skip).limit(limit))
|
||||
return result.scalars().all()
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while fetching chats.")
|
||||
status_code=500, detail="An unexpected error occurred while fetching chats."
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/chats/{chat_id}", response_model=ChatRead)
|
||||
async def read_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -145,14 +154,19 @@ async def read_chat(
|
|||
chat = result.scalars().first()
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat not found or you don't have permission to access it")
|
||||
status_code=404,
|
||||
detail="Chat not found or you don't have permission to access it",
|
||||
)
|
||||
return chat
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while fetching the chat.")
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while fetching the chat.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.put("/chats/{chat_id}", response_model=ChatRead)
|
||||
|
|
@ -160,7 +174,7 @@ async def update_chat(
|
|||
chat_id: int,
|
||||
chat_update: ChatUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_chat = await read_chat(chat_id, session, user)
|
||||
|
|
@ -175,22 +189,27 @@ async def update_chat(
|
|||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while updating the chat.")
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while updating the chat.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/chats/{chat_id}", response_model=dict)
|
||||
async def delete_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_chat = await read_chat(chat_id, session, user)
|
||||
|
|
@ -202,81 +221,16 @@ async def delete_chat(
|
|||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete chat due to existing dependencies.")
|
||||
status_code=400, detail="Cannot delete chat due to existing dependencies."
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while deleting the chat.")
|
||||
|
||||
|
||||
# test_data = [
|
||||
# {
|
||||
# "type": "TERMINAL_INFO",
|
||||
# "content": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "text": "Starting to search for crawled URLs...",
|
||||
# "type": "info"
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "text": "Found 2 relevant crawled URLs",
|
||||
# "type": "success"
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "type": "SOURCES",
|
||||
# "content": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "name": "Crawled URLs",
|
||||
# "type": "CRAWLED_URL",
|
||||
# "sources": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://jsoneditoronline.org/"
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://www.google.com/"
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "name": "Files",
|
||||
# "type": "FILE",
|
||||
# "sources": [
|
||||
# {
|
||||
# "id": 3,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://jsoneditoronline.org/"
|
||||
# },
|
||||
# {
|
||||
# "id": 4,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://www.google.com/"
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "type": "ANSWER",
|
||||
# "content": [
|
||||
# "## SurfSense Introduction",
|
||||
# "Surfsense is A Personal NotebookLM and Perplexity-like AI Assistant for Everyone. Research and Never forget Anything. [1] [3]"
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while deleting the chat.",
|
||||
) from None
|
||||
|
|
|
|||
|
|
@ -1,23 +1,35 @@
|
|||
from litellm import atranscription
|
||||
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List
|
||||
from app.db import Log, get_async_session, User, SearchSpace, Document, DocumentType
|
||||
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from app.tasks.background_tasks import add_received_markdown_file_document, add_extension_received_document, add_received_file_document_using_unstructured, add_crawled_url_document, add_youtube_video_document, add_received_file_document_using_llamacloud, add_received_file_document_using_docling
|
||||
from app.config import config as app_config
|
||||
# Force asyncio to use standard event loop before unstructured imports
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, UploadFile
|
||||
from litellm import atranscription
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentType, Log, SearchSpace, User, get_async_session
|
||||
from app.schemas import DocumentRead, DocumentsCreate, DocumentUpdate
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.background_tasks import (
|
||||
add_crawled_url_document,
|
||||
add_extension_received_document,
|
||||
add_received_file_document_using_docling,
|
||||
add_received_file_document_using_llamacloud,
|
||||
add_received_file_document_using_unstructured,
|
||||
add_received_markdown_file_document,
|
||||
add_youtube_video_document,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
try:
|
||||
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
||||
except RuntimeError:
|
||||
except RuntimeError as e:
|
||||
print("Error setting event loop policy", e)
|
||||
pass
|
||||
|
||||
import os
|
||||
|
||||
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
||||
|
||||
|
||||
|
|
@ -29,7 +41,7 @@ async def create_documents(
|
|||
request: DocumentsCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
try:
|
||||
# Check if the user owns the search space
|
||||
|
|
@ -41,7 +53,7 @@ async def create_documents(
|
|||
process_extension_document_with_new_session,
|
||||
individual_document,
|
||||
request.search_space_id,
|
||||
str(user.id)
|
||||
str(user.id),
|
||||
)
|
||||
elif request.document_type == DocumentType.CRAWLED_URL:
|
||||
for url in request.content:
|
||||
|
|
@ -49,7 +61,7 @@ async def create_documents(
|
|||
process_crawled_url_with_new_session,
|
||||
url,
|
||||
request.search_space_id,
|
||||
str(user.id)
|
||||
str(user.id),
|
||||
)
|
||||
elif request.document_type == DocumentType.YOUTUBE_VIDEO:
|
||||
for url in request.content:
|
||||
|
|
@ -57,13 +69,10 @@ async def create_documents(
|
|||
process_youtube_video_with_new_session,
|
||||
url,
|
||||
request.search_space_id,
|
||||
str(user.id)
|
||||
str(user.id),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid document type"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Invalid document type")
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Documents processed successfully"}
|
||||
|
|
@ -72,18 +81,17 @@ async def create_documents(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to process documents: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to process documents: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/documents/fileupload")
|
||||
async def create_documents(
|
||||
async def create_documents_file_upload(
|
||||
files: list[UploadFile],
|
||||
search_space_id: int = Form(...),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
|
@ -94,31 +102,32 @@ async def create_documents(
|
|||
for file in files:
|
||||
try:
|
||||
# Save file to a temporary location to avoid stream issues
|
||||
import tempfile
|
||||
import aiofiles
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
# Create temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=os.path.splitext(file.filename)[1]
|
||||
) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
# Write uploaded file to temp file
|
||||
content = await file.read()
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
fastapi_background_tasks.add_task(
|
||||
process_file_in_background_with_new_session,
|
||||
temp_path,
|
||||
file.filename,
|
||||
search_space_id,
|
||||
str(user.id)
|
||||
str(user.id),
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to process file {file.filename}: {str(e)}"
|
||||
)
|
||||
detail=f"Failed to process file {file.filename}: {e!s}",
|
||||
) from e
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Files uploaded for processing"}
|
||||
|
|
@ -127,9 +136,8 @@ async def create_documents(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to upload files: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to upload files: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def process_file_in_background(
|
||||
|
|
@ -139,64 +147,71 @@ async def process_file_in_background(
|
|||
user_id: str,
|
||||
session: AsyncSession,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: Log
|
||||
log_entry: Log,
|
||||
):
|
||||
try:
|
||||
# Check if the file is a markdown or text file
|
||||
if filename.lower().endswith(('.md', '.markdown', '.txt')):
|
||||
if filename.lower().endswith((".md", ".markdown", ".txt")):
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing markdown/text file: {filename}",
|
||||
{"file_type": "markdown", "processing_stage": "reading_file"}
|
||||
{"file_type": "markdown", "processing_stage": "reading_file"},
|
||||
)
|
||||
|
||||
|
||||
# For markdown files, read the content directly
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
markdown_content = f.read()
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Creating document from markdown content: {filename}",
|
||||
{"processing_stage": "creating_document", "content_length": len(markdown_content)}
|
||||
{
|
||||
"processing_stage": "creating_document",
|
||||
"content_length": len(markdown_content),
|
||||
},
|
||||
)
|
||||
|
||||
# Process markdown directly through specialized function
|
||||
result = await add_received_markdown_file_document(
|
||||
session,
|
||||
filename,
|
||||
markdown_content,
|
||||
search_space_id,
|
||||
user_id
|
||||
session, filename, markdown_content, search_space_id, user_id
|
||||
)
|
||||
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed markdown file: {filename}",
|
||||
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "markdown"}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "markdown",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Markdown file already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "markdown"}
|
||||
{"duplicate_detected": True, "file_type": "markdown"},
|
||||
)
|
||||
|
||||
|
||||
# Check if the file is an audio file
|
||||
elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')):
|
||||
elif filename.lower().endswith(
|
||||
(".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
|
||||
):
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing audio file for transcription: {filename}",
|
||||
{"file_type": "audio", "processing_stage": "starting_transcription"}
|
||||
{"file_type": "audio", "processing_stage": "starting_transcription"},
|
||||
)
|
||||
|
||||
|
||||
# Open the audio file for transcription
|
||||
with open(file_path, "rb") as audio_file:
|
||||
# Use LiteLLM for audio transcription
|
||||
|
|
@ -205,65 +220,76 @@ async def process_file_in_background(
|
|||
model=app_config.STT_SERVICE,
|
||||
file=audio_file,
|
||||
api_base=app_config.STT_SERVICE_API_BASE,
|
||||
api_key=app_config.STT_SERVICE_API_KEY
|
||||
api_key=app_config.STT_SERVICE_API_KEY,
|
||||
)
|
||||
else:
|
||||
transcription_response = await atranscription(
|
||||
model=app_config.STT_SERVICE,
|
||||
api_key=app_config.STT_SERVICE_API_KEY,
|
||||
file=audio_file
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
# Extract the transcribed text
|
||||
transcribed_text = transcription_response.get("text", "")
|
||||
|
||||
# Add metadata about the transcription
|
||||
transcribed_text = f"# Transcription of {filename}\n\n{transcribed_text}"
|
||||
transcribed_text = (
|
||||
f"# Transcription of {filename}\n\n{transcribed_text}"
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Transcription completed, creating document: {filename}",
|
||||
{"processing_stage": "transcription_complete", "transcript_length": len(transcribed_text)}
|
||||
{
|
||||
"processing_stage": "transcription_complete",
|
||||
"transcript_length": len(transcribed_text),
|
||||
},
|
||||
)
|
||||
|
||||
# Clean up the temp file
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
# Process transcription as markdown document
|
||||
result = await add_received_markdown_file_document(
|
||||
session,
|
||||
filename,
|
||||
transcribed_text,
|
||||
search_space_id,
|
||||
user_id
|
||||
session, filename, transcribed_text, search_space_id, user_id
|
||||
)
|
||||
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully transcribed and processed audio file: {filename}",
|
||||
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "audio", "transcript_length": len(transcribed_text)}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "audio",
|
||||
"transcript_length": len(transcribed_text),
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Audio file transcript already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "audio"}
|
||||
{"duplicate_detected": True, "file_type": "audio"},
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing file with Unstructured ETL: {filename}",
|
||||
{"file_type": "document", "etl_service": "UNSTRUCTURED", "processing_stage": "loading"}
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
"processing_stage": "loading",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
|
||||
# Process the file
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
|
|
@ -280,212 +306,257 @@ async def process_file_in_background(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Unstructured ETL completed, creating document: {filename}",
|
||||
{"processing_stage": "etl_complete", "elements_count": len(docs)}
|
||||
{"processing_stage": "etl_complete", "elements_count": len(docs)},
|
||||
)
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
# Pass the documents to the existing background task
|
||||
result = await add_received_file_document_using_unstructured(
|
||||
session,
|
||||
filename,
|
||||
docs,
|
||||
search_space_id,
|
||||
user_id
|
||||
session, filename, docs, search_space_id, user_id
|
||||
)
|
||||
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file with Unstructured: {filename}",
|
||||
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "document", "etl_service": "UNSTRUCTURED"}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Document already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "document", "etl_service": "UNSTRUCTURED"}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
elif app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing file with LlamaCloud ETL: {filename}",
|
||||
{"file_type": "document", "etl_service": "LLAMACLOUD", "processing_stage": "parsing"}
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
"processing_stage": "parsing",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
|
||||
# Create LlamaParse parser instance
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1, # Use single worker for file processing
|
||||
verbose=True,
|
||||
language="en",
|
||||
result_type=ResultType.MD
|
||||
result_type=ResultType.MD,
|
||||
)
|
||||
|
||||
|
||||
# Parse the file asynchronously
|
||||
result = await parser.aparse(file_path)
|
||||
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
|
||||
# Get markdown documents from the result
|
||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||
|
||||
markdown_documents = await result.aget_markdown_documents(
|
||||
split_by_page=False
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"LlamaCloud parsing completed, creating documents: {filename}",
|
||||
{"processing_stage": "parsing_complete", "documents_count": len(markdown_documents)}
|
||||
{
|
||||
"processing_stage": "parsing_complete",
|
||||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
for doc in markdown_documents:
|
||||
# Extract text content from the markdown documents
|
||||
markdown_content = doc.text
|
||||
|
||||
|
||||
# Process the documents using our LlamaCloud background task
|
||||
doc_result = await add_received_file_document_using_llamacloud(
|
||||
session,
|
||||
filename,
|
||||
llamacloud_markdown_document=markdown_content,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
if doc_result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file with LlamaCloud: {filename}",
|
||||
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "LLAMACLOUD"}
|
||||
{
|
||||
"document_id": doc_result.id,
|
||||
"content_hash": doc_result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Document already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "document", "etl_service": "LLAMACLOUD"}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
elif app_config.ETL_SERVICE == "DOCLING":
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing file with Docling ETL: {filename}",
|
||||
{"file_type": "document", "etl_service": "DOCLING", "processing_stage": "parsing"}
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
"processing_stage": "parsing",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Use Docling service for document processing
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
|
||||
# Create Docling service
|
||||
docling_service = create_docling_service()
|
||||
|
||||
|
||||
# Process the document
|
||||
result = await docling_service.process_document(file_path, filename)
|
||||
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Docling parsing completed, creating document: {filename}",
|
||||
{"processing_stage": "parsing_complete", "content_length": len(result['content'])}
|
||||
{
|
||||
"processing_stage": "parsing_complete",
|
||||
"content_length": len(result["content"]),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Process the document using our Docling background task
|
||||
doc_result = await add_received_file_document_using_docling(
|
||||
session,
|
||||
filename,
|
||||
docling_markdown_document=result['content'],
|
||||
docling_markdown_document=result["content"],
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
if doc_result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file with Docling: {filename}",
|
||||
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "DOCLING"}
|
||||
{
|
||||
"document_id": doc_result.id,
|
||||
"content_hash": doc_result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Document already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "document", "etl_service": "DOCLING"}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to process file: {filename}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__, "filename": filename}
|
||||
{"error_type": type(e).__name__, "filename": filename},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing file in background: {str(e)}")
|
||||
|
||||
logging.error(f"Error processing file in background: {e!s}")
|
||||
raise # Re-raise so the wrapper can also handle it
|
||||
|
||||
|
||||
@router.get("/documents/", response_model=List[DocumentRead])
|
||||
@router.get("/documents/", response_model=list[DocumentRead])
|
||||
async def read_documents(
|
||||
skip: int = 0,
|
||||
limit: int = 300,
|
||||
search_space_id: int = None,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
query = select(Document).join(SearchSpace).filter(
|
||||
SearchSpace.user_id == user.id)
|
||||
query = (
|
||||
select(Document).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
||||
)
|
||||
|
||||
# Filter by search_space_id if provided
|
||||
if search_space_id is not None:
|
||||
query = query.filter(Document.search_space_id == search_space_id)
|
||||
|
||||
result = await session.execute(
|
||||
query.offset(skip).limit(limit)
|
||||
)
|
||||
result = await session.execute(query.offset(skip).limit(limit))
|
||||
db_documents = result.scalars().all()
|
||||
|
||||
# Convert database objects to API-friendly format
|
||||
api_documents = []
|
||||
for doc in db_documents:
|
||||
api_documents.append(DocumentRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
document_metadata=doc.document_metadata,
|
||||
content=doc.content,
|
||||
created_at=doc.created_at,
|
||||
search_space_id=doc.search_space_id
|
||||
))
|
||||
api_documents.append(
|
||||
DocumentRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
document_metadata=doc.document_metadata,
|
||||
content=doc.content,
|
||||
created_at=doc.created_at,
|
||||
search_space_id=doc.search_space_id,
|
||||
)
|
||||
)
|
||||
|
||||
return api_documents
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch documents: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch documents: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}", response_model=DocumentRead)
|
||||
async def read_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -497,8 +568,7 @@ async def read_document(
|
|||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
status_code=404, detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
# Convert database object to API-friendly format
|
||||
|
|
@ -509,13 +579,12 @@ async def read_document(
|
|||
document_metadata=document.document_metadata,
|
||||
content=document.content,
|
||||
created_at=document.created_at,
|
||||
search_space_id=document.search_space_id
|
||||
search_space_id=document.search_space_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch document: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch document: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/documents/{document_id}", response_model=DocumentRead)
|
||||
|
|
@ -523,7 +592,7 @@ async def update_document(
|
|||
document_id: int,
|
||||
document_update: DocumentUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
# Query the document directly instead of using read_document function
|
||||
|
|
@ -536,8 +605,7 @@ async def update_document(
|
|||
|
||||
if not db_document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
status_code=404, detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
update_data = document_update.model_dump(exclude_unset=True)
|
||||
|
|
@ -554,23 +622,22 @@ async def update_document(
|
|||
document_metadata=db_document.document_metadata,
|
||||
content=db_document.content,
|
||||
created_at=db_document.created_at,
|
||||
search_space_id=db_document.search_space_id
|
||||
search_space_id=db_document.search_space_id,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update document: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update document: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/documents/{document_id}", response_model=dict)
|
||||
async def delete_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
# Query the document directly instead of using read_document function
|
||||
|
|
@ -583,8 +650,7 @@ async def delete_document(
|
|||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
status_code=404, detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
await session.delete(document)
|
||||
|
|
@ -595,15 +661,12 @@ async def delete_document(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete document: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to delete document: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def process_extension_document_with_new_session(
|
||||
individual_document,
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
individual_document, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Create a new session and process extension document."""
|
||||
from app.db import async_session_maker
|
||||
|
|
@ -612,7 +675,7 @@ async def process_extension_document_with_new_session(
|
|||
async with async_session_maker() as session:
|
||||
# Initialize task logging service
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_extension_document",
|
||||
|
|
@ -622,40 +685,41 @@ async def process_extension_document_with_new_session(
|
|||
"document_type": "EXTENSION",
|
||||
"url": individual_document.metadata.VisitedWebPageURL,
|
||||
"title": individual_document.metadata.VisitedWebPageTitle,
|
||||
"user_id": user_id
|
||||
}
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await add_extension_received_document(session, individual_document, search_space_id, user_id)
|
||||
|
||||
result = await add_extension_received_document(
|
||||
session, individual_document, search_space_id, user_id
|
||||
)
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed extension document: {individual_document.metadata.VisitedWebPageTitle}",
|
||||
{"document_id": result.id, "content_hash": result.content_hash}
|
||||
{"document_id": result.id, "content_hash": result.content_hash},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Extension document already exists (duplicate): {individual_document.metadata.VisitedWebPageTitle}",
|
||||
{"duplicate_detected": True}
|
||||
{"duplicate_detected": True},
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to process extension document: {individual_document.metadata.VisitedWebPageTitle}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing extension document: {str(e)}")
|
||||
|
||||
logging.error(f"Error processing extension document: {e!s}")
|
||||
|
||||
|
||||
async def process_crawled_url_with_new_session(
|
||||
url: str,
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
url: str, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Create a new session and process crawled URL."""
|
||||
from app.db import async_session_maker
|
||||
|
|
@ -664,50 +728,50 @@ async def process_crawled_url_with_new_session(
|
|||
async with async_session_maker() as session:
|
||||
# Initialize task logging service
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_crawled_url",
|
||||
source="document_processor",
|
||||
message=f"Starting URL crawling and processing for: {url}",
|
||||
metadata={
|
||||
"document_type": "CRAWLED_URL",
|
||||
"url": url,
|
||||
"user_id": user_id
|
||||
}
|
||||
metadata={"document_type": "CRAWLED_URL", "url": url, "user_id": user_id},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await add_crawled_url_document(session, url, search_space_id, user_id)
|
||||
|
||||
result = await add_crawled_url_document(
|
||||
session, url, search_space_id, user_id
|
||||
)
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully crawled and processed URL: {url}",
|
||||
{"document_id": result.id, "title": result.title, "content_hash": result.content_hash}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"title": result.title,
|
||||
"content_hash": result.content_hash,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"URL document already exists (duplicate): {url}",
|
||||
{"duplicate_detected": True}
|
||||
{"duplicate_detected": True},
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to crawl URL: {url}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing crawled URL: {str(e)}")
|
||||
|
||||
logging.error(f"Error processing crawled URL: {e!s}")
|
||||
|
||||
|
||||
async def process_file_in_background_with_new_session(
|
||||
file_path: str,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
file_path: str, filename: str, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Create a new session and process file."""
|
||||
from app.db import async_session_maker
|
||||
|
|
@ -716,7 +780,7 @@ async def process_file_in_background_with_new_session(
|
|||
async with async_session_maker() as session:
|
||||
# Initialize task logging service
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_file_upload",
|
||||
|
|
@ -726,29 +790,36 @@ async def process_file_in_background_with_new_session(
|
|||
"document_type": "FILE",
|
||||
"filename": filename,
|
||||
"file_path": file_path,
|
||||
"user_id": user_id
|
||||
}
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
await process_file_in_background(file_path, filename, search_space_id, user_id, session, task_logger, log_entry)
|
||||
|
||||
await process_file_in_background(
|
||||
file_path,
|
||||
filename,
|
||||
search_space_id,
|
||||
user_id,
|
||||
session,
|
||||
task_logger,
|
||||
log_entry,
|
||||
)
|
||||
|
||||
# Note: success/failure logging is handled within process_file_in_background
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to process file: {filename}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing file: {str(e)}")
|
||||
|
||||
logging.error(f"Error processing file: {e!s}")
|
||||
|
||||
|
||||
async def process_youtube_video_with_new_session(
|
||||
url: str,
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
url: str, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Create a new session and process YouTube video."""
|
||||
from app.db import async_session_maker
|
||||
|
|
@ -757,42 +828,43 @@ async def process_youtube_video_with_new_session(
|
|||
async with async_session_maker() as session:
|
||||
# Initialize task logging service
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_youtube_video",
|
||||
source="document_processor",
|
||||
message=f"Starting YouTube video processing for: {url}",
|
||||
metadata={
|
||||
"document_type": "YOUTUBE_VIDEO",
|
||||
"url": url,
|
||||
"user_id": user_id
|
||||
}
|
||||
metadata={"document_type": "YOUTUBE_VIDEO", "url": url, "user_id": user_id},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await add_youtube_video_document(session, url, search_space_id, user_id)
|
||||
|
||||
result = await add_youtube_video_document(
|
||||
session, url, search_space_id, user_id
|
||||
)
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed YouTube video: {result.title}",
|
||||
{"document_id": result.id, "video_id": result.document_metadata.get("video_id"), "content_hash": result.content_hash}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"video_id": result.document_metadata.get("video_id"),
|
||||
"content_hash": result.content_hash,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"YouTube video document already exists (duplicate): {url}",
|
||||
{"duplicate_detected": True}
|
||||
{"duplicate_detected": True},
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to process YouTube video: {url}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing YouTube video: {str(e)}")
|
||||
|
||||
|
||||
logging.error(f"Error processing YouTube video: {e!s}")
|
||||
|
|
|
|||
|
|
@ -1,35 +1,40 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from app.db import get_async_session, User, LLMConfig
|
||||
from app.schemas import LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
|
||||
|
||||
from app.db import LLMConfig, User, get_async_session
|
||||
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class LLMPreferencesUpdate(BaseModel):
|
||||
"""Schema for updating user LLM preferences"""
|
||||
long_context_llm_id: Optional[int] = None
|
||||
fast_llm_id: Optional[int] = None
|
||||
strategic_llm_id: Optional[int] = None
|
||||
|
||||
long_context_llm_id: int | None = None
|
||||
fast_llm_id: int | None = None
|
||||
strategic_llm_id: int | None = None
|
||||
|
||||
|
||||
class LLMPreferencesRead(BaseModel):
|
||||
"""Schema for reading user LLM preferences"""
|
||||
long_context_llm_id: Optional[int] = None
|
||||
fast_llm_id: Optional[int] = None
|
||||
strategic_llm_id: Optional[int] = None
|
||||
long_context_llm: Optional[LLMConfigRead] = None
|
||||
fast_llm: Optional[LLMConfigRead] = None
|
||||
strategic_llm: Optional[LLMConfigRead] = None
|
||||
|
||||
long_context_llm_id: int | None = None
|
||||
fast_llm_id: int | None = None
|
||||
strategic_llm_id: int | None = None
|
||||
long_context_llm: LLMConfigRead | None = None
|
||||
fast_llm: LLMConfigRead | None = None
|
||||
strategic_llm: LLMConfigRead | None = None
|
||||
|
||||
|
||||
@router.post("/llm-configs/", response_model=LLMConfigRead)
|
||||
async def create_llm_config(
|
||||
llm_config: LLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a new LLM configuration for the authenticated user"""
|
||||
try:
|
||||
|
|
@ -43,16 +48,16 @@ async def create_llm_config(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create LLM configuration: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to create LLM configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
@router.get("/llm-configs/", response_model=List[LLMConfigRead])
|
||||
|
||||
@router.get("/llm-configs/", response_model=list[LLMConfigRead])
|
||||
async def read_llm_configs(
|
||||
skip: int = 0,
|
||||
limit: int = 200,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get all LLM configurations for the authenticated user"""
|
||||
try:
|
||||
|
|
@ -65,15 +70,15 @@ async def read_llm_configs(
|
|||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch LLM configurations: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
||||
async def read_llm_config(
|
||||
llm_config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific LLM configuration by ID"""
|
||||
try:
|
||||
|
|
@ -83,25 +88,25 @@ async def read_llm_config(
|
|||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch LLM configuration: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch LLM configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
||||
async def update_llm_config(
|
||||
llm_config_id: int,
|
||||
llm_config_update: LLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update an existing LLM configuration"""
|
||||
try:
|
||||
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
|
||||
update_data = llm_config_update.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(db_llm_config, key, value)
|
||||
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_llm_config)
|
||||
return db_llm_config
|
||||
|
|
@ -110,15 +115,15 @@ async def update_llm_config(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update LLM configuration: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update LLM configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/llm-configs/{llm_config_id}", response_model=dict)
|
||||
async def delete_llm_config(
|
||||
llm_config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete an LLM configuration"""
|
||||
try:
|
||||
|
|
@ -131,22 +136,23 @@ async def delete_llm_config(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete LLM configuration: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to delete LLM configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# User LLM Preferences endpoints
|
||||
|
||||
|
||||
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||
async def get_user_llm_preferences(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get the current user's LLM preferences"""
|
||||
try:
|
||||
# Refresh user to get latest relationships
|
||||
await session.refresh(user)
|
||||
|
||||
|
||||
result = {
|
||||
"long_context_llm_id": user.long_context_llm_id,
|
||||
"fast_llm_id": user.fast_llm_id,
|
||||
|
|
@ -155,82 +161,79 @@ async def get_user_llm_preferences(
|
|||
"fast_llm": None,
|
||||
"strategic_llm": None,
|
||||
}
|
||||
|
||||
|
||||
# Fetch the actual LLM configs if they exist
|
||||
if user.long_context_llm_id:
|
||||
long_context_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.long_context_llm_id,
|
||||
LLMConfig.user_id == user.id
|
||||
LLMConfig.user_id == user.id,
|
||||
)
|
||||
)
|
||||
llm_config = long_context_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["long_context_llm"] = llm_config
|
||||
|
||||
|
||||
if user.fast_llm_id:
|
||||
fast_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.fast_llm_id,
|
||||
LLMConfig.user_id == user.id
|
||||
LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = fast_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["fast_llm"] = llm_config
|
||||
|
||||
|
||||
if user.strategic_llm_id:
|
||||
strategic_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.strategic_llm_id,
|
||||
LLMConfig.user_id == user.id
|
||||
LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = strategic_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["strategic_llm"] = llm_config
|
||||
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch LLM preferences: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||
async def update_user_llm_preferences(
|
||||
preferences: LLMPreferencesUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update the current user's LLM preferences"""
|
||||
try:
|
||||
# Validate that all provided LLM config IDs belong to the user
|
||||
update_data = preferences.model_dump(exclude_unset=True)
|
||||
|
||||
for key, llm_config_id in update_data.items():
|
||||
|
||||
for _key, llm_config_id in update_data.items():
|
||||
if llm_config_id is not None:
|
||||
# Verify ownership of the LLM config
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.user_id == user.id
|
||||
LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
if not llm_config:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it"
|
||||
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it",
|
||||
)
|
||||
|
||||
|
||||
# Update user preferences
|
||||
for key, value in update_data.items():
|
||||
setattr(user, key, value)
|
||||
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
|
||||
# Return updated preferences
|
||||
return await get_user_llm_preferences(session, user)
|
||||
except HTTPException:
|
||||
|
|
@ -238,6 +241,5 @@ async def update_user_llm_preferences(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update LLM preferences: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import and_, desc
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.db import get_async_session, User, SearchSpace, Log, LogLevel, LogStatus
|
||||
from app.schemas import LogCreate, LogUpdate, LogRead, LogFilter
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import Log, LogLevel, LogStatus, SearchSpace, User, get_async_session
|
||||
from app.schemas import LogCreate, LogRead, LogUpdate
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/logs/", response_model=LogRead)
|
||||
async def create_log(
|
||||
log: LogCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a new log entry."""
|
||||
try:
|
||||
|
|
@ -33,22 +34,22 @@ async def create_log(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create log: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to create log: {e!s}"
|
||||
) from e
|
||||
|
||||
@router.get("/logs/", response_model=List[LogRead])
|
||||
|
||||
@router.get("/logs/", response_model=list[LogRead])
|
||||
async def read_logs(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: Optional[int] = None,
|
||||
level: Optional[LogLevel] = None,
|
||||
status: Optional[LogStatus] = None,
|
||||
source: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
search_space_id: int | None = None,
|
||||
level: LogLevel | None = None,
|
||||
status: LogStatus | None = None,
|
||||
source: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get logs with optional filtering."""
|
||||
try:
|
||||
|
|
@ -62,23 +63,23 @@ async def read_logs(
|
|||
|
||||
# Apply filters
|
||||
filters = []
|
||||
|
||||
|
||||
if search_space_id is not None:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
filters.append(Log.search_space_id == search_space_id)
|
||||
|
||||
|
||||
if level is not None:
|
||||
filters.append(Log.level == level)
|
||||
|
||||
|
||||
if status is not None:
|
||||
filters.append(Log.status == status)
|
||||
|
||||
|
||||
if source is not None:
|
||||
filters.append(Log.source.ilike(f"%{source}%"))
|
||||
|
||||
|
||||
if start_date is not None:
|
||||
filters.append(Log.created_at >= start_date)
|
||||
|
||||
|
||||
if end_date is not None:
|
||||
filters.append(Log.created_at <= end_date)
|
||||
|
||||
|
|
@ -93,15 +94,15 @@ async def read_logs(
|
|||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch logs: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch logs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/logs/{log_id}", response_model=LogRead)
|
||||
async def read_log(
|
||||
log_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific log by ID."""
|
||||
try:
|
||||
|
|
@ -112,25 +113,25 @@ async def read_log(
|
|||
.filter(Log.id == log_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
log = result.scalars().first()
|
||||
|
||||
|
||||
if not log:
|
||||
raise HTTPException(status_code=404, detail="Log not found")
|
||||
|
||||
|
||||
return log
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch log: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch log: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/logs/{log_id}", response_model=LogRead)
|
||||
async def update_log(
|
||||
log_id: int,
|
||||
log_update: LogUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update a log entry."""
|
||||
try:
|
||||
|
|
@ -141,7 +142,7 @@ async def update_log(
|
|||
.filter(Log.id == log_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
db_log = result.scalars().first()
|
||||
|
||||
|
||||
if not db_log:
|
||||
raise HTTPException(status_code=404, detail="Log not found")
|
||||
|
||||
|
|
@ -158,15 +159,15 @@ async def update_log(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update log: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update log: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/logs/{log_id}")
|
||||
async def delete_log(
|
||||
log_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete a log entry."""
|
||||
try:
|
||||
|
|
@ -177,7 +178,7 @@ async def delete_log(
|
|||
.filter(Log.id == log_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
db_log = result.scalars().first()
|
||||
|
||||
|
||||
if not db_log:
|
||||
raise HTTPException(status_code=404, detail="Log not found")
|
||||
|
||||
|
|
@ -189,38 +190,35 @@ async def delete_log(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete log: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to delete log: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/logs/search-space/{search_space_id}/summary")
|
||||
async def get_logs_summary(
|
||||
search_space_id: int,
|
||||
hours: int = 24,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a summary of logs for a search space in the last X hours."""
|
||||
try:
|
||||
# Check ownership
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
||||
|
||||
# Calculate time window
|
||||
since = datetime.utcnow().replace(microsecond=0) - timedelta(hours=hours)
|
||||
|
||||
|
||||
# Get logs from the time window
|
||||
result = await session.execute(
|
||||
select(Log)
|
||||
.filter(
|
||||
and_(
|
||||
Log.search_space_id == search_space_id,
|
||||
Log.created_at >= since
|
||||
)
|
||||
and_(Log.search_space_id == search_space_id, Log.created_at >= since)
|
||||
)
|
||||
.order_by(desc(Log.created_at))
|
||||
)
|
||||
logs = result.scalars().all()
|
||||
|
||||
|
||||
# Create summary
|
||||
summary = {
|
||||
"total_logs": len(logs),
|
||||
|
|
@ -229,52 +227,69 @@ async def get_logs_summary(
|
|||
"by_level": {},
|
||||
"by_source": {},
|
||||
"active_tasks": [],
|
||||
"recent_failures": []
|
||||
"recent_failures": [],
|
||||
}
|
||||
|
||||
|
||||
# Count by status and level
|
||||
for log in logs:
|
||||
# Status counts
|
||||
status_str = log.status.value
|
||||
summary["by_status"][status_str] = summary["by_status"].get(status_str, 0) + 1
|
||||
|
||||
summary["by_status"][status_str] = (
|
||||
summary["by_status"].get(status_str, 0) + 1
|
||||
)
|
||||
|
||||
# Level counts
|
||||
level_str = log.level.value
|
||||
summary["by_level"][level_str] = summary["by_level"].get(level_str, 0) + 1
|
||||
|
||||
|
||||
# Source counts
|
||||
if log.source:
|
||||
summary["by_source"][log.source] = summary["by_source"].get(log.source, 0) + 1
|
||||
|
||||
summary["by_source"][log.source] = (
|
||||
summary["by_source"].get(log.source, 0) + 1
|
||||
)
|
||||
|
||||
# Active tasks (IN_PROGRESS)
|
||||
if log.status == LogStatus.IN_PROGRESS:
|
||||
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown"
|
||||
summary["active_tasks"].append({
|
||||
"id": log.id,
|
||||
"task_name": task_name,
|
||||
"message": log.message,
|
||||
"started_at": log.created_at,
|
||||
"source": log.source
|
||||
})
|
||||
|
||||
task_name = (
|
||||
log.log_metadata.get("task_name", "Unknown")
|
||||
if log.log_metadata
|
||||
else "Unknown"
|
||||
)
|
||||
summary["active_tasks"].append(
|
||||
{
|
||||
"id": log.id,
|
||||
"task_name": task_name,
|
||||
"message": log.message,
|
||||
"started_at": log.created_at,
|
||||
"source": log.source,
|
||||
}
|
||||
)
|
||||
|
||||
# Recent failures
|
||||
if log.status == LogStatus.FAILED and len(summary["recent_failures"]) < 10:
|
||||
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown"
|
||||
summary["recent_failures"].append({
|
||||
"id": log.id,
|
||||
"task_name": task_name,
|
||||
"message": log.message,
|
||||
"failed_at": log.created_at,
|
||||
"source": log.source,
|
||||
"error_details": log.log_metadata.get("error_details") if log.log_metadata else None
|
||||
})
|
||||
|
||||
task_name = (
|
||||
log.log_metadata.get("task_name", "Unknown")
|
||||
if log.log_metadata
|
||||
else "Unknown"
|
||||
)
|
||||
summary["recent_failures"].append(
|
||||
{
|
||||
"id": log.id,
|
||||
"task_name": task_name,
|
||||
"message": log.message,
|
||||
"failed_at": log.created_at,
|
||||
"source": log.source,
|
||||
"error_details": log.log_metadata.get("error_details")
|
||||
if log.log_metadata
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to generate logs summary: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to generate logs summary: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -1,24 +1,31 @@
|
|||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace, Podcast, Chat
|
||||
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||
from fastapi.responses import StreamingResponse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import Chat, Podcast, SearchSpace, User, get_async_session
|
||||
from app.schemas import (
|
||||
PodcastCreate,
|
||||
PodcastGenerateRequest,
|
||||
PodcastRead,
|
||||
PodcastUpdate,
|
||||
)
|
||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/podcasts/", response_model=PodcastRead)
|
||||
async def create_podcast(
|
||||
podcast: PodcastCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, podcast.search_space_id, user)
|
||||
|
|
@ -29,22 +36,30 @@ async def create_podcast(
|
|||
return db_podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError as e:
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Podcast creation failed due to constraint violation")
|
||||
except SQLAlchemyError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Podcast creation failed due to constraint violation",
|
||||
) from None
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while creating podcast")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while creating podcast"
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="An unexpected error occurred")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred"
|
||||
) from None
|
||||
|
||||
@router.get("/podcasts/", response_model=List[PodcastRead])
|
||||
|
||||
@router.get("/podcasts/", response_model=list[PodcastRead])
|
||||
async def read_podcasts(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
|
|
@ -58,13 +73,16 @@ async def read_podcasts(
|
|||
)
|
||||
return result.scalars().all()
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcasts")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while fetching podcasts"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def read_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -76,20 +94,23 @@ async def read_podcast(
|
|||
if not podcast:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Podcast not found or you don't have permission to access it"
|
||||
detail="Podcast not found or you don't have permission to access it",
|
||||
)
|
||||
return podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcast")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while fetching podcast"
|
||||
) from None
|
||||
|
||||
|
||||
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def update_podcast(
|
||||
podcast_id: int,
|
||||
podcast_update: PodcastUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_podcast = await read_podcast(podcast_id, session, user)
|
||||
|
|
@ -103,16 +124,21 @@ async def update_podcast(
|
|||
raise he
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Update failed due to constraint violation")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Update failed due to constraint violation"
|
||||
) from None
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while updating podcast")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while updating podcast"
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
||||
async def delete_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_podcast = await read_podcast(podcast_id, session, user)
|
||||
|
|
@ -123,83 +149,100 @@ async def delete_podcast(
|
|||
raise he
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while deleting podcast"
|
||||
) from None
|
||||
|
||||
|
||||
async def generate_chat_podcast_with_new_session(
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
podcast_title: str,
|
||||
user_id: int
|
||||
chat_id: int, search_space_id: int, podcast_title: str, user_id: int
|
||||
):
|
||||
"""Create a new session and process chat podcast generation."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title, user_id)
|
||||
await generate_chat_podcast(
|
||||
session, chat_id, search_space_id, podcast_title, user_id
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Error generating podcast from chat: {str(e)}")
|
||||
|
||||
logging.error(f"Error generating podcast from chat: {e!s}")
|
||||
|
||||
|
||||
@router.post("/podcasts/generate/")
|
||||
async def generate_podcast(
|
||||
request: PodcastGenerateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
try:
|
||||
# Check if the user owns the search space
|
||||
await check_ownership(session, SearchSpace, request.search_space_id, user)
|
||||
|
||||
|
||||
if request.type == "CHAT":
|
||||
# Verify that all chat IDs belong to this user and search space
|
||||
query = select(Chat).filter(
|
||||
Chat.id.in_(request.ids),
|
||||
Chat.search_space_id == request.search_space_id
|
||||
).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
||||
|
||||
query = (
|
||||
select(Chat)
|
||||
.filter(
|
||||
Chat.id.in_(request.ids),
|
||||
Chat.search_space_id == request.search_space_id,
|
||||
)
|
||||
.join(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
valid_chats = result.scalars().all()
|
||||
valid_chat_ids = [chat.id for chat in valid_chats]
|
||||
|
||||
|
||||
# If any requested ID is not in valid IDs, raise error immediately
|
||||
if len(valid_chat_ids) != len(request.ids):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="One or more chat IDs do not belong to this user or search space"
|
||||
status_code=403,
|
||||
detail="One or more chat IDs do not belong to this user or search space",
|
||||
)
|
||||
|
||||
|
||||
# Only add a single task with the first chat ID
|
||||
for chat_id in valid_chat_ids:
|
||||
fastapi_background_tasks.add_task(
|
||||
generate_chat_podcast_with_new_session,
|
||||
chat_id,
|
||||
generate_chat_podcast_with_new_session,
|
||||
chat_id,
|
||||
request.search_space_id,
|
||||
request.podcast_title,
|
||||
user.id
|
||||
user.id,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"message": "Podcast generation started",
|
||||
}
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError as e:
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Podcast generation failed due to constraint violation")
|
||||
except SQLAlchemyError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Podcast generation failed due to constraint violation",
|
||||
) from None
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while generating podcast")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while generating podcast"
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"An unexpected error occurred: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}/stream")
|
||||
async def stream_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Stream a podcast audio file."""
|
||||
try:
|
||||
|
|
@ -210,36 +253,38 @@ async def stream_podcast(
|
|||
.filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
podcast = result.scalars().first()
|
||||
|
||||
|
||||
if not podcast:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Podcast not found or you don't have permission to access it"
|
||||
detail="Podcast not found or you don't have permission to access it",
|
||||
)
|
||||
|
||||
|
||||
# Get the file path
|
||||
file_path = podcast.file_location
|
||||
|
||||
|
||||
# Check if the file exists
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Podcast audio file not found")
|
||||
|
||||
|
||||
# Define a generator function to stream the file
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
|
||||
|
||||
# Return a streaming response with appropriate headers
|
||||
return StreamingResponse(
|
||||
iterfile(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Disposition": f"inline; filename={Path(file_path).name}"
|
||||
}
|
||||
"Content-Disposition": f"inline; filename={Path(file_path).name}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error streaming podcast: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error streaming podcast: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -12,7 +12,13 @@ Note: Each user can have only one connector of each type (SERPER_API, TAVILY_API
|
|||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.connectors.github_connector import GitHubConnector
|
||||
from app.db import (
|
||||
|
|
@ -39,11 +45,6 @@ from app.tasks.connectors_indexing_tasks import (
|
|||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -57,7 +58,7 @@ class GitHubPATRequest(BaseModel):
|
|||
|
||||
|
||||
# --- New Endpoint to list GitHub Repositories ---
|
||||
@router.post("/github/repositories/", response_model=List[Dict[str, Any]])
|
||||
@router.post("/github/repositories/", response_model=list[dict[str, Any]])
|
||||
async def list_github_repositories(
|
||||
pat_request: GitHubPATRequest,
|
||||
user: User = Depends(current_active_user), # Ensure the user is logged in
|
||||
|
|
@ -74,15 +75,13 @@ async def list_github_repositories(
|
|||
return repositories
|
||||
except ValueError as e:
|
||||
# Handle invalid token error specifically
|
||||
logger.error(f"GitHub PAT validation failed for user {user.id}: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {str(e)}")
|
||||
logger.error(f"GitHub PAT validation failed for user {user.id}: {e!s}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {e!s}") from e
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch GitHub repositories for user {user.id}: {str(e)}"
|
||||
)
|
||||
logger.error(f"Failed to fetch GitHub repositories for user {user.id}: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to fetch GitHub repositories."
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead)
|
||||
|
|
@ -118,32 +117,32 @@ async def create_search_source_connector(
|
|||
return db_connector
|
||||
except ValidationError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=422, detail=f"Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=422, detail=f"Validation error: {e!s}") from e
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Integrity error: A connector with this type already exists. {str(e)}",
|
||||
)
|
||||
detail=f"Integrity error: A connector with this type already exists. {e!s}",
|
||||
) from e
|
||||
except HTTPException:
|
||||
await session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create search source connector: {str(e)}")
|
||||
logger.error(f"Failed to create search source connector: {e!s}")
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create search source connector: {str(e)}",
|
||||
)
|
||||
detail=f"Failed to create search source connector: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search-source-connectors/", response_model=List[SearchSourceConnectorRead]
|
||||
"/search-source-connectors/", response_model=list[SearchSourceConnectorRead]
|
||||
)
|
||||
async def read_search_source_connectors(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: int = None,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -160,8 +159,8 @@ async def read_search_source_connectors(
|
|||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search source connectors: {str(e)}",
|
||||
)
|
||||
detail=f"Failed to fetch search source connectors: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
|
|
@ -179,8 +178,8 @@ async def read_search_source_connector(
|
|||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch search source connector: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch search source connector: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put(
|
||||
|
|
@ -238,8 +237,8 @@ async def update_search_source_connector(
|
|||
except ValidationError as e:
|
||||
# Raise specific validation error for the merged config
|
||||
raise HTTPException(
|
||||
status_code=422, detail=f"Validation error for merged config: {str(e)}"
|
||||
)
|
||||
status_code=422, detail=f"Validation error for merged config: {e!s}"
|
||||
) from e
|
||||
|
||||
# If validation passes, update the main update_data dict with the merged config
|
||||
update_data["config"] = merged_config
|
||||
|
|
@ -272,8 +271,8 @@ async def update_search_source_connector(
|
|||
await session.rollback()
|
||||
# This might occur if connector_type constraint is violated somehow after the check
|
||||
raise HTTPException(
|
||||
status_code=409, detail=f"Database integrity error during update: {str(e)}"
|
||||
)
|
||||
status_code=409, detail=f"Database integrity error during update: {e!s}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(
|
||||
|
|
@ -282,8 +281,8 @@ async def update_search_source_connector(
|
|||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update search source connector: {str(e)}",
|
||||
)
|
||||
detail=f"Failed to update search source connector: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/search-source-connectors/{connector_id}", response_model=dict)
|
||||
|
|
@ -306,12 +305,12 @@ async def delete_search_source_connector(
|
|||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete search source connector: {str(e)}",
|
||||
)
|
||||
detail=f"Failed to delete search source connector: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search-source-connectors/{connector_id}/index", response_model=Dict[str, Any]
|
||||
"/search-source-connectors/{connector_id}/index", response_model=dict[str, Any]
|
||||
)
|
||||
async def index_connector_content(
|
||||
connector_id: int,
|
||||
|
|
@ -356,7 +355,7 @@ async def index_connector_content(
|
|||
)
|
||||
|
||||
# Check if the search space belongs to the user
|
||||
search_space = await check_ownership(
|
||||
_search_space = await check_ownership(
|
||||
session, SearchSpace, search_space_id, user
|
||||
)
|
||||
|
||||
|
|
@ -381,10 +380,7 @@ async def index_connector_content(
|
|||
else:
|
||||
indexing_from = start_date
|
||||
|
||||
if end_date is None:
|
||||
indexing_to = today_str
|
||||
else:
|
||||
indexing_to = end_date
|
||||
indexing_to = end_date if end_date else today_str
|
||||
|
||||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
# Run indexing in background
|
||||
|
|
@ -497,8 +493,8 @@ async def index_connector_content(
|
|||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate indexing: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to initiate indexing: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def update_connector_last_indexed(session: AsyncSession, connector_id: int):
|
||||
|
|
@ -523,7 +519,7 @@ async def update_connector_last_indexed(session: AsyncSession, connector_id: int
|
|||
logger.info(f"Updated last_indexed_at for connector {connector_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update last_indexed_at for connector {connector_id}: {str(e)}"
|
||||
f"Failed to update last_indexed_at for connector {connector_id}: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
|
||||
|
|
@ -587,7 +583,7 @@ async def run_slack_indexing(
|
|||
f"Slack indexing failed or no documents processed: {error_or_warning}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Slack indexing task: {str(e)}")
|
||||
logger.error(f"Error in background Slack indexing task: {e!s}")
|
||||
|
||||
|
||||
async def run_notion_indexing_with_new_session(
|
||||
|
|
@ -649,7 +645,7 @@ async def run_notion_indexing(
|
|||
f"Notion indexing failed or no documents processed: {error_or_warning}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Notion indexing task: {str(e)}")
|
||||
logger.error(f"Error in background Notion indexing task: {e!s}")
|
||||
|
||||
|
||||
# Add new helper functions for GitHub indexing
|
||||
|
|
@ -829,7 +825,7 @@ async def run_discord_indexing(
|
|||
f"Discord indexing failed or no documents processed: {error_or_warning}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Discord indexing task: {str(e)}")
|
||||
logger.error(f"Error in background Discord indexing task: {e!s}")
|
||||
|
||||
|
||||
# Add new helper functions for Jira indexing
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace
|
||||
from app.schemas import SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
||||
|
||||
from app.db import SearchSpace, User, get_async_session
|
||||
from app.schemas import SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from fastapi import HTTPException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/searchspaces/", response_model=SearchSpaceRead)
|
||||
async def create_search_space(
|
||||
search_space: SearchSpaceCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id)
|
||||
|
|
@ -27,16 +27,16 @@ async def create_search_space(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create search space: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to create search space: {e!s}"
|
||||
) from e
|
||||
|
||||
@router.get("/searchspaces/", response_model=List[SearchSpaceRead])
|
||||
|
||||
@router.get("/searchspaces/", response_model=list[SearchSpaceRead])
|
||||
async def read_search_spaces(
|
||||
skip: int = 0,
|
||||
limit: int = 200,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -48,37 +48,41 @@ async def read_search_spaces(
|
|||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search spaces: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch search spaces: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||
async def read_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
search_space = await check_ownership(
|
||||
session, SearchSpace, search_space_id, user
|
||||
)
|
||||
return search_space
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search space: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch search space: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||
async def update_search_space(
|
||||
search_space_id: int,
|
||||
search_space_update: SearchSpaceUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
db_search_space = await check_ownership(
|
||||
session, SearchSpace, search_space_id, user
|
||||
)
|
||||
update_data = search_space_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_search_space, key, value)
|
||||
|
|
@ -90,18 +94,20 @@ async def update_search_space(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update search space: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update search space: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
||||
async def delete_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
db_search_space = await check_ownership(
|
||||
session, SearchSpace, search_space_id, user
|
||||
)
|
||||
await session.delete(db_search_space)
|
||||
await session.commit()
|
||||
return {"message": "Search space deleted successfully"}
|
||||
|
|
@ -110,6 +116,5 @@ async def delete_search_space(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete search space: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to delete search space: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -1,62 +1,78 @@
|
|||
from .base import TimestampModel, IDModel
|
||||
from .users import UserRead, UserCreate, UserUpdate
|
||||
from .search_space import SearchSpaceBase, SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
||||
from .base import IDModel, TimestampModel
|
||||
from .chats import AISDKChatRequest, ChatBase, ChatCreate, ChatRead, ChatUpdate
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
||||
from .documents import (
|
||||
ExtensionDocumentMetadata,
|
||||
ExtensionDocumentContent,
|
||||
DocumentBase,
|
||||
DocumentRead,
|
||||
DocumentsCreate,
|
||||
DocumentUpdate,
|
||||
DocumentRead,
|
||||
ExtensionDocumentContent,
|
||||
ExtensionDocumentMetadata,
|
||||
)
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
|
||||
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
|
||||
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
|
||||
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
|
||||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
|
||||
from .logs import LogBase, LogCreate, LogUpdate, LogRead, LogFilter
|
||||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||
from .podcasts import (
|
||||
PodcastBase,
|
||||
PodcastCreate,
|
||||
PodcastGenerateRequest,
|
||||
PodcastRead,
|
||||
PodcastUpdate,
|
||||
)
|
||||
from .search_source_connector import (
|
||||
SearchSourceConnectorBase,
|
||||
SearchSourceConnectorCreate,
|
||||
SearchSourceConnectorRead,
|
||||
SearchSourceConnectorUpdate,
|
||||
)
|
||||
from .search_space import (
|
||||
SearchSpaceBase,
|
||||
SearchSpaceCreate,
|
||||
SearchSpaceRead,
|
||||
SearchSpaceUpdate,
|
||||
)
|
||||
from .users import UserCreate, UserRead, UserUpdate
|
||||
|
||||
__all__ = [
|
||||
"AISDKChatRequest",
|
||||
"TimestampModel",
|
||||
"IDModel",
|
||||
"UserRead",
|
||||
"UserCreate",
|
||||
"UserUpdate",
|
||||
"SearchSpaceBase",
|
||||
"SearchSpaceCreate",
|
||||
"SearchSpaceUpdate",
|
||||
"SearchSpaceRead",
|
||||
"ExtensionDocumentMetadata",
|
||||
"ExtensionDocumentContent",
|
||||
"DocumentBase",
|
||||
"DocumentsCreate",
|
||||
"DocumentUpdate",
|
||||
"DocumentRead",
|
||||
"ChunkBase",
|
||||
"ChunkCreate",
|
||||
"ChunkUpdate",
|
||||
"ChunkRead",
|
||||
"PodcastBase",
|
||||
"PodcastCreate",
|
||||
"PodcastUpdate",
|
||||
"PodcastRead",
|
||||
"PodcastGenerateRequest",
|
||||
"ChatBase",
|
||||
"ChatCreate",
|
||||
"ChatUpdate",
|
||||
"ChatRead",
|
||||
"SearchSourceConnectorBase",
|
||||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorUpdate",
|
||||
"SearchSourceConnectorRead",
|
||||
"ChatUpdate",
|
||||
"ChunkBase",
|
||||
"ChunkCreate",
|
||||
"ChunkRead",
|
||||
"ChunkUpdate",
|
||||
"DocumentBase",
|
||||
"DocumentRead",
|
||||
"DocumentUpdate",
|
||||
"DocumentsCreate",
|
||||
"ExtensionDocumentContent",
|
||||
"ExtensionDocumentMetadata",
|
||||
"IDModel",
|
||||
"LLMConfigBase",
|
||||
"LLMConfigCreate",
|
||||
"LLMConfigUpdate",
|
||||
"LLMConfigRead",
|
||||
"LLMConfigUpdate",
|
||||
"LogBase",
|
||||
"LogCreate",
|
||||
"LogUpdate",
|
||||
"LogRead",
|
||||
"LogFilter",
|
||||
]
|
||||
"LogRead",
|
||||
"LogUpdate",
|
||||
"PodcastBase",
|
||||
"PodcastCreate",
|
||||
"PodcastGenerateRequest",
|
||||
"PodcastRead",
|
||||
"PodcastUpdate",
|
||||
"SearchSourceConnectorBase",
|
||||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorRead",
|
||||
"SearchSourceConnectorUpdate",
|
||||
"SearchSpaceBase",
|
||||
"SearchSpaceCreate",
|
||||
"SearchSpaceRead",
|
||||
"SearchSpaceUpdate",
|
||||
"TimestampModel",
|
||||
"UserCreate",
|
||||
"UserRead",
|
||||
"UserUpdate",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class TimestampModel(BaseModel):
|
||||
created_at: datetime
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class IDModel(BaseModel):
|
||||
id: int
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from app.db import ChatType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
|
@ -9,39 +10,43 @@ from .base import IDModel, TimestampModel
|
|||
class ChatBase(BaseModel):
|
||||
type: ChatType
|
||||
title: str
|
||||
initial_connectors: Optional[List[str]] = None
|
||||
messages: List[Any]
|
||||
initial_connectors: list[str] | None = None
|
||||
messages: list[Any]
|
||||
search_space_id: int
|
||||
|
||||
|
||||
|
||||
class ClientAttachment(BaseModel):
|
||||
name: str
|
||||
contentType: str
|
||||
content_type: str
|
||||
url: str
|
||||
|
||||
|
||||
class ToolInvocation(BaseModel):
|
||||
toolCallId: str
|
||||
toolName: str
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
args: dict
|
||||
result: dict
|
||||
|
||||
|
||||
|
||||
|
||||
# class ClientMessage(BaseModel):
|
||||
# role: str
|
||||
# content: str
|
||||
# experimental_attachments: Optional[List[ClientAttachment]] = None
|
||||
# toolInvocations: Optional[List[ToolInvocation]] = None
|
||||
|
||||
|
||||
|
||||
class AISDKChatRequest(BaseModel):
|
||||
messages: List[Any]
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
messages: list[Any]
|
||||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatCreate(ChatBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChatUpdate(ChatBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChatRead(ChatBase, IDModel, TimestampModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class ChunkBase(BaseModel):
|
||||
content: str
|
||||
document_id: int
|
||||
|
||||
|
||||
class ChunkCreate(ChunkBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChunkUpdate(ChunkBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChunkRead(ChunkBase, IDModel, TimestampModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from typing import List
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from app.db import DocumentType
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from app.db import DocumentType
|
||||
|
||||
|
||||
class ExtensionDocumentMetadata(BaseModel):
|
||||
BrowsingSessionId: str
|
||||
VisitedWebPageURL: str
|
||||
|
|
@ -11,21 +13,28 @@ class ExtensionDocumentMetadata(BaseModel):
|
|||
VisitedWebPageReffererURL: str
|
||||
VisitedWebPageVisitDurationInMilliseconds: str
|
||||
|
||||
|
||||
class ExtensionDocumentContent(BaseModel):
|
||||
metadata: ExtensionDocumentMetadata
|
||||
pageContent: str
|
||||
pageContent: str # noqa: N815
|
||||
|
||||
|
||||
class DocumentBase(BaseModel):
|
||||
document_type: DocumentType
|
||||
content: List[ExtensionDocumentContent] | List[str] | str # Updated to allow string content
|
||||
content: (
|
||||
list[ExtensionDocumentContent] | list[str] | str
|
||||
) # Updated to allow string content
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class DocumentsCreate(DocumentBase):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentUpdate(DocumentBase):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentRead(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
|
|
@ -34,6 +43,5 @@ class DocumentRead(BaseModel):
|
|||
content: str # Changed to string to match frontend
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,34 +1,61 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class LLMConfigBase(BaseModel):
|
||||
name: str = Field(..., max_length=100, description="User-friendly name for the LLM configuration")
|
||||
name: str = Field(
|
||||
..., max_length=100, description="User-friendly name for the LLM configuration"
|
||||
)
|
||||
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
|
||||
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
|
||||
model_name: str = Field(..., max_length=100, description="Model name without provider prefix")
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., max_length=100, description="Model name without provider prefix"
|
||||
)
|
||||
api_key: str = Field(..., description="API key for the provider")
|
||||
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
|
||||
litellm_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional LiteLLM parameters")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
|
||||
class LLMConfigCreate(LLMConfigBase):
|
||||
pass
|
||||
|
||||
|
||||
class LLMConfigUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, max_length=100, description="User-friendly name for the LLM configuration")
|
||||
provider: Optional[LiteLLMProvider] = Field(None, description="LiteLLM provider type")
|
||||
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
|
||||
model_name: Optional[str] = Field(None, max_length=100, description="Model name without provider prefix")
|
||||
api_key: Optional[str] = Field(None, description="API key for the provider")
|
||||
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
|
||||
litellm_params: Optional[Dict[str, Any]] = Field(None, description="Additional LiteLLM parameters")
|
||||
name: str | None = Field(
|
||||
None, max_length=100, description="User-friendly name for the LLM configuration"
|
||||
)
|
||||
provider: LiteLLMProvider | None = Field(None, description="LiteLLM provider type")
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
None, max_length=100, description="Model name without provider prefix"
|
||||
)
|
||||
api_key: str | None = Field(None, description="API key for the provider")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
|
||||
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,30 +1,37 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
from app.db import LogLevel, LogStatus
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class LogBase(BaseModel):
|
||||
level: LogLevel
|
||||
status: LogStatus
|
||||
message: str
|
||||
source: Optional[str] = None
|
||||
log_metadata: Optional[Dict[str, Any]] = None
|
||||
source: str | None = None
|
||||
log_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LogCreate(BaseModel):
|
||||
level: LogLevel
|
||||
status: LogStatus
|
||||
message: str
|
||||
source: Optional[str] = None
|
||||
log_metadata: Optional[Dict[str, Any]] = None
|
||||
source: str | None = None
|
||||
log_metadata: dict[str, Any] | None = None
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class LogUpdate(BaseModel):
|
||||
level: Optional[LogLevel] = None
|
||||
status: Optional[LogStatus] = None
|
||||
message: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
log_metadata: Optional[Dict[str, Any]] = None
|
||||
level: LogLevel | None = None
|
||||
status: LogStatus | None = None
|
||||
message: str | None = None
|
||||
source: str | None = None
|
||||
log_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LogRead(LogBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
|
|
@ -33,12 +40,13 @@ class LogRead(LogBase, IDModel, TimestampModel):
|
|||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class LogFilter(BaseModel):
|
||||
level: Optional[LogLevel] = None
|
||||
status: Optional[LogStatus] = None
|
||||
source: Optional[str] = None
|
||||
search_space_id: Optional[int] = None
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
class LogFilter(BaseModel):
|
||||
level: LogLevel | None = None
|
||||
status: LogStatus | None = None
|
||||
source: str | None = None
|
||||
search_space_id: int | None = None
|
||||
start_date: datetime | None = None
|
||||
end_date: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,24 +1,31 @@
|
|||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class PodcastBase(BaseModel):
|
||||
title: str
|
||||
podcast_transcript: List[Any]
|
||||
podcast_transcript: list[Any]
|
||||
file_location: str = ""
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class PodcastCreate(PodcastBase):
|
||||
pass
|
||||
|
||||
|
||||
class PodcastUpdate(PodcastBase):
|
||||
pass
|
||||
|
||||
|
||||
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PodcastGenerateRequest(BaseModel):
|
||||
type: Literal["DOCUMENT", "CHAT"]
|
||||
ids: List[int]
|
||||
ids: list[int]
|
||||
search_space_id: int
|
||||
podcast_title: str = "SurfSense Podcast"
|
||||
podcast_title: str = "SurfSense Podcast"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
|
@ -12,14 +13,14 @@ class SearchSourceConnectorBase(BaseModel):
|
|||
name: str
|
||||
connector_type: SearchSourceConnectorType
|
||||
is_indexable: bool
|
||||
last_indexed_at: Optional[datetime] = None
|
||||
config: Dict[str, Any]
|
||||
last_indexed_at: datetime | None = None
|
||||
config: dict[str, Any]
|
||||
|
||||
@field_validator("config")
|
||||
@classmethod
|
||||
def validate_config_for_connector_type(
|
||||
cls, config: Dict[str, Any], values: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
cls, config: dict[str, Any], values: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
connector_type = values.data.get("connector_type")
|
||||
|
||||
if connector_type == SearchSourceConnectorType.SERPER_API:
|
||||
|
|
@ -150,11 +151,11 @@ class SearchSourceConnectorCreate(SearchSourceConnectorBase):
|
|||
|
||||
|
||||
class SearchSourceConnectorUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
connector_type: Optional[SearchSourceConnectorType] = None
|
||||
is_indexable: Optional[bool] = None
|
||||
last_indexed_at: Optional[datetime] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
name: str | None = None
|
||||
connector_type: SearchSourceConnectorType | None = None
|
||||
is_indexable: bool | None = None
|
||||
last_indexed_at: datetime | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
|
||||
|
|
|
|||
|
|
@ -1,22 +1,27 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class SearchSpaceBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class SearchSpaceCreate(SearchSpaceBase):
|
||||
pass
|
||||
|
||||
|
||||
class SearchSpaceUpdate(SearchSpaceBase):
|
||||
pass
|
||||
|
||||
|
||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
import uuid
|
||||
|
||||
from fastapi_users import schemas
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
pass
|
||||
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
pass
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
# Services package
|
||||
# Services package
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
import asyncio
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from linkup import LinkupClient
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from tavily import TavilyClient
|
||||
|
||||
from app.agents.researcher.configuration import SearchMode
|
||||
from app.db import (
|
||||
|
|
@ -11,15 +17,10 @@ from app.db import (
|
|||
)
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
from linkup import LinkupClient
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from tavily import TavilyClient
|
||||
|
||||
|
||||
class ConnectorService:
|
||||
def __init__(self, session: AsyncSession, user_id: str = None):
|
||||
def __init__(self, session: AsyncSession, user_id: str | None = None):
|
||||
self.session = session
|
||||
self.chunk_retriever = ChucksHybridSearchRetriever(session)
|
||||
self.document_retriever = DocumentHybridSearchRetriever(session)
|
||||
|
|
@ -52,7 +53,7 @@ class ConnectorService:
|
|||
f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error initializing source_id_counter: {str(e)}")
|
||||
print(f"Error initializing source_id_counter: {e!s}")
|
||||
# Fallback to default value
|
||||
self.source_id_counter = 1
|
||||
|
||||
|
|
@ -204,7 +205,9 @@ class ConnectorService:
|
|||
|
||||
return result_object, files_chunks
|
||||
|
||||
def _transform_document_results(self, document_results: List[Dict]) -> List[Dict]:
|
||||
def _transform_document_results(
|
||||
self, document_results: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Transform results from document_retriever.hybrid_search() to match the format
|
||||
expected by the processing code.
|
||||
|
|
@ -233,7 +236,7 @@ class ConnectorService:
|
|||
|
||||
async def get_connector_by_type(
|
||||
self, user_id: str, connector_type: SearchSourceConnectorType
|
||||
) -> Optional[SearchSourceConnector]:
|
||||
) -> SearchSourceConnector | None:
|
||||
"""
|
||||
Get a connector by type for a specific user
|
||||
|
||||
|
|
@ -350,7 +353,7 @@ class ConnectorService:
|
|||
|
||||
except Exception as e:
|
||||
# Log the error and return empty results
|
||||
print(f"Error searching with Tavily: {str(e)}")
|
||||
print(f"Error searching with Tavily: {e!s}")
|
||||
return {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
|
|
@ -596,7 +599,7 @@ class ConnectorService:
|
|||
# Process each chunk and create sources directly without deduplication
|
||||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(extension_chunks):
|
||||
for _, chunk in enumerate(extension_chunks):
|
||||
# Extract document metadata
|
||||
document = chunk.get("document", {})
|
||||
metadata = document.get("metadata", {})
|
||||
|
|
@ -608,7 +611,7 @@ class ConnectorService:
|
|||
visit_duration = metadata.get(
|
||||
"VisitedWebPageVisitDurationInMilliseconds", ""
|
||||
)
|
||||
browsing_session_id = metadata.get("BrowsingSessionId", "")
|
||||
_browsing_session_id = metadata.get("BrowsingSessionId", "")
|
||||
|
||||
# Create a more descriptive title for extension data
|
||||
title = webpage_title
|
||||
|
|
@ -622,7 +625,7 @@ class ConnectorService:
|
|||
else visit_date
|
||||
)
|
||||
title += f" (visited: {formatted_date})"
|
||||
except:
|
||||
except Exception:
|
||||
# Fallback if date parsing fails
|
||||
title += f" (visited: {visit_date})"
|
||||
|
||||
|
|
@ -642,7 +645,7 @@ class ConnectorService:
|
|||
|
||||
if description:
|
||||
description += f" | Duration: {duration_text}"
|
||||
except:
|
||||
except Exception:
|
||||
# Fallback if duration parsing fails
|
||||
pass
|
||||
|
||||
|
|
@ -1180,7 +1183,7 @@ class ConnectorService:
|
|||
|
||||
except Exception as e:
|
||||
# Log the error and return empty results
|
||||
print(f"Error searching with Linkup: {str(e)}")
|
||||
print(f"Error searching with Linkup: {e!s}")
|
||||
return {
|
||||
"id": 10,
|
||||
"name": "Linkup Search",
|
||||
|
|
@ -1239,7 +1242,7 @@ class ConnectorService:
|
|||
# Process each chunk and create sources directly without deduplication
|
||||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(discord_chunks):
|
||||
for _, chunk in enumerate(discord_chunks):
|
||||
# Extract document metadata
|
||||
document = chunk.get("document", {})
|
||||
metadata = document.get("metadata", {})
|
||||
|
|
|
|||
|
|
@ -5,15 +5,16 @@ SSL-safe implementation with pre-downloaded models
|
|||
"""
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DoclingService:
|
||||
"""Docling service for enhanced document processing with SSL fixes."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Docling service with SSL, model fixes, and GPU acceleration."""
|
||||
self.converter = None
|
||||
|
|
@ -21,30 +22,32 @@ class DoclingService:
|
|||
self._configure_ssl_environment()
|
||||
self._check_wsl2_gpu_support()
|
||||
self._initialize_docling()
|
||||
|
||||
|
||||
def _configure_ssl_environment(self):
|
||||
"""Configure SSL environment for secure model downloads."""
|
||||
try:
|
||||
# Set SSL context for downloads
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
|
||||
|
||||
# Set SSL environment variables if not already set
|
||||
if not os.environ.get('SSL_CERT_FILE'):
|
||||
if not os.environ.get("SSL_CERT_FILE"):
|
||||
try:
|
||||
import certifi
|
||||
os.environ['SSL_CERT_FILE'] = certifi.where()
|
||||
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
|
||||
|
||||
os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
logger.info("🔐 SSL environment configured for model downloads")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ SSL configuration warning: {e}")
|
||||
|
||||
|
||||
def _check_wsl2_gpu_support(self):
|
||||
"""Check and configure GPU support for WSL2 environment."""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown"
|
||||
|
|
@ -60,34 +63,34 @@ class DoclingService:
|
|||
except Exception as e:
|
||||
logger.warning(f"⚠️ GPU detection failed: {e}, falling back to CPU")
|
||||
self.use_gpu = False
|
||||
|
||||
|
||||
def _initialize_docling(self):
|
||||
"""Initialize Docling with version-safe configuration."""
|
||||
try:
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
||||
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
|
||||
logger.info("🔧 Initializing Docling with version-safe configuration...")
|
||||
|
||||
|
||||
# Create pipeline options with version-safe attribute checking
|
||||
pipeline_options = PdfPipelineOptions()
|
||||
|
||||
|
||||
# Disable OCR (user request)
|
||||
if hasattr(pipeline_options, 'do_ocr'):
|
||||
if hasattr(pipeline_options, "do_ocr"):
|
||||
pipeline_options.do_ocr = False
|
||||
logger.info("⚠️ OCR disabled by user request")
|
||||
else:
|
||||
logger.warning("⚠️ OCR attribute not available in this Docling version")
|
||||
|
||||
|
||||
# Enable table structure if available
|
||||
if hasattr(pipeline_options, 'do_table_structure'):
|
||||
if hasattr(pipeline_options, "do_table_structure"):
|
||||
pipeline_options.do_table_structure = True
|
||||
logger.info("✅ Table structure detection enabled")
|
||||
|
||||
|
||||
# Configure GPU acceleration for WSL2 if available
|
||||
if hasattr(pipeline_options, 'accelerator_device'):
|
||||
if hasattr(pipeline_options, "accelerator_device"):
|
||||
if self.use_gpu:
|
||||
try:
|
||||
pipeline_options.accelerator_device = "cuda"
|
||||
|
|
@ -99,164 +102,180 @@ class DoclingService:
|
|||
pipeline_options.accelerator_device = "cpu"
|
||||
logger.info("🖥️ Using CPU acceleration")
|
||||
else:
|
||||
logger.info("ℹ️ Accelerator device attribute not available in this Docling version")
|
||||
|
||||
logger.info(
|
||||
"⚠️ Accelerator device attribute not available in this Docling version"
|
||||
)
|
||||
|
||||
# Create PDF format option with backend
|
||||
pdf_format_option = PdfFormatOption(
|
||||
pipeline_options=pipeline_options,
|
||||
backend=PyPdfiumDocumentBackend
|
||||
pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend
|
||||
)
|
||||
|
||||
|
||||
# Initialize DocumentConverter
|
||||
self.converter = DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.PDF: pdf_format_option
|
||||
}
|
||||
format_options={InputFormat.PDF: pdf_format_option}
|
||||
)
|
||||
|
||||
|
||||
acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU"
|
||||
logger.info(f"✅ Docling initialized successfully with {acceleration_type} acceleration")
|
||||
|
||||
logger.info(
|
||||
f"✅ Docling initialized successfully with {acceleration_type} acceleration"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"❌ Docling not installed: {e}")
|
||||
raise RuntimeError(f"Docling not available: {e}")
|
||||
raise RuntimeError(f"Docling not available: {e}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Docling initialization failed: {e}")
|
||||
raise RuntimeError(f"Docling initialization failed: {e}")
|
||||
|
||||
raise RuntimeError(f"Docling initialization failed: {e}") from e
|
||||
|
||||
def _configure_easyocr_local_models(self):
|
||||
"""Configure EasyOCR to use pre-downloaded local models."""
|
||||
try:
|
||||
import easyocr
|
||||
import os
|
||||
|
||||
|
||||
import easyocr
|
||||
|
||||
# Set SSL environment for EasyOCR downloads
|
||||
os.environ['CURL_CA_BUNDLE'] = ''
|
||||
os.environ['REQUESTS_CA_BUNDLE'] = ''
|
||||
|
||||
os.environ["CURL_CA_BUNDLE"] = ""
|
||||
os.environ["REQUESTS_CA_BUNDLE"] = ""
|
||||
|
||||
# Try to use local models first, fallback to download if needed
|
||||
try:
|
||||
reader = easyocr.Reader(['en'],
|
||||
download_enabled=False,
|
||||
model_storage_directory="/root/.EasyOCR/model")
|
||||
reader = easyocr.Reader(
|
||||
["en"],
|
||||
download_enabled=False,
|
||||
model_storage_directory="/root/.EasyOCR/model",
|
||||
)
|
||||
logger.info("✅ EasyOCR configured for local models")
|
||||
return reader
|
||||
except:
|
||||
except Exception:
|
||||
# If local models fail, allow download with SSL bypass
|
||||
logger.info("🔄 Local models failed, attempting download with SSL bypass...")
|
||||
reader = easyocr.Reader(['en'],
|
||||
download_enabled=True,
|
||||
model_storage_directory="/root/.EasyOCR/model")
|
||||
logger.info(
|
||||
"🔄 Local models failed, attempting download with SSL bypass..."
|
||||
)
|
||||
reader = easyocr.Reader(
|
||||
["en"],
|
||||
download_enabled=True,
|
||||
model_storage_directory="/root/.EasyOCR/model",
|
||||
)
|
||||
logger.info("✅ EasyOCR configured with downloaded models")
|
||||
return reader
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ EasyOCR configuration failed: {e}")
|
||||
return None
|
||||
|
||||
async def process_document(self, file_path: str, filename: str = None) -> Dict[str, Any]:
|
||||
|
||||
async def process_document(
|
||||
self, file_path: str, filename: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Process document with Docling using pre-downloaded models."""
|
||||
|
||||
|
||||
if self.converter is None:
|
||||
raise RuntimeError("Docling converter not initialized")
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"🔄 Processing {filename} with Docling (using local models)...")
|
||||
|
||||
logger.info(
|
||||
f"🔄 Processing {filename} with Docling (using local models)..."
|
||||
)
|
||||
|
||||
# Process document with local models
|
||||
result = self.converter.convert(file_path)
|
||||
|
||||
|
||||
# Extract content using version-safe methods
|
||||
content = None
|
||||
if hasattr(result, 'document') and result.document:
|
||||
if hasattr(result, "document") and result.document:
|
||||
# Try different export methods (version compatibility)
|
||||
if hasattr(result.document, 'export_to_markdown'):
|
||||
if hasattr(result.document, "export_to_markdown"):
|
||||
content = result.document.export_to_markdown()
|
||||
logger.info("📄 Used export_to_markdown method")
|
||||
elif hasattr(result.document, 'to_markdown'):
|
||||
elif hasattr(result.document, "to_markdown"):
|
||||
content = result.document.to_markdown()
|
||||
logger.info("📄 Used to_markdown method")
|
||||
elif hasattr(result.document, 'text'):
|
||||
elif hasattr(result.document, "text"):
|
||||
content = result.document.text
|
||||
logger.info("📄 Used text property")
|
||||
elif hasattr(result.document, '__str__'):
|
||||
elif hasattr(result.document, "__str__"):
|
||||
content = str(result.document)
|
||||
logger.info("📄 Used string conversion")
|
||||
|
||||
|
||||
if content:
|
||||
logger.info(f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)")
|
||||
|
||||
logger.info(
|
||||
f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)"
|
||||
)
|
||||
|
||||
return {
|
||||
'content': content,
|
||||
'full_text': content,
|
||||
'service_used': 'docling',
|
||||
'status': 'success',
|
||||
'processing_notes': 'Processed with Docling using pre-downloaded models'
|
||||
"content": content,
|
||||
"full_text": content,
|
||||
"service_used": "docling",
|
||||
"status": "success",
|
||||
"processing_notes": "Processed with Docling using pre-downloaded models",
|
||||
}
|
||||
else:
|
||||
raise ValueError("No content could be extracted from document")
|
||||
else:
|
||||
raise ValueError("No document object returned by Docling")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Docling processing failed for {filename}: {e}")
|
||||
# Log the full error for debugging
|
||||
import traceback
|
||||
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
raise RuntimeError(f"Docling processing failed: {e}")
|
||||
|
||||
raise RuntimeError(f"Docling processing failed: {e}") from e
|
||||
|
||||
async def process_large_document_summary(
|
||||
self,
|
||||
content: str,
|
||||
llm,
|
||||
document_title: str = "Document"
|
||||
self, content: str, llm, document_title: str = "Document"
|
||||
) -> str:
|
||||
"""
|
||||
Process large documents using chunked LLM summarization.
|
||||
|
||||
|
||||
Args:
|
||||
content: The full document content
|
||||
llm: The language model to use for summarization
|
||||
document_title: Title of the document for context
|
||||
|
||||
|
||||
Returns:
|
||||
Final summary of the document
|
||||
"""
|
||||
# Large document threshold (100K characters ≈ 25K tokens)
|
||||
LARGE_DOCUMENT_THRESHOLD = 100_000
|
||||
|
||||
if len(content) <= LARGE_DOCUMENT_THRESHOLD:
|
||||
large_document_threshold = 100_000
|
||||
|
||||
if len(content) <= large_document_threshold:
|
||||
# For smaller documents, use direct processing
|
||||
logger.info(f"📄 Document size: {len(content)} chars - using direct processing")
|
||||
logger.info(
|
||||
f"📄 Document size: {len(content)} chars - using direct processing"
|
||||
)
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
||||
result = await summary_chain.ainvoke({"document": content})
|
||||
return result.content
|
||||
|
||||
logger.info(f"📚 Large document detected: {len(content)} chars - using chunked processing")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"📚 Large document detected: {len(content)} chars - using chunked processing"
|
||||
)
|
||||
|
||||
# Import chunker from config
|
||||
from app.config import config
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
# Create LLM-optimized chunks (8K tokens max for safety)
|
||||
from chonkie import RecursiveChunker, OverlapRefinery
|
||||
from chonkie import OverlapRefinery, RecursiveChunker
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
llm_chunker = RecursiveChunker(
|
||||
chunk_size=8000 # Conservative for most LLMs
|
||||
)
|
||||
|
||||
|
||||
# Apply overlap refinery for context preservation (10% overlap = 800 tokens)
|
||||
overlap_refinery = OverlapRefinery(
|
||||
context_size=0.1, # 10% overlap for context preservation
|
||||
method="suffix" # Add next chunk context to current chunk
|
||||
method="suffix", # Add next chunk context to current chunk
|
||||
)
|
||||
|
||||
|
||||
# First chunk the content, then apply overlap refinery
|
||||
initial_chunks = llm_chunker.chunk(content)
|
||||
chunks = overlap_refinery.refine(initial_chunks)
|
||||
total_chunks = len(chunks)
|
||||
|
||||
|
||||
logger.info(f"📄 Split into {total_chunks} chunks for LLM processing")
|
||||
|
||||
|
||||
# Template for chunk processing
|
||||
chunk_template = PromptTemplate(
|
||||
input_variables=["chunk", "chunk_number", "total_chunks"],
|
||||
|
|
@ -274,34 +293,38 @@ Chunk {chunk_number}/{total_chunks}:
|
|||
<document_chunk>
|
||||
{chunk}
|
||||
</document_chunk>
|
||||
</INSTRUCTIONS>"""
|
||||
</INSTRUCTIONS>""",
|
||||
)
|
||||
|
||||
|
||||
# Process each chunk individually
|
||||
chunk_summaries = []
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
try:
|
||||
logger.info(f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)")
|
||||
|
||||
logger.info(
|
||||
f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)"
|
||||
)
|
||||
|
||||
chunk_chain = chunk_template | llm
|
||||
chunk_result = await chunk_chain.ainvoke({
|
||||
"chunk": chunk.text,
|
||||
"chunk_number": i,
|
||||
"total_chunks": total_chunks
|
||||
})
|
||||
|
||||
chunk_result = await chunk_chain.ainvoke(
|
||||
{
|
||||
"chunk": chunk.text,
|
||||
"chunk_number": i,
|
||||
"total_chunks": total_chunks,
|
||||
}
|
||||
)
|
||||
|
||||
chunk_summary = chunk_result.content
|
||||
chunk_summaries.append(f"=== Section {i} ===\n{chunk_summary}")
|
||||
|
||||
|
||||
logger.info(f"✅ Completed chunk {i}/{total_chunks}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to process chunk {i}/{total_chunks}: {e}")
|
||||
chunk_summaries.append(f"=== Section {i} ===\n[Processing failed]")
|
||||
|
||||
|
||||
# Combine summaries into final document summary
|
||||
logger.info(f"🔄 Combining {len(chunk_summaries)} chunk summaries")
|
||||
|
||||
|
||||
try:
|
||||
combine_template = PromptTemplate(
|
||||
input_variables=["summaries", "document_title"],
|
||||
|
|
@ -318,22 +341,23 @@ Ensure:
|
|||
<section_summaries>
|
||||
{summaries}
|
||||
</section_summaries>
|
||||
</INSTRUCTIONS>"""
|
||||
</INSTRUCTIONS>""",
|
||||
)
|
||||
|
||||
|
||||
combined_summaries = "\n\n".join(chunk_summaries)
|
||||
combine_chain = combine_template | llm
|
||||
|
||||
final_result = await combine_chain.ainvoke({
|
||||
"summaries": combined_summaries,
|
||||
"document_title": document_title
|
||||
})
|
||||
|
||||
|
||||
final_result = await combine_chain.ainvoke(
|
||||
{"summaries": combined_summaries, "document_title": document_title}
|
||||
)
|
||||
|
||||
final_summary = final_result.content
|
||||
logger.info(f"✅ Large document processing complete: {len(final_summary)} chars summary")
|
||||
|
||||
logger.info(
|
||||
f"✅ Large document processing complete: {len(final_summary)} chars summary"
|
||||
)
|
||||
|
||||
return final_summary
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to combine summaries: {e}")
|
||||
# Fallback: return concatenated chunk summaries
|
||||
|
|
@ -341,6 +365,7 @@ Ensure:
|
|||
logger.warning("⚠️ Using fallback combined summary")
|
||||
return fallback_summary
|
||||
|
||||
|
||||
def create_docling_service() -> DoclingService:
|
||||
"""Create a Docling service instance."""
|
||||
return DoclingService()
|
||||
return DoclingService()
|
||||
|
|
|
|||
|
|
@ -1,45 +1,43 @@
|
|||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
import logging
|
||||
|
||||
from app.db import User, LLMConfig
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import LLMConfig, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMRole:
|
||||
LONG_CONTEXT = "long_context"
|
||||
FAST = "fast"
|
||||
STRATEGIC = "strategic"
|
||||
|
||||
|
||||
async def get_user_llm_instance(
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
role: str
|
||||
) -> Optional[ChatLiteLLM]:
|
||||
session: AsyncSession, user_id: str, role: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""
|
||||
Get a ChatLiteLLM instance for a specific user and role.
|
||||
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
role: LLM role ('long_context', 'fast', or 'strategic')
|
||||
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None if not found
|
||||
"""
|
||||
try:
|
||||
# Get user with their LLM preferences
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
|
||||
|
||||
if not user:
|
||||
logger.error(f"User {user_id} not found")
|
||||
return None
|
||||
|
||||
|
||||
# Get the appropriate LLM config ID based on role
|
||||
llm_config_id = None
|
||||
if role == LLMRole.LONG_CONTEXT:
|
||||
|
|
@ -51,24 +49,23 @@ async def get_user_llm_instance(
|
|||
else:
|
||||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
||||
|
||||
if not llm_config_id:
|
||||
logger.error(f"No {role} LLM configured for user {user_id}")
|
||||
return None
|
||||
|
||||
|
||||
# Get the LLM configuration
|
||||
result = await session.execute(
|
||||
select(LLMConfig).where(
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.user_id == user_id
|
||||
LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
|
||||
|
||||
if not llm_config:
|
||||
logger.error(f"LLM config {llm_config_id} not found for user {user_id}")
|
||||
return None
|
||||
|
||||
|
||||
# Build the model string for litellm
|
||||
if llm_config.custom_provider:
|
||||
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
|
||||
|
|
@ -76,7 +73,7 @@ async def get_user_llm_instance(
|
|||
# Map provider enum to litellm format
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
|
|
@ -84,37 +81,48 @@ async def get_user_llm_instance(
|
|||
"MISTRAL": "mistral",
|
||||
# Add more mappings as needed
|
||||
}
|
||||
provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower())
|
||||
provider_prefix = provider_map.get(
|
||||
llm_config.provider.value, llm_config.provider.value.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{llm_config.model_name}"
|
||||
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": llm_config.api_key,
|
||||
}
|
||||
|
||||
|
||||
# Add optional parameters
|
||||
if llm_config.api_base:
|
||||
litellm_kwargs["api_base"] = llm_config.api_base
|
||||
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if llm_config.litellm_params:
|
||||
litellm_kwargs.update(llm_config.litellm_params)
|
||||
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error getting LLM instance for user {user_id}, role {role}: {e!s}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
|
||||
async def get_user_long_context_llm(
|
||||
session: AsyncSession, user_id: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""Get user's long context LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
|
||||
|
||||
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
|
||||
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None:
|
||||
"""Get user's fast LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
|
||||
|
||||
async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
|
||||
async def get_user_strategic_llm(
|
||||
session: AsyncSession, user_id: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""Get user's strategic LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import datetime
|
||||
from langchain.schema import HumanMessage, SystemMessage, AIMessage
|
||||
from app.config import config
|
||||
from app.services.llm_service import get_user_strategic_llm
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from app.services.llm_service import get_user_strategic_llm
|
||||
|
||||
|
||||
class QueryService:
|
||||
|
|
@ -13,13 +14,13 @@ class QueryService:
|
|||
|
||||
@staticmethod
|
||||
async def reformulate_query_with_chat_history(
|
||||
user_query: str,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
chat_history_str: Optional[str] = None
|
||||
user_query: str,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
chat_history_str: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Reformulate the user query using the user's strategic LLM to make it more
|
||||
Reformulate the user query using the user's strategic LLM to make it more
|
||||
effective for information retrieval and research purposes.
|
||||
|
||||
Args:
|
||||
|
|
@ -38,7 +39,9 @@ class QueryService:
|
|||
# Get the user's strategic LLM instance
|
||||
llm = await get_user_strategic_llm(session, user_id)
|
||||
if not llm:
|
||||
print(f"Warning: No strategic LLM configured for user {user_id}. Using original query.")
|
||||
print(
|
||||
f"Warning: No strategic LLM configured for user {user_id}. Using original query."
|
||||
)
|
||||
return user_query
|
||||
|
||||
# Create system message with instructions
|
||||
|
|
@ -92,14 +95,13 @@ class QueryService:
|
|||
print(f"Error reformulating query: {e}")
|
||||
return user_query
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def langchain_chat_history_to_str(chat_history: List[Any]) -> str:
|
||||
async def langchain_chat_history_to_str(chat_history: list[Any]) -> str:
|
||||
"""
|
||||
Convert a list of chat history messages to a string.
|
||||
"""
|
||||
chat_history_str = "<chat_history>\n"
|
||||
|
||||
|
||||
for chat_message in chat_history:
|
||||
if isinstance(chat_message, HumanMessage):
|
||||
chat_history_str += f"<user>{chat_message.content}</user>\n"
|
||||
|
|
@ -107,6 +109,6 @@ class QueryService:
|
|||
chat_history_str += f"<assistant>{chat_message.content}</assistant>\n"
|
||||
elif isinstance(chat_message, SystemMessage):
|
||||
chat_history_str += f"<system>{chat_message.content}</system>\n"
|
||||
|
||||
|
||||
chat_history_str += "</chat_history>"
|
||||
return chat_history_str
|
||||
|
|
|
|||
|
|
@ -1,35 +1,39 @@
|
|||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from rerankers import Document as RerankerDocument
|
||||
|
||||
|
||||
class RerankerService:
|
||||
"""
|
||||
Service for reranking documents using a configured reranker
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, reranker_instance=None):
|
||||
"""
|
||||
Initialize the reranker service
|
||||
|
||||
|
||||
Args:
|
||||
reranker_instance: The reranker instance to use for reranking
|
||||
"""
|
||||
self.reranker_instance = reranker_instance
|
||||
|
||||
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
|
||||
def rerank_documents(
|
||||
self, query_text: str, documents: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using the configured reranker
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The query text to use for reranking
|
||||
documents: List of document dictionaries to rerank
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Reranked documents
|
||||
"""
|
||||
if not self.reranker_instance or not documents:
|
||||
return documents
|
||||
|
||||
|
||||
try:
|
||||
# Create Document objects for the rerankers library
|
||||
reranker_docs = []
|
||||
|
|
@ -38,58 +42,63 @@ class RerankerService:
|
|||
content = doc.get("content", "")
|
||||
score = doc.get("score", 0.0)
|
||||
document_info = doc.get("document", {})
|
||||
|
||||
|
||||
reranker_docs.append(
|
||||
RerankerDocument(
|
||||
text=content,
|
||||
doc_id=chunk_id,
|
||||
metadata={
|
||||
'document_id': document_info.get("id", ""),
|
||||
'document_title': document_info.get("title", ""),
|
||||
'document_type': document_info.get("document_type", ""),
|
||||
'rrf_score': score
|
||||
}
|
||||
"document_id": document_info.get("id", ""),
|
||||
"document_title": document_info.get("title", ""),
|
||||
"document_type": document_info.get("document_type", ""),
|
||||
"rrf_score": score,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Rerank using the configured reranker
|
||||
reranking_results = self.reranker_instance.rank(
|
||||
query=query_text,
|
||||
docs=reranker_docs
|
||||
query=query_text, docs=reranker_docs
|
||||
)
|
||||
|
||||
|
||||
# Process the results from the reranker
|
||||
# Convert to serializable dictionaries
|
||||
serialized_results = []
|
||||
for result in reranking_results.results:
|
||||
# Find the original document by id
|
||||
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None)
|
||||
original_doc = next(
|
||||
(
|
||||
doc
|
||||
for doc in documents
|
||||
if doc.get("chunk_id") == result.document.doc_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if original_doc:
|
||||
# Create a new document with the reranked score
|
||||
reranked_doc = original_doc.copy()
|
||||
reranked_doc["score"] = float(result.score)
|
||||
reranked_doc["rank"] = result.rank
|
||||
serialized_results.append(reranked_doc)
|
||||
|
||||
|
||||
return serialized_results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
logging.error(f"Error during reranking: {str(e)}")
|
||||
logging.error(f"Error during reranking: {e!s}")
|
||||
# Fall back to original documents without reranking
|
||||
return documents
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_reranker_instance() -> Optional['RerankerService']:
|
||||
def get_reranker_instance() -> Optional["RerankerService"]:
|
||||
"""
|
||||
Get a reranker service instance from the global configuration.
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[RerankerService]: A reranker service instance if configured, None otherwise
|
||||
"""
|
||||
from app.config import config
|
||||
|
||||
if hasattr(config, 'reranker_instance') and config.reranker_instance:
|
||||
|
||||
if hasattr(config, "reranker_instance") and config.reranker_instance:
|
||||
return RerankerService(config.reranker_instance)
|
||||
return None
|
||||
|
||||
|
|
@ -1,27 +1,15 @@
|
|||
import json
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StreamingService:
|
||||
def __init__(self):
|
||||
self.terminal_idx = 1
|
||||
self.message_annotations = [
|
||||
{
|
||||
"type": "TERMINAL_INFO",
|
||||
"content": []
|
||||
},
|
||||
{
|
||||
"type": "SOURCES",
|
||||
"content": []
|
||||
},
|
||||
{
|
||||
"type": "ANSWER",
|
||||
"content": []
|
||||
},
|
||||
{
|
||||
"type": "FURTHER_QUESTIONS",
|
||||
"content": []
|
||||
}
|
||||
{"type": "TERMINAL_INFO", "content": []},
|
||||
{"type": "SOURCES", "content": []},
|
||||
{"type": "ANSWER", "content": []},
|
||||
{"type": "FURTHER_QUESTIONS", "content": []},
|
||||
]
|
||||
|
||||
# DEPRECATED: This sends the full annotation array every time (inefficient)
|
||||
|
|
@ -35,7 +23,7 @@ class StreamingService:
|
|||
Returns:
|
||||
str: The formatted annotations string
|
||||
"""
|
||||
return f'8:{json.dumps(self.message_annotations)}\n'
|
||||
return f"8:{json.dumps(self.message_annotations)}\n"
|
||||
|
||||
def format_terminal_info_delta(self, text: str, message_type: str = "info") -> str:
|
||||
"""
|
||||
|
|
@ -58,7 +46,7 @@ class StreamingService:
|
|||
annotation = {"type": "TERMINAL_INFO", "content": [message]}
|
||||
return f"8:[{json.dumps(annotation)}]\n"
|
||||
|
||||
def format_sources_delta(self, sources: List[Dict[str, Any]]) -> str:
|
||||
def format_sources_delta(self, sources: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format sources as a delta annotation
|
||||
|
||||
|
|
@ -95,7 +83,7 @@ class StreamingService:
|
|||
annotation = {"type": "ANSWER", "content": [answer_chunk]}
|
||||
return f"8:[{json.dumps(annotation)}]\n"
|
||||
|
||||
def format_answer_annotation(self, answer_lines: List[str]) -> str:
|
||||
def format_answer_annotation(self, answer_lines: list[str]) -> str:
|
||||
"""
|
||||
Format the complete answer as a replacement annotation
|
||||
|
||||
|
|
@ -113,7 +101,7 @@ class StreamingService:
|
|||
return f"8:[{json.dumps(annotation)}]\n"
|
||||
|
||||
def format_further_questions_delta(
|
||||
self, further_questions: List[Dict[str, Any]]
|
||||
self, further_questions: list[dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
Format further questions as a delta annotation
|
||||
|
|
@ -155,14 +143,16 @@ class StreamingService:
|
|||
"""
|
||||
return f"3:{json.dumps(error_message)}\n"
|
||||
|
||||
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
|
||||
def format_completion(
|
||||
self, prompt_tokens: int = 156, completion_tokens: int = 204
|
||||
) -> str:
|
||||
"""
|
||||
Format a completion message
|
||||
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
|
||||
|
||||
Returns:
|
||||
str: The formatted completion string
|
||||
"""
|
||||
|
|
@ -172,7 +162,7 @@ class StreamingService:
|
|||
"usage": {
|
||||
"promptTokens": prompt_tokens,
|
||||
"completionTokens": completion_tokens,
|
||||
"totalTokens": total_tokens
|
||||
}
|
||||
"totalTokens": total_tokens,
|
||||
},
|
||||
}
|
||||
return f'd:{json.dumps(completion_data)}\n'
|
||||
return f"d:{json.dumps(completion_data)}\n"
|
||||
|
|
|
|||
|
|
@ -1,111 +1,116 @@
|
|||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.db import Log, LogLevel, LogStatus
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Log, LogLevel, LogStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskLoggingService:
|
||||
"""Service for logging background tasks using the database Log model"""
|
||||
|
||||
|
||||
def __init__(self, session: AsyncSession, search_space_id: int):
|
||||
self.session = session
|
||||
self.search_space_id = search_space_id
|
||||
|
||||
|
||||
async def log_task_start(
|
||||
self,
|
||||
task_name: str,
|
||||
source: str,
|
||||
message: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Log the start of a task with IN_PROGRESS status
|
||||
|
||||
|
||||
Args:
|
||||
task_name: Name/identifier of the task
|
||||
source: Source service/component (e.g., 'document_processor', 'slack_indexer')
|
||||
message: Human-readable message about the task
|
||||
metadata: Additional context data
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The created log entry
|
||||
"""
|
||||
log_metadata = metadata or {}
|
||||
log_metadata.update({
|
||||
"task_name": task_name,
|
||||
"started_at": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
log_metadata.update(
|
||||
{"task_name": task_name, "started_at": datetime.utcnow().isoformat()}
|
||||
)
|
||||
|
||||
log_entry = Log(
|
||||
level=LogLevel.INFO,
|
||||
status=LogStatus.IN_PROGRESS,
|
||||
message=message,
|
||||
source=source,
|
||||
log_metadata=log_metadata,
|
||||
search_space_id=self.search_space_id
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
|
||||
|
||||
self.session.add(log_entry)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
|
||||
logger.info(f"Started task {task_name}: {message}")
|
||||
return log_entry
|
||||
|
||||
|
||||
async def log_task_success(
|
||||
self,
|
||||
log_entry: Log,
|
||||
message: str,
|
||||
additional_metadata: Optional[Dict[str, Any]] = None
|
||||
additional_metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Update a log entry to SUCCESS status
|
||||
|
||||
|
||||
Args:
|
||||
log_entry: The original log entry to update
|
||||
message: Success message
|
||||
additional_metadata: Additional metadata to merge
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The updated log entry
|
||||
"""
|
||||
# Update the existing log entry
|
||||
log_entry.status = LogStatus.SUCCESS
|
||||
log_entry.message = message
|
||||
|
||||
|
||||
# Merge additional metadata
|
||||
if additional_metadata:
|
||||
if log_entry.log_metadata is None:
|
||||
log_entry.log_metadata = {}
|
||||
log_entry.log_metadata.update(additional_metadata)
|
||||
log_entry.log_metadata["completed_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
|
||||
|
||||
task_name = (
|
||||
log_entry.log_metadata.get("task_name", "unknown")
|
||||
if log_entry.log_metadata
|
||||
else "unknown"
|
||||
)
|
||||
logger.info(f"Completed task {task_name}: {message}")
|
||||
return log_entry
|
||||
|
||||
|
||||
async def log_task_failure(
|
||||
self,
|
||||
log_entry: Log,
|
||||
error_message: str,
|
||||
error_details: Optional[str] = None,
|
||||
additional_metadata: Optional[Dict[str, Any]] = None
|
||||
error_details: str | None = None,
|
||||
additional_metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Update a log entry to FAILED status
|
||||
|
||||
|
||||
Args:
|
||||
log_entry: The original log entry to update
|
||||
error_message: Error message
|
||||
error_details: Detailed error information
|
||||
additional_metadata: Additional metadata to merge
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The updated log entry
|
||||
"""
|
||||
|
|
@ -113,77 +118,86 @@ class TaskLoggingService:
|
|||
log_entry.status = LogStatus.FAILED
|
||||
log_entry.level = LogLevel.ERROR
|
||||
log_entry.message = error_message
|
||||
|
||||
|
||||
# Merge additional metadata
|
||||
if log_entry.log_metadata is None:
|
||||
log_entry.log_metadata = {}
|
||||
|
||||
log_entry.log_metadata.update({
|
||||
"failed_at": datetime.utcnow().isoformat(),
|
||||
"error_details": error_details
|
||||
})
|
||||
|
||||
|
||||
log_entry.log_metadata.update(
|
||||
{"failed_at": datetime.utcnow().isoformat(), "error_details": error_details}
|
||||
)
|
||||
|
||||
if additional_metadata:
|
||||
log_entry.log_metadata.update(additional_metadata)
|
||||
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
|
||||
|
||||
task_name = (
|
||||
log_entry.log_metadata.get("task_name", "unknown")
|
||||
if log_entry.log_metadata
|
||||
else "unknown"
|
||||
)
|
||||
logger.error(f"Failed task {task_name}: {error_message}")
|
||||
if error_details:
|
||||
logger.error(f"Error details: {error_details}")
|
||||
|
||||
|
||||
return log_entry
|
||||
|
||||
|
||||
async def log_task_progress(
|
||||
self,
|
||||
log_entry: Log,
|
||||
progress_message: str,
|
||||
progress_metadata: Optional[Dict[str, Any]] = None
|
||||
progress_metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Update a log entry with progress information while keeping IN_PROGRESS status
|
||||
|
||||
|
||||
Args:
|
||||
log_entry: The log entry to update
|
||||
progress_message: Progress update message
|
||||
progress_metadata: Additional progress metadata
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The updated log entry
|
||||
"""
|
||||
log_entry.message = progress_message
|
||||
|
||||
|
||||
if progress_metadata:
|
||||
if log_entry.log_metadata is None:
|
||||
log_entry.log_metadata = {}
|
||||
log_entry.log_metadata.update(progress_metadata)
|
||||
log_entry.log_metadata["last_progress_update"] = datetime.utcnow().isoformat()
|
||||
|
||||
log_entry.log_metadata["last_progress_update"] = (
|
||||
datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
|
||||
|
||||
task_name = (
|
||||
log_entry.log_metadata.get("task_name", "unknown")
|
||||
if log_entry.log_metadata
|
||||
else "unknown"
|
||||
)
|
||||
logger.info(f"Progress update for task {task_name}: {progress_message}")
|
||||
return log_entry
|
||||
|
||||
|
||||
async def log_simple_event(
|
||||
self,
|
||||
level: LogLevel,
|
||||
source: str,
|
||||
message: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Log a simple event (not a long-running task)
|
||||
|
||||
|
||||
Args:
|
||||
level: Log level
|
||||
source: Source service/component
|
||||
message: Log message
|
||||
metadata: Additional context data
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The created log entry
|
||||
"""
|
||||
|
|
@ -193,12 +207,12 @@ class TaskLoggingService:
|
|||
message=message,
|
||||
source=source,
|
||||
log_metadata=metadata or {},
|
||||
search_space_id=self.search_space_id
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
|
||||
|
||||
self.session.add(log_entry)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
|
||||
logger.info(f"Logged event from {source}: {message}")
|
||||
return log_entry
|
||||
return log_entry
|
||||
|
|
|
|||
|
|
@ -1,46 +1,49 @@
|
|||
from typing import Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import logging
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import aiohttp
|
||||
import validators
|
||||
from langchain_community.document_loaders import AsyncChromiumLoader, FireCrawlLoader
|
||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.db import Document, DocumentType, Chunk
|
||||
from app.schemas import ExtensionDocumentContent
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.utils.document_converters import convert_document_to_markdown, generate_content_hash
|
||||
from app.schemas import ExtensionDocumentContent
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
|
||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||
import validators
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
import aiohttp
|
||||
import logging
|
||||
from app.utils.document_converters import (
|
||||
convert_document_to_markdown,
|
||||
generate_content_hash,
|
||||
)
|
||||
|
||||
md = MarkdownifyTransformer()
|
||||
|
||||
|
||||
async def add_crawled_url_document(
|
||||
session: AsyncSession, url: str, search_space_id: int, user_id: str
|
||||
) -> Optional[Document]:
|
||||
) -> Document | None:
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="crawl_url_document",
|
||||
source="background_task",
|
||||
message=f"Starting URL crawling process for: {url}",
|
||||
metadata={"url": url, "user_id": str(user_id)}
|
||||
metadata={"url": url, "user_id": str(user_id)},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# URL validation step
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Validating URL: {url}",
|
||||
{"stage": "validation"}
|
||||
log_entry, f"Validating URL: {url}", {"stage": "validation"}
|
||||
)
|
||||
|
||||
|
||||
if not validators.url(url):
|
||||
raise ValueError(f"Url {url} is not a valid URL address")
|
||||
|
||||
|
|
@ -48,7 +51,10 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Setting up crawler for URL: {url}",
|
||||
{"stage": "crawler_setup", "firecrawl_available": bool(config.FIRECRAWL_API_KEY)}
|
||||
{
|
||||
"stage": "crawler_setup",
|
||||
"firecrawl_available": bool(config.FIRECRAWL_API_KEY),
|
||||
},
|
||||
)
|
||||
|
||||
if config.FIRECRAWL_API_KEY:
|
||||
|
|
@ -68,21 +74,21 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Crawling URL content: {url}",
|
||||
{"stage": "crawling", "crawler_type": type(crawl_loader).__name__}
|
||||
{"stage": "crawling", "crawler_type": type(crawl_loader).__name__},
|
||||
)
|
||||
|
||||
url_crawled = await crawl_loader.aload()
|
||||
|
||||
if type(crawl_loader) == FireCrawlLoader:
|
||||
if isinstance(crawl_loader, FireCrawlLoader):
|
||||
content_in_markdown = url_crawled[0].page_content
|
||||
elif type(crawl_loader) == AsyncChromiumLoader:
|
||||
elif isinstance(crawl_loader, AsyncChromiumLoader):
|
||||
content_in_markdown = md.transform_documents(url_crawled)[0].page_content
|
||||
|
||||
# Format document
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing crawled content from: {url}",
|
||||
{"stage": "content_processing", "content_length": len(content_in_markdown)}
|
||||
{"stage": "content_processing", "content_length": len(content_in_markdown)},
|
||||
)
|
||||
|
||||
# Format document metadata in a more maintainable way
|
||||
|
|
@ -117,7 +123,7 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Checking for duplicate content: {url}",
|
||||
{"stage": "duplicate_check", "content_hash": content_hash}
|
||||
{"stage": "duplicate_check", "content_hash": content_hash},
|
||||
)
|
||||
|
||||
# Check if document with this content hash already exists
|
||||
|
|
@ -125,21 +131,26 @@ async def add_crawled_url_document(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Document already exists for URL: {url}",
|
||||
{"duplicate_detected": True, "existing_document_id": existing_document.id}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
},
|
||||
)
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get LLM for summary generation
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Preparing for summary generation: {url}",
|
||||
{"stage": "llm_setup"}
|
||||
{"stage": "llm_setup"},
|
||||
)
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -151,7 +162,7 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Generating summary for URL content: {url}",
|
||||
{"stage": "summary_generation"}
|
||||
{"stage": "summary_generation"},
|
||||
)
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
|
|
@ -165,7 +176,7 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing content chunks for URL: {url}",
|
||||
{"stage": "chunk_processing"}
|
||||
{"stage": "chunk_processing"},
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -180,13 +191,13 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Creating document in database for URL: {url}",
|
||||
{"stage": "document_creation", "chunks_count": len(chunks)}
|
||||
{"stage": "document_creation", "chunks_count": len(chunks)},
|
||||
)
|
||||
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=url_crawled[0].metadata["title"]
|
||||
if type(crawl_loader) == FireCrawlLoader
|
||||
if isinstance(crawl_loader, FireCrawlLoader)
|
||||
else url_crawled[0].metadata["source"],
|
||||
document_type=DocumentType.CRAWLED_URL,
|
||||
document_metadata=url_crawled[0].metadata,
|
||||
|
|
@ -209,8 +220,8 @@ async def add_crawled_url_document(
|
|||
"title": document.title,
|
||||
"content_hash": content_hash,
|
||||
"chunks_count": len(chunks),
|
||||
"summary_length": len(summary_content)
|
||||
}
|
||||
"summary_length": len(summary_content),
|
||||
},
|
||||
)
|
||||
|
||||
return document
|
||||
|
|
@ -221,7 +232,7 @@ async def add_crawled_url_document(
|
|||
log_entry,
|
||||
f"Database error while processing URL: {url}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"}
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -230,14 +241,17 @@ async def add_crawled_url_document(
|
|||
log_entry,
|
||||
f"Failed to crawl URL: {url}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
raise RuntimeError(f"Failed to crawl URL: {str(e)}")
|
||||
raise RuntimeError(f"Failed to crawl URL: {e!s}") from e
|
||||
|
||||
|
||||
async def add_extension_received_document(
|
||||
session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int, user_id: str
|
||||
) -> Optional[Document]:
|
||||
session: AsyncSession,
|
||||
content: ExtensionDocumentContent,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content received from the SurfSense Extension.
|
||||
|
||||
|
|
@ -250,7 +264,7 @@ async def add_extension_received_document(
|
|||
Document object if successful, None if failed
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="extension_document",
|
||||
|
|
@ -259,10 +273,10 @@ async def add_extension_received_document(
|
|||
metadata={
|
||||
"url": content.metadata.VisitedWebPageURL,
|
||||
"title": content.metadata.VisitedWebPageTitle,
|
||||
"user_id": str(user_id)
|
||||
}
|
||||
"user_id": str(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# Format document metadata in a more maintainable way
|
||||
metadata_sections = [
|
||||
|
|
@ -301,14 +315,19 @@ async def add_extension_received_document(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Extension document already exists: {content.metadata.VisitedWebPageTitle}",
|
||||
{"duplicate_detected": True, "existing_document_id": existing_document.id}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
},
|
||||
)
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -356,8 +375,8 @@ async def add_extension_received_document(
|
|||
{
|
||||
"document_id": document.id,
|
||||
"content_hash": content_hash,
|
||||
"url": content.metadata.VisitedWebPageURL
|
||||
}
|
||||
"url": content.metadata.VisitedWebPageURL,
|
||||
},
|
||||
)
|
||||
|
||||
return document
|
||||
|
|
@ -368,7 +387,7 @@ async def add_extension_received_document(
|
|||
log_entry,
|
||||
f"Database error processing extension document: {content.metadata.VisitedWebPageTitle}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"}
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -377,24 +396,32 @@ async def add_extension_received_document(
|
|||
log_entry,
|
||||
f"Failed to process extension document: {content.metadata.VisitedWebPageTitle}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
raise RuntimeError(f"Failed to process extension document: {str(e)}")
|
||||
raise RuntimeError(f"Failed to process extension document: {e!s}") from e
|
||||
|
||||
|
||||
async def add_received_markdown_file_document(
|
||||
session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int, user_id: str
|
||||
) -> Optional[Document]:
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
file_in_markdown: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Document | None:
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="markdown_file_document",
|
||||
source="background_task",
|
||||
message=f"Processing markdown file: {file_name}",
|
||||
metadata={"filename": file_name, "user_id": str(user_id), "content_length": len(file_in_markdown)}
|
||||
metadata={
|
||||
"filename": file_name,
|
||||
"user_id": str(user_id),
|
||||
"content_length": len(file_in_markdown),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
content_hash = generate_content_hash(file_in_markdown, search_space_id)
|
||||
|
||||
|
|
@ -403,14 +430,19 @@ async def add_received_markdown_file_document(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Markdown file document already exists: {file_name}",
|
||||
{"duplicate_detected": True, "existing_document_id": existing_document.id}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
},
|
||||
)
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -459,8 +491,8 @@ async def add_received_markdown_file_document(
|
|||
"document_id": document.id,
|
||||
"content_hash": content_hash,
|
||||
"chunks_count": len(chunks),
|
||||
"summary_length": len(summary_content)
|
||||
}
|
||||
"summary_length": len(summary_content),
|
||||
},
|
||||
)
|
||||
|
||||
return document
|
||||
|
|
@ -470,7 +502,7 @@ async def add_received_markdown_file_document(
|
|||
log_entry,
|
||||
f"Database error processing markdown file: {file_name}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"}
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -479,18 +511,18 @@ async def add_received_markdown_file_document(
|
|||
log_entry,
|
||||
f"Failed to process markdown file: {file_name}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
||||
raise RuntimeError(f"Failed to process file document: {e!s}") from e
|
||||
|
||||
|
||||
async def add_received_file_document_using_unstructured(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
unstructured_processed_elements: List[LangChainDocument],
|
||||
unstructured_processed_elements: list[LangChainDocument],
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Optional[Document]:
|
||||
) -> Document | None:
|
||||
try:
|
||||
file_in_markdown = await convert_document_to_markdown(
|
||||
unstructured_processed_elements
|
||||
|
|
@ -503,9 +535,11 @@ async def add_received_file_document_using_unstructured(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
return existing_document
|
||||
|
||||
# TODO: Check if file_markdown exceeds token limit of embedding model
|
||||
|
|
@ -555,7 +589,7 @@ async def add_received_file_document_using_unstructured(
|
|||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
||||
raise RuntimeError(f"Failed to process file document: {e!s}") from e
|
||||
|
||||
|
||||
async def add_received_file_document_using_llamacloud(
|
||||
|
|
@ -564,7 +598,7 @@ async def add_received_file_document_using_llamacloud(
|
|||
llamacloud_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Optional[Document]:
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content parsed by LlamaCloud.
|
||||
|
||||
|
|
@ -588,9 +622,11 @@ async def add_received_file_document_using_llamacloud(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -638,7 +674,9 @@ async def add_received_file_document_using_llamacloud(
|
|||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process file document using LlamaCloud: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Failed to process file document using LlamaCloud: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def add_received_file_document_using_docling(
|
||||
|
|
@ -647,7 +685,7 @@ async def add_received_file_document_using_docling(
|
|||
docling_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Optional[Document]:
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content parsed by Docling.
|
||||
|
||||
|
|
@ -671,9 +709,11 @@ async def add_received_file_document_using_docling(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -683,12 +723,11 @@ async def add_received_file_document_using_docling(
|
|||
|
||||
# Generate summary using chunked processing for large documents
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
|
||||
|
||||
summary_content = await docling_service.process_large_document_summary(
|
||||
content=file_in_markdown,
|
||||
llm=user_llm,
|
||||
document_title=file_name
|
||||
content=file_in_markdown, llm=user_llm, document_title=file_name
|
||||
)
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
|
|
@ -726,7 +765,9 @@ async def add_received_file_document_using_docling(
|
|||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process file document using Docling: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Failed to process file document using Docling: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def add_youtube_video_document(
|
||||
|
|
@ -749,23 +790,23 @@ async def add_youtube_video_document(
|
|||
RuntimeError: If the video processing fails
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="youtube_video_document",
|
||||
source="background_task",
|
||||
message=f"Starting YouTube video processing for: {url}",
|
||||
metadata={"url": url, "user_id": str(user_id)}
|
||||
metadata={"url": url, "user_id": str(user_id)},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# Extract video ID from URL
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Extracting video ID from URL: {url}",
|
||||
{"stage": "video_id_extraction"}
|
||||
{"stage": "video_id_extraction"},
|
||||
)
|
||||
|
||||
|
||||
def get_youtube_video_id(url: str):
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
|
|
@ -790,14 +831,14 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Video ID extracted: {video_id}",
|
||||
{"stage": "video_id_extracted", "video_id": video_id}
|
||||
{"stage": "video_id_extracted", "video_id": video_id},
|
||||
)
|
||||
|
||||
# Get video metadata
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching video metadata for: {video_id}",
|
||||
{"stage": "metadata_fetch"}
|
||||
{"stage": "metadata_fetch"},
|
||||
)
|
||||
|
||||
params = {
|
||||
|
|
@ -806,21 +847,27 @@ async def add_youtube_video_document(
|
|||
}
|
||||
oembed_url = "https://www.youtube.com/oembed"
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.get(oembed_url, params=params) as response:
|
||||
video_data = await response.json()
|
||||
async with (
|
||||
aiohttp.ClientSession() as http_session,
|
||||
http_session.get(oembed_url, params=params) as response,
|
||||
):
|
||||
video_data = await response.json()
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Video metadata fetched: {video_data.get('title', 'Unknown')}",
|
||||
{"stage": "metadata_fetched", "title": video_data.get('title'), "author": video_data.get('author_name')}
|
||||
{
|
||||
"stage": "metadata_fetched",
|
||||
"title": video_data.get("title"),
|
||||
"author": video_data.get("author_name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Get video transcript
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching transcript for video: {video_id}",
|
||||
{"stage": "transcript_fetch"}
|
||||
{"stage": "transcript_fetch"},
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -834,25 +881,29 @@ async def add_youtube_video_document(
|
|||
timestamp = f"[{start_time:.2f}s-{start_time + duration:.2f}s]"
|
||||
transcript_segments.append(f"{timestamp} {text}")
|
||||
transcript_text = "\n".join(transcript_segments)
|
||||
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Transcript fetched successfully: {len(captions)} segments",
|
||||
{"stage": "transcript_fetched", "segments_count": len(captions), "transcript_length": len(transcript_text)}
|
||||
{
|
||||
"stage": "transcript_fetched",
|
||||
"segments_count": len(captions),
|
||||
"transcript_length": len(transcript_text),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
transcript_text = f"No captions available for this video. Error: {str(e)}"
|
||||
transcript_text = f"No captions available for this video. Error: {e!s}"
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"No transcript available for video: {video_id}",
|
||||
{"stage": "transcript_unavailable", "error": str(e)}
|
||||
{"stage": "transcript_unavailable", "error": str(e)},
|
||||
)
|
||||
|
||||
# Format document
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing video content: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "content_processing"}
|
||||
{"stage": "content_processing"},
|
||||
)
|
||||
|
||||
# Format document metadata in a more maintainable way
|
||||
|
|
@ -890,7 +941,7 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Checking for duplicate video content: {video_id}",
|
||||
{"stage": "duplicate_check", "content_hash": content_hash}
|
||||
{"stage": "duplicate_check", "content_hash": content_hash},
|
||||
)
|
||||
|
||||
# Check if document with this content hash already exists
|
||||
|
|
@ -898,21 +949,27 @@ async def add_youtube_video_document(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"YouTube video document already exists: {video_data.get('title', 'YouTube Video')}",
|
||||
{"duplicate_detected": True, "existing_document_id": existing_document.id, "video_id": video_id}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
"video_id": video_id,
|
||||
},
|
||||
)
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get LLM for summary generation
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Preparing for summary generation: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "llm_setup"}
|
||||
{"stage": "llm_setup"},
|
||||
)
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -924,7 +981,7 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Generating summary for video: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "summary_generation"}
|
||||
{"stage": "summary_generation"},
|
||||
)
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
|
|
@ -938,7 +995,7 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing content chunks for video: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "chunk_processing"}
|
||||
{"stage": "chunk_processing"},
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -953,7 +1010,7 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Creating YouTube video document in database: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "document_creation", "chunks_count": len(chunks)}
|
||||
{"stage": "document_creation", "chunks_count": len(chunks)},
|
||||
)
|
||||
|
||||
document = Document(
|
||||
|
|
@ -988,8 +1045,8 @@ async def add_youtube_video_document(
|
|||
"content_hash": content_hash,
|
||||
"chunks_count": len(chunks),
|
||||
"summary_length": len(summary_content),
|
||||
"has_transcript": "No captions available" not in transcript_text
|
||||
}
|
||||
"has_transcript": "No captions available" not in transcript_text,
|
||||
},
|
||||
)
|
||||
|
||||
return document
|
||||
|
|
@ -999,7 +1056,10 @@ async def add_youtube_video_document(
|
|||
log_entry,
|
||||
f"Database error while processing YouTube video: {url}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError", "video_id": video_id if 'video_id' in locals() else None}
|
||||
{
|
||||
"error_type": "SQLAlchemyError",
|
||||
"video_id": video_id if "video_id" in locals() else None,
|
||||
},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -1008,7 +1068,10 @@ async def add_youtube_video_document(
|
|||
log_entry,
|
||||
f"Failed to process YouTube video: {url}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__, "video_id": video_id if 'video_id' in locals() else None}
|
||||
{
|
||||
"error_type": type(e).__name__,
|
||||
"video_id": video_id if "video_id" in locals() else None,
|
||||
},
|
||||
)
|
||||
logging.error(f"Failed to process YouTube video: {str(e)}")
|
||||
logging.error(f"Failed to process YouTube video: {e!s}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Tuple
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.discord_connector import DiscordConnector
|
||||
|
|
@ -21,10 +25,6 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
|||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.document_converters import generate_content_hash
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -35,10 +35,10 @@ async def index_slack_messages(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
) -> tuple[int, str | None]:
|
||||
"""
|
||||
Index Slack messages from all accessible channels.
|
||||
|
||||
|
|
@ -192,7 +192,7 @@ async def index_slack_messages(
|
|||
str(e),
|
||||
{"error_type": "ChannelFetchError"},
|
||||
)
|
||||
return 0, f"Failed to get Slack channels: {str(e)}"
|
||||
return 0, f"Failed to get Slack channels: {e!s}"
|
||||
|
||||
if not channels:
|
||||
await task_logger.log_task_success(
|
||||
|
|
@ -400,13 +400,13 @@ async def index_slack_messages(
|
|||
|
||||
except SlackApiError as slack_error:
|
||||
logger.error(
|
||||
f"Slack API error for channel {channel_name}: {str(slack_error)}"
|
||||
f"Slack API error for channel {channel_name}: {slack_error!s}"
|
||||
)
|
||||
skipped_channels.append(f"{channel_name} (Slack API error)")
|
||||
documents_skipped += 1
|
||||
continue # Skip this channel and continue with others
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing channel {channel_name}: {str(e)}")
|
||||
logger.error(f"Error processing channel {channel_name}: {e!s}")
|
||||
skipped_channels.append(f"{channel_name} (processing error)")
|
||||
documents_skipped += 1
|
||||
continue # Skip this channel and continue with others
|
||||
|
|
@ -453,8 +453,8 @@ async def index_slack_messages(
|
|||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {str(db_error)}")
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
logger.error(f"Database error: {db_error!s}")
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -463,8 +463,8 @@ async def index_slack_messages(
|
|||
str(e),
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Slack messages: {str(e)}")
|
||||
return 0, f"Failed to index Slack messages: {str(e)}"
|
||||
logger.error(f"Failed to index Slack messages: {e!s}")
|
||||
return 0, f"Failed to index Slack messages: {e!s}"
|
||||
|
||||
|
||||
async def index_notion_pages(
|
||||
|
|
@ -472,10 +472,10 @@ async def index_notion_pages(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
) -> tuple[int, str | None]:
|
||||
"""
|
||||
Index Notion pages from all accessible pages.
|
||||
|
||||
|
|
@ -611,8 +611,8 @@ async def index_notion_pages(
|
|||
str(e),
|
||||
{"error_type": "PageFetchError"},
|
||||
)
|
||||
logger.error(f"Error fetching Notion pages: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to get Notion pages: {str(e)}"
|
||||
logger.error(f"Error fetching Notion pages: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to get Notion pages: {e!s}"
|
||||
|
||||
if not pages:
|
||||
await task_logger.log_task_success(
|
||||
|
|
@ -799,7 +799,7 @@ async def index_notion_pages(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing Notion page {page.get('title', 'Unknown')}: {str(e)}",
|
||||
f"Error processing Notion page {page.get('title', 'Unknown')}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
skipped_pages.append(
|
||||
|
|
@ -852,9 +852,9 @@ async def index_notion_pages(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(
|
||||
f"Database error during Notion indexing: {str(db_error)}", exc_info=True
|
||||
f"Database error during Notion indexing: {db_error!s}", exc_info=True
|
||||
)
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -863,8 +863,8 @@ async def index_notion_pages(
|
|||
str(e),
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Notion pages: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to index Notion pages: {str(e)}"
|
||||
logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Notion pages: {e!s}"
|
||||
|
||||
|
||||
async def index_github_repos(
|
||||
|
|
@ -872,10 +872,10 @@ async def index_github_repos(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
) -> tuple[int, str | None]:
|
||||
"""
|
||||
Index code and documentation files from accessible GitHub repositories.
|
||||
|
||||
|
|
@ -978,7 +978,7 @@ async def index_github_repos(
|
|||
str(e),
|
||||
{"error_type": "ClientInitializationError"},
|
||||
)
|
||||
return 0, f"Failed to initialize GitHub client: {str(e)}"
|
||||
return 0, f"Failed to initialize GitHub client: {e!s}"
|
||||
|
||||
# 4. Validate selected repositories
|
||||
# For simplicity, we'll proceed with the list provided.
|
||||
|
|
@ -1097,7 +1097,7 @@ async def index_github_repos(
|
|||
"url": file_url,
|
||||
"sha": file_sha,
|
||||
"type": file_type,
|
||||
"indexed_at": datetime.now(timezone.utc).isoformat(),
|
||||
"indexed_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
# Create new document
|
||||
|
|
@ -1175,10 +1175,10 @@ async def index_linear_issues(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
) -> tuple[int, str | None]:
|
||||
"""
|
||||
Index Linear issues and comments.
|
||||
|
||||
|
|
@ -1339,8 +1339,8 @@ async def index_linear_issues(
|
|||
logger.info(f"Retrieved {len(issues)} issues from Linear API")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception when calling Linear API: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to get Linear issues: {str(e)}"
|
||||
logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to get Linear issues: {e!s}"
|
||||
|
||||
if not issues:
|
||||
logger.info("No Linear issues found for the specified date range")
|
||||
|
|
@ -1481,7 +1481,7 @@ async def index_linear_issues(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing issue {issue.get('identifier', 'Unknown')}: {str(e)}",
|
||||
f"Error processing issue {issue.get('identifier', 'Unknown')}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
skipped_issues.append(
|
||||
|
|
@ -1528,8 +1528,8 @@ async def index_linear_issues(
|
|||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {str(db_error)}", exc_info=True)
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -1538,8 +1538,8 @@ async def index_linear_issues(
|
|||
str(e),
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Linear issues: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to index Linear issues: {str(e)}"
|
||||
logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Linear issues: {e!s}"
|
||||
|
||||
|
||||
async def index_discord_messages(
|
||||
|
|
@ -1547,10 +1547,10 @@ async def index_discord_messages(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
) -> tuple[int, str | None]:
|
||||
"""
|
||||
Index Discord messages from all accessible channels.
|
||||
|
||||
|
|
@ -1632,13 +1632,11 @@ async def index_discord_messages(
|
|||
# Calculate date range
|
||||
if start_date is None or end_date is None:
|
||||
# Fall back to calculating dates based on last_indexed_at
|
||||
calculated_end_date = datetime.now(timezone.utc)
|
||||
calculated_end_date = datetime.now(UTC)
|
||||
|
||||
# Use last_indexed_at as start date if available, otherwise use 365 days ago
|
||||
if connector.last_indexed_at:
|
||||
calculated_start_date = connector.last_indexed_at.replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
calculated_start_date = connector.last_indexed_at.replace(tzinfo=UTC)
|
||||
logger.info(
|
||||
f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date"
|
||||
)
|
||||
|
|
@ -1655,7 +1653,7 @@ async def index_discord_messages(
|
|||
# Convert YYYY-MM-DD to ISO format
|
||||
start_date_iso = (
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
.replace(tzinfo=timezone.utc)
|
||||
.replace(tzinfo=UTC)
|
||||
.isoformat()
|
||||
)
|
||||
|
||||
|
|
@ -1665,20 +1663,18 @@ async def index_discord_messages(
|
|||
# Convert YYYY-MM-DD to ISO format
|
||||
end_date_iso = (
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
.replace(tzinfo=timezone.utc)
|
||||
.replace(tzinfo=UTC)
|
||||
.isoformat()
|
||||
)
|
||||
else:
|
||||
# Convert provided dates to ISO format for Discord API
|
||||
start_date_iso = (
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
.replace(tzinfo=timezone.utc)
|
||||
.replace(tzinfo=UTC)
|
||||
.isoformat()
|
||||
)
|
||||
end_date_iso = (
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
.replace(tzinfo=timezone.utc)
|
||||
.isoformat()
|
||||
datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=UTC).isoformat()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
|
@ -1710,9 +1706,9 @@ async def index_discord_messages(
|
|||
str(e),
|
||||
{"error_type": "GuildFetchError"},
|
||||
)
|
||||
logger.error(f"Failed to get Discord guilds: {str(e)}", exc_info=True)
|
||||
logger.error(f"Failed to get Discord guilds: {e!s}", exc_info=True)
|
||||
await discord_client.close_bot()
|
||||
return 0, f"Failed to get Discord guilds: {str(e)}"
|
||||
return 0, f"Failed to get Discord guilds: {e!s}"
|
||||
if not guilds:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
|
|
@ -1754,7 +1750,7 @@ async def index_discord_messages(
|
|||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get messages for channel {channel_name}: {str(e)}"
|
||||
f"Failed to get messages for channel {channel_name}: {e!s}"
|
||||
)
|
||||
skipped_channels.append(
|
||||
f"{guild_name}#{channel_name} (fetch error)"
|
||||
|
|
@ -1886,7 +1882,9 @@ async def index_discord_messages(
|
|||
|
||||
chunks = [
|
||||
Chunk(content=raw_chunk.text, embedding=embedding)
|
||||
for raw_chunk, embedding in zip(raw_chunks, chunk_embeddings)
|
||||
for raw_chunk, embedding in zip(
|
||||
raw_chunks, chunk_embeddings, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
# Create and store new document
|
||||
|
|
@ -1902,7 +1900,7 @@ async def index_discord_messages(
|
|||
"message_count": len(formatted_messages),
|
||||
"start_date": start_date_iso,
|
||||
"end_date": end_date_iso,
|
||||
"indexed_at": datetime.now(timezone.utc).strftime(
|
||||
"indexed_at": datetime.now(UTC).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
},
|
||||
|
|
@ -1920,14 +1918,14 @@ async def index_discord_messages(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing guild {guild_name}: {str(e)}", exc_info=True
|
||||
f"Error processing guild {guild_name}: {e!s}", exc_info=True
|
||||
)
|
||||
skipped_channels.append(f"{guild_name} (processing error)")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
if update_last_indexed and documents_indexed > 0:
|
||||
connector.last_indexed_at = datetime.now(timezone.utc)
|
||||
connector.last_indexed_at = datetime.now(UTC)
|
||||
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
|
||||
|
||||
await session.commit()
|
||||
|
|
@ -1968,9 +1966,9 @@ async def index_discord_messages(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(
|
||||
f"Database error during Discord indexing: {str(db_error)}", exc_info=True
|
||||
f"Database error during Discord indexing: {db_error!s}", exc_info=True
|
||||
)
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -1979,8 +1977,8 @@ async def index_discord_messages(
|
|||
str(e),
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Discord messages: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to index Discord messages: {str(e)}"
|
||||
logger.error(f"Failed to index Discord messages: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index Discord messages: {e!s}"
|
||||
|
||||
|
||||
async def index_jira_issues(
|
||||
|
|
@ -1988,10 +1986,10 @@ async def index_jira_issues(
|
|||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
update_last_indexed: bool = True,
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
) -> tuple[int, str | None]:
|
||||
"""
|
||||
Index Jira issues and comments.
|
||||
|
||||
|
|
@ -2161,8 +2159,8 @@ async def index_jira_issues(
|
|||
logger.info(f"Retrieved {len(issues)} issues from Jira API")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Jira issues: {str(e)}", exc_info=True)
|
||||
return 0, f"Error fetching Jira issues: {str(e)}"
|
||||
logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True)
|
||||
return 0, f"Error fetching Jira issues: {e!s}"
|
||||
|
||||
# Process and index each issue
|
||||
documents_indexed = 0
|
||||
|
|
@ -2272,7 +2270,7 @@ async def index_jira_issues(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing issue {issue.get('identifier', 'Unknown')}: {str(e)}",
|
||||
f"Error processing issue {issue.get('identifier', 'Unknown')}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
skipped_issues.append(
|
||||
|
|
@ -2319,8 +2317,8 @@ async def index_jira_issues(
|
|||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {str(db_error)}", exc_info=True)
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -2329,5 +2327,5 @@ async def index_jira_issues(
|
|||
str(e),
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index JIRA issues: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to index JIRA issues: {str(e)}"
|
||||
logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True)
|
||||
return 0, f"Failed to index JIRA issues: {e!s}"
|
||||
|
|
|
|||
|
|
@ -1,33 +1,29 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State
|
||||
from app.db import Chat, Podcast
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
||||
async def generate_document_podcast(
|
||||
session: AsyncSession,
|
||||
document_id: int,
|
||||
search_space_id: int,
|
||||
user_id: int
|
||||
session: AsyncSession, document_id: int, search_space_id: int, user_id: int
|
||||
):
|
||||
# TODO: Need to fetch the document chunks, then concatenate them and pass them to the podcast generation model
|
||||
pass
|
||||
|
||||
|
||||
|
||||
async def generate_chat_podcast(
|
||||
session: AsyncSession,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
podcast_title: str,
|
||||
user_id: int
|
||||
user_id: int,
|
||||
):
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="generate_chat_podcast",
|
||||
|
|
@ -37,44 +33,43 @@ async def generate_chat_podcast(
|
|||
"chat_id": chat_id,
|
||||
"search_space_id": search_space_id,
|
||||
"podcast_title": podcast_title,
|
||||
"user_id": str(user_id)
|
||||
}
|
||||
"user_id": str(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# Fetch the chat with the specified ID
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching chat {chat_id} from database",
|
||||
{"stage": "fetch_chat"}
|
||||
log_entry, f"Fetching chat {chat_id} from database", {"stage": "fetch_chat"}
|
||||
)
|
||||
|
||||
|
||||
query = select(Chat).filter(
|
||||
Chat.id == chat_id,
|
||||
Chat.search_space_id == search_space_id
|
||||
Chat.id == chat_id, Chat.search_space_id == search_space_id
|
||||
)
|
||||
|
||||
|
||||
result = await session.execute(query)
|
||||
chat = result.scalars().first()
|
||||
|
||||
|
||||
if not chat:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Chat with id {chat_id} not found in search space {search_space_id}",
|
||||
"Chat not found",
|
||||
{"error_type": "ChatNotFound"}
|
||||
{"error_type": "ChatNotFound"},
|
||||
)
|
||||
raise ValueError(f"Chat with id {chat_id} not found in search space {search_space_id}")
|
||||
|
||||
raise ValueError(
|
||||
f"Chat with id {chat_id} not found in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Create chat history structure
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing chat history for chat {chat_id}",
|
||||
{"stage": "process_chat_history", "message_count": len(chat.messages)}
|
||||
{"stage": "process_chat_history", "message_count": len(chat.messages)},
|
||||
)
|
||||
|
||||
|
||||
chat_history_str = "<chat_history>"
|
||||
|
||||
|
||||
processed_messages = 0
|
||||
for message in chat.messages:
|
||||
if message["role"] == "user":
|
||||
|
|
@ -89,18 +84,24 @@ async def generate_chat_podcast(
|
|||
# If content is a list, join it into a single string
|
||||
if isinstance(answer_text, list):
|
||||
answer_text = "\n".join(answer_text)
|
||||
chat_history_str += f"<assistant_message>{answer_text}</assistant_message>"
|
||||
chat_history_str += (
|
||||
f"<assistant_message>{answer_text}</assistant_message>"
|
||||
)
|
||||
processed_messages += 1
|
||||
|
||||
|
||||
chat_history_str += "</chat_history>"
|
||||
|
||||
|
||||
# Pass it to the SurfSense Podcaster
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing podcast generation for chat {chat_id}",
|
||||
{"stage": "initialize_podcast_generation", "processed_messages": processed_messages, "content_length": len(chat_history_str)}
|
||||
{
|
||||
"stage": "initialize_podcast_generation",
|
||||
"processed_messages": processed_messages,
|
||||
"content_length": len(chat_history_str),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"podcast_title": "SurfSense",
|
||||
|
|
@ -108,53 +109,55 @@ async def generate_chat_podcast(
|
|||
}
|
||||
}
|
||||
# Initialize state with database session and streaming service
|
||||
initial_state = State(
|
||||
source_content=chat_history_str,
|
||||
db_session=session
|
||||
)
|
||||
|
||||
initial_state = State(source_content=chat_history_str, db_session=session)
|
||||
|
||||
# Run the graph directly
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Running podcast generation graph for chat {chat_id}",
|
||||
{"stage": "run_podcast_graph"}
|
||||
{"stage": "run_podcast_graph"},
|
||||
)
|
||||
|
||||
|
||||
result = await podcaster_graph.ainvoke(initial_state, config=config)
|
||||
|
||||
|
||||
# Convert podcast transcript entries to serializable format
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing podcast transcript for chat {chat_id}",
|
||||
{"stage": "process_transcript", "transcript_entries": len(result["podcast_transcript"])}
|
||||
{
|
||||
"stage": "process_transcript",
|
||||
"transcript_entries": len(result["podcast_transcript"]),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
serializable_transcript = []
|
||||
for entry in result["podcast_transcript"]:
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.speaker_id,
|
||||
"dialog": entry.dialog
|
||||
})
|
||||
|
||||
serializable_transcript.append(
|
||||
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
|
||||
)
|
||||
|
||||
# Create a new podcast entry
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Creating podcast database entry for chat {chat_id}",
|
||||
{"stage": "create_podcast_entry", "file_location": result.get("final_podcast_file_path")}
|
||||
{
|
||||
"stage": "create_podcast_entry",
|
||||
"file_location": result.get("final_podcast_file_path"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
podcast = Podcast(
|
||||
title=f"{podcast_title}",
|
||||
podcast_transcript=serializable_transcript,
|
||||
file_location=result["final_podcast_file_path"],
|
||||
search_space_id=search_space_id
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
|
||||
# Add to session and commit
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
|
||||
|
||||
# Log success
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
|
|
@ -165,10 +168,10 @@ async def generate_chat_podcast(
|
|||
"transcript_entries": len(serializable_transcript),
|
||||
"file_location": result.get("final_podcast_file_path"),
|
||||
"processed_messages": processed_messages,
|
||||
"content_length": len(chat_history_str)
|
||||
}
|
||||
"content_length": len(chat_history_str),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return podcast
|
||||
|
||||
except ValueError as ve:
|
||||
|
|
@ -178,7 +181,7 @@ async def generate_chat_podcast(
|
|||
log_entry,
|
||||
f"Value error during podcast generation for chat {chat_id}",
|
||||
str(ve),
|
||||
{"error_type": "ValueError"}
|
||||
{"error_type": "ValueError"},
|
||||
)
|
||||
raise ve
|
||||
except SQLAlchemyError as db_error:
|
||||
|
|
@ -187,7 +190,7 @@ async def generate_chat_podcast(
|
|||
log_entry,
|
||||
f"Database error during podcast generation for chat {chat_id}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"}
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -196,7 +199,8 @@ async def generate_chat_podcast(
|
|||
log_entry,
|
||||
f"Unexpected error during podcast generation for chat {chat_id}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
raise RuntimeError(f"Failed to generate podcast for chat {chat_id}: {str(e)}")
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to generate podcast for chat {chat_id}: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -1,28 +1,29 @@
|
|||
from typing import Any, AsyncGenerator, List, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.agents.researcher.graph import graph as researcher_graph
|
||||
from app.agents.researcher.state import State
|
||||
from app.services.streaming_service import StreamingService
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.researcher.configuration import SearchMode
|
||||
from app.agents.researcher.graph import graph as researcher_graph
|
||||
from app.agents.researcher.state import State
|
||||
from app.services.streaming_service import StreamingService
|
||||
|
||||
|
||||
async def stream_connector_search_results(
|
||||
user_query: str,
|
||||
user_id: Union[str, UUID],
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
selected_connectors: List[str],
|
||||
langchain_chat_history: List[Any],
|
||||
user_query: str,
|
||||
user_id: str | UUID,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
selected_connectors: list[str],
|
||||
langchain_chat_history: list[Any],
|
||||
search_mode_str: str,
|
||||
document_ids_to_add_in_context: List[int]
|
||||
document_ids_to_add_in_context: list[int],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream connector search results to the client
|
||||
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID (can be UUID object or string)
|
||||
|
|
@ -30,61 +31,60 @@ async def stream_connector_search_results(
|
|||
session: The database session
|
||||
research_mode: The research mode
|
||||
selected_connectors: List of selected connectors
|
||||
|
||||
|
||||
Yields:
|
||||
str: Formatted response strings
|
||||
"""
|
||||
streaming_service = StreamingService()
|
||||
|
||||
|
||||
if research_mode == "REPORT_GENERAL":
|
||||
NUM_SECTIONS = 1
|
||||
num_sections = 1
|
||||
elif research_mode == "REPORT_DEEP":
|
||||
NUM_SECTIONS = 3
|
||||
num_sections = 3
|
||||
elif research_mode == "REPORT_DEEPER":
|
||||
NUM_SECTIONS = 6
|
||||
num_sections = 6
|
||||
else:
|
||||
# Default fallback
|
||||
NUM_SECTIONS = 1
|
||||
|
||||
num_sections = 1
|
||||
|
||||
# Convert UUID to string if needed
|
||||
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
|
||||
|
||||
|
||||
if search_mode_str == "CHUNKS":
|
||||
search_mode = SearchMode.CHUNKS
|
||||
elif search_mode_str == "DOCUMENTS":
|
||||
search_mode = SearchMode.DOCUMENTS
|
||||
|
||||
|
||||
# Sample configuration
|
||||
config = {
|
||||
"configurable": {
|
||||
"user_query": user_query,
|
||||
"num_sections": NUM_SECTIONS,
|
||||
"num_sections": num_sections,
|
||||
"connectors_to_search": selected_connectors,
|
||||
"user_id": user_id_str,
|
||||
"search_space_id": search_space_id,
|
||||
"search_mode": search_mode,
|
||||
"research_mode": research_mode,
|
||||
"document_ids_to_add_in_context": document_ids_to_add_in_context
|
||||
"document_ids_to_add_in_context": document_ids_to_add_in_context,
|
||||
}
|
||||
}
|
||||
# Initialize state with database session and streaming service
|
||||
initial_state = State(
|
||||
db_session=session,
|
||||
streaming_service=streaming_service,
|
||||
chat_history=langchain_chat_history
|
||||
chat_history=langchain_chat_history,
|
||||
)
|
||||
|
||||
|
||||
# Run the graph directly
|
||||
print("\nRunning the complete researcher workflow...")
|
||||
|
||||
|
||||
# Use streaming with config parameter
|
||||
async for chunk in researcher_graph.astream(
|
||||
initial_state,
|
||||
config=config,
|
||||
stream_mode="custom",
|
||||
):
|
||||
if isinstance(chunk, dict):
|
||||
if "yield_value" in chunk:
|
||||
yield chunk["yield_value"]
|
||||
if isinstance(chunk, dict) and "yield_value" in chunk:
|
||||
yield chunk["yield_value"]
|
||||
|
||||
yield streaming_service.format_completion()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
from fastapi_users.authentication import (
|
||||
AuthenticationBackend,
|
||||
|
|
@ -10,21 +9,23 @@ from fastapi_users.authentication import (
|
|||
JWTStrategy,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.schemas import model_dump
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.config import config
|
||||
from app.db import User, get_user_db
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BearerResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
|
||||
google_oauth_client = GoogleOAuth2(
|
||||
config.GOOGLE_OAUTH_CLIENT_ID,
|
||||
config.GOOGLE_OAUTH_CLIENT_SECRET,
|
||||
|
|
@ -35,27 +36,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||
reset_password_token_secret = SECRET
|
||||
verification_token_secret = SECRET
|
||||
|
||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||
async def on_after_register(self, user: User, request: Request | None = None):
|
||||
print(f"User {user.id} has registered.")
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
self, user: User, token: str, request: Request | None = None
|
||||
):
|
||||
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
self, user: User, token: str, request: Request | None = None
|
||||
):
|
||||
print(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}")
|
||||
print(f"Verification requested for user {user.id}. Verification token: {token}")
|
||||
|
||||
|
||||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||
yield UserManager(user_db)
|
||||
|
||||
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600*24)
|
||||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24)
|
||||
|
||||
|
||||
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
|
||||
|
|
@ -77,6 +77,7 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
|||
# get_strategy=get_jwt_strategy,
|
||||
# )
|
||||
|
||||
|
||||
# BEARER AUTH CODE.
|
||||
class CustomBearerTransport(BearerTransport):
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
|
|
@ -87,6 +88,7 @@ class CustomBearerTransport(BearerTransport):
|
|||
else:
|
||||
return JSONResponse(model_dump(bearer_response))
|
||||
|
||||
|
||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
||||
|
||||
|
|
@ -98,4 +100,4 @@ auth_backend = AuthenticationBackend(
|
|||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,19 @@
|
|||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import User
|
||||
|
||||
|
||||
# Helper function to check user ownership
|
||||
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
|
||||
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id))
|
||||
item = await session.execute(
|
||||
select(model).filter(model.id == item_id, model.user_id == user.id)
|
||||
)
|
||||
item = item.scalars().first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it")
|
||||
return item
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Item not found or you don't have permission to access it",
|
||||
)
|
||||
return item
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ async def convert_element_to_markdown(element) -> str:
|
|||
"Footer": lambda x: f"*{x}*\n\n",
|
||||
"CodeSnippet": lambda x: f"```\n{x}\n```",
|
||||
"PageNumber": lambda x: f"*Page {x}*\n\n",
|
||||
"UncategorizedText": lambda x: f"{x}\n\n"
|
||||
"UncategorizedText": lambda x: f"{x}\n\n",
|
||||
}
|
||||
|
||||
converter = markdown_mapping.get(element_category, lambda x: x)
|
||||
|
|
@ -74,7 +74,7 @@ def convert_chunks_to_langchain_documents(chunks):
|
|||
except ImportError:
|
||||
raise ImportError(
|
||||
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
|
||||
)
|
||||
) from None
|
||||
|
||||
langchain_docs = []
|
||||
|
||||
|
|
@ -92,17 +92,20 @@ def convert_chunks_to_langchain_documents(chunks):
|
|||
# Add document information to metadata
|
||||
if "document" in chunk:
|
||||
doc = chunk["document"]
|
||||
metadata.update({
|
||||
"document_id": doc.get("id"),
|
||||
"document_title": doc.get("title"),
|
||||
"document_type": doc.get("document_type"),
|
||||
})
|
||||
metadata.update(
|
||||
{
|
||||
"document_id": doc.get("id"),
|
||||
"document_title": doc.get("title"),
|
||||
"document_type": doc.get("document_type"),
|
||||
}
|
||||
)
|
||||
|
||||
# Add document metadata if available
|
||||
if "metadata" in doc:
|
||||
# Prefix document metadata keys to avoid conflicts
|
||||
doc_metadata = {f"doc_meta_{k}": v for k,
|
||||
v in doc.get("metadata", {}).items()}
|
||||
doc_metadata = {
|
||||
f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()
|
||||
}
|
||||
metadata.update(doc_metadata)
|
||||
|
||||
# Add source URL if available in metadata
|
||||
|
|
@ -131,10 +134,7 @@ def convert_chunks_to_langchain_documents(chunks):
|
|||
"""
|
||||
|
||||
# Create LangChain Document
|
||||
langchain_doc = LangChainDocument(
|
||||
page_content=new_content,
|
||||
metadata=metadata
|
||||
)
|
||||
langchain_doc = LangChainDocument(page_content=new_content, metadata=metadata)
|
||||
|
||||
langchain_docs.append(langchain_doc)
|
||||
|
||||
|
|
@ -144,4 +144,4 @@ def convert_chunks_to_langchain_documents(chunks):
|
|||
def generate_content_hash(content: str, search_space_id: int) -> str:
|
||||
"""Generate SHA-256 hash for the given content combined with search space ID."""
|
||||
combined_data = f"{search_space_id}:{content}"
|
||||
return hashlib.sha256(combined_data.encode('utf-8')).hexdigest()
|
||||
return hashlib.sha256(combined_data.encode("utf-8")).hexdigest()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import uvicorn
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.config.uvicorn import load_uvicorn_config
|
||||
|
||||
logging.basicConfig(
|
||||
|
|
|
|||
|
|
@ -36,3 +36,97 @@ dependencies = [
|
|||
"validators>=0.34.0",
|
||||
"youtube-transcript-api>=1.0.3",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.12.5",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".ipynb_checkpoints",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pyenv",
|
||||
".pytest_cache",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
".vscode",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"site-packages",
|
||||
"venv",
|
||||
]
|
||||
|
||||
line-length = 88
|
||||
indent-width = 4
|
||||
|
||||
# Python 3.12
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E4", # pycodestyle errors
|
||||
"E7", # pycodestyle errors
|
||||
"E9", # pycodestyle errors
|
||||
"F", # Pyflakes
|
||||
"I", # isort
|
||||
"N", # pep8-naming
|
||||
"UP", # pyupgrade
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"T20", # flake8-print
|
||||
"SIM", # flake8-simplify
|
||||
"RUF", # Ruff-specific rules
|
||||
]
|
||||
|
||||
ignore = [
|
||||
"E501", # Line too long (handled by formatter)
|
||||
"B008", # Do not perform function calls in argument defaults
|
||||
"T201", # Print found (allow print statements)
|
||||
"RUF012", # Mutable class attributes should be annotated with `typing.ClassVar`
|
||||
]
|
||||
|
||||
extend-select = ["I"]
|
||||
|
||||
# Allow fix for all enabled rules (when `--fix`) is provided.
|
||||
fixable = ["ALL"]
|
||||
unfixable = []
|
||||
|
||||
# Allow unused variables when underscore-prefixed.
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
[tool.ruff.format]
|
||||
# Use double quotes for strings.
|
||||
quote-style = "double"
|
||||
|
||||
# Indent with spaces, rather than tabs.
|
||||
indent-style = "space"
|
||||
|
||||
# Respect magic trailing commas.
|
||||
skip-magic-trailing-comma = false
|
||||
|
||||
# Automatically detect the appropriate line ending.
|
||||
line-ending = "auto"
|
||||
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
# Group imports by type
|
||||
known-first-party = ["app"]
|
||||
force-single-line = false
|
||||
combine-as-imports = true
|
||||
|
|
|
|||
4507
surfsense_backend/uv.lock
generated
4507
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue