mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 07:42:39 +02:00
commit
617a7a34b5
92 changed files with 6274 additions and 5163 deletions
224
.github/workflows/code-quality.yml
vendored
Normal file
224
.github/workflows/code-quality.yml
vendored
Normal file
|
|
@ -0,0 +1,224 @@
|
||||||
|
name: Code Quality Checks
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [main, dev]
|
||||||
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
file-quality:
|
||||||
|
name: File Quality Checks
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: github.event.pull_request.draft == false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Fetch base branch
|
||||||
|
run: |
|
||||||
|
# Ensure we have the base branch reference for comparison
|
||||||
|
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install pre-commit
|
||||||
|
run: pip install pre-commit
|
||||||
|
|
||||||
|
- name: Cache pre-commit hooks
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pre-commit
|
||||||
|
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||||
|
restore-keys: |
|
||||||
|
pre-commit-
|
||||||
|
|
||||||
|
- name: Install hook environments (cache)
|
||||||
|
run: pre-commit install-hooks
|
||||||
|
|
||||||
|
- name: Run file quality checks on changed files
|
||||||
|
run: |
|
||||||
|
# Get list of changed files and run specific hooks on them
|
||||||
|
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
|
||||||
|
BASE_REF="${{ github.base_ref }}"
|
||||||
|
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
|
||||||
|
BASE_REF="origin/${{ github.base_ref }}"
|
||||||
|
else
|
||||||
|
echo "Base branch reference not found, running file quality hooks on all files"
|
||||||
|
pre-commit run --all-files check-yaml check-json check-toml check-merge-conflict check-added-large-files debug-statements check-case-conflict
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Running file quality hooks on changed files against $BASE_REF"
|
||||||
|
|
||||||
|
# Run each hook individually on changed files
|
||||||
|
SKIP=detect-secrets,bandit,ruff,ruff-format,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
|
||||||
|
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
|
||||||
|
|
||||||
|
# Exit with the same code as pre-commit
|
||||||
|
exit ${exit_code:-0}
|
||||||
|
|
||||||
|
security-scan:
|
||||||
|
name: Security Scan
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: github.event.pull_request.draft == false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Fetch base branch
|
||||||
|
run: |
|
||||||
|
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install pre-commit
|
||||||
|
run: pip install pre-commit
|
||||||
|
|
||||||
|
- name: Cache pre-commit hooks
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pre-commit
|
||||||
|
key: pre-commit-security-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||||
|
restore-keys: |
|
||||||
|
pre-commit-security-
|
||||||
|
|
||||||
|
- name: Install hook environments (cache)
|
||||||
|
run: pre-commit install-hooks
|
||||||
|
|
||||||
|
- name: Run security scans on changed files
|
||||||
|
run: |
|
||||||
|
# Get base ref for comparison
|
||||||
|
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
|
||||||
|
BASE_REF="${{ github.base_ref }}"
|
||||||
|
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
|
||||||
|
BASE_REF="origin/${{ github.base_ref }}"
|
||||||
|
else
|
||||||
|
echo "Base branch reference not found, running security scans on all files"
|
||||||
|
echo "⚠️ This may take longer than normal"
|
||||||
|
pre-commit run --all-files detect-secrets bandit
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Running security scans on changed files against $BASE_REF"
|
||||||
|
|
||||||
|
# Run only security hooks on changed files
|
||||||
|
SKIP=check-yaml,check-json,check-toml,check-merge-conflict,check-added-large-files,debug-statements,check-case-conflict,ruff,ruff-format,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
|
||||||
|
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
|
||||||
|
|
||||||
|
# Exit with the same code as pre-commit
|
||||||
|
exit ${exit_code:-0}
|
||||||
|
|
||||||
|
python-backend:
|
||||||
|
name: Python Backend Quality
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: github.event.pull_request.draft == false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install UV
|
||||||
|
uses: astral-sh/setup-uv@v3
|
||||||
|
|
||||||
|
- name: Check if backend files changed
|
||||||
|
id: backend-changes
|
||||||
|
uses: dorny/paths-filter@v3
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
backend:
|
||||||
|
- 'surfsense_backend/**'
|
||||||
|
|
||||||
|
- name: Cache dependencies
|
||||||
|
if: steps.backend-changes.outputs.backend == 'true'
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cache/uv
|
||||||
|
surfsense_backend/.venv
|
||||||
|
key: python-deps-${{ hashFiles('surfsense_backend/uv.lock') }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.backend-changes.outputs.backend == 'true'
|
||||||
|
working-directory: surfsense_backend
|
||||||
|
run: uv sync
|
||||||
|
|
||||||
|
- name: Install pre-commit for backend checks
|
||||||
|
if: steps.backend-changes.outputs.backend == 'true'
|
||||||
|
run: pip install pre-commit
|
||||||
|
|
||||||
|
- name: Cache pre-commit hooks
|
||||||
|
if: steps.backend-changes.outputs.backend == 'true'
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pre-commit
|
||||||
|
key: pre-commit-backend-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||||
|
restore-keys: |
|
||||||
|
pre-commit-backend-
|
||||||
|
|
||||||
|
- name: Install hook environments (cache)
|
||||||
|
if: steps.backend-changes.outputs.backend == 'true'
|
||||||
|
run: pre-commit install-hooks
|
||||||
|
|
||||||
|
- name: Run Python backend quality checks
|
||||||
|
if: steps.backend-changes.outputs.backend == 'true'
|
||||||
|
run: |
|
||||||
|
# Get base ref for comparison
|
||||||
|
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
|
||||||
|
BASE_REF="${{ github.base_ref }}"
|
||||||
|
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
|
||||||
|
BASE_REF="origin/${{ github.base_ref }}"
|
||||||
|
else
|
||||||
|
echo "Base branch reference not found, running Python backend checks on all files"
|
||||||
|
pre-commit run --all-files ruff ruff-format
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Running Python backend checks on changed files against $BASE_REF"
|
||||||
|
|
||||||
|
# Run only ruff hooks on changed Python files
|
||||||
|
SKIP=detect-secrets,bandit,check-yaml,check-json,check-toml,check-merge-conflict,check-added-large-files,debug-statements,check-case-conflict,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
|
||||||
|
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
|
||||||
|
|
||||||
|
# Exit with the same code as pre-commit
|
||||||
|
exit ${exit_code:-0}
|
||||||
|
|
||||||
|
quality-gate:
|
||||||
|
name: Quality Gate
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [file-quality, security-scan, python-backend]
|
||||||
|
if: always()
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check all jobs status
|
||||||
|
run: |
|
||||||
|
if [[ "${{ needs.file-quality.result }}" == "failure" ||
|
||||||
|
"${{ needs.security-scan.result }}" == "failure" ||
|
||||||
|
"${{ needs.python-backend.result }}" == "failure" ]]; then
|
||||||
|
echo "❌ Code quality checks failed"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "✅ All code quality checks passed"
|
||||||
|
fi
|
||||||
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
|
|
@ -3,7 +3,7 @@ name: pre-commit
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [main, dev]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
|
|
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,3 +1,5 @@
|
||||||
.flashrank_cache*
|
.flashrank_cache*
|
||||||
podcasts/
|
podcasts/
|
||||||
.env
|
.env
|
||||||
|
|
||||||
|
.ruff_cache/
|
||||||
|
|
@ -6,8 +6,6 @@ repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
|
||||||
exclude: '\.md$'
|
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
args: [--multi, --unsafe]
|
args: [--multi, --unsafe]
|
||||||
- id: check-json
|
- id: check-json
|
||||||
|
|
@ -31,52 +29,36 @@ repos:
|
||||||
.*\.env\.template|
|
.*\.env\.template|
|
||||||
.*/tests/.*|
|
.*/tests/.*|
|
||||||
.*test.*\.py|
|
.*test.*\.py|
|
||||||
|
test_.*\.py|
|
||||||
.github/workflows/.*\.yml|
|
.github/workflows/.*\.yml|
|
||||||
.github/workflows/.*\.yaml|
|
.github/workflows/.*\.yaml|
|
||||||
.*pnpm-lock\.yaml|
|
.*pnpm-lock\.yaml|
|
||||||
.*alembic\.ini|
|
.*alembic\.ini|
|
||||||
|
.*alembic/versions/.*\.py|
|
||||||
.*\.mdx$
|
.*\.mdx$
|
||||||
)$
|
)$
|
||||||
|
|
||||||
# Python Backend Hooks (surfsense_backend)
|
# Python Backend Hooks (surfsense_backend) - Using Ruff for linting and formatting
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 25.1.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
files: ^surfsense_backend/
|
|
||||||
language_version: python3
|
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
|
||||||
rev: 6.0.1
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
files: ^surfsense_backend/
|
|
||||||
args: ["--profile", "black", "--line-length", "88"]
|
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.12.4
|
rev: v0.12.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
name: ruff-check
|
||||||
files: ^surfsense_backend/
|
files: ^surfsense_backend/
|
||||||
args: [--fix, --exit-non-zero-on-fix]
|
exclude: ^surfsense_backend/(test_.*\.py|.*test.*\.py)
|
||||||
|
args: [--fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
name: ruff-format
|
||||||
files: ^surfsense_backend/
|
files: ^surfsense_backend/
|
||||||
|
exclude: ^surfsense_backend/(test_.*\.py|.*test.*\.py)
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
||||||
rev: v1.17.0
|
|
||||||
hooks:
|
|
||||||
- id: mypy
|
|
||||||
files: ^surfsense_backend/
|
|
||||||
additional_dependencies: ['types-requests']
|
|
||||||
args: [--ignore-missing-imports, --disallow-untyped-defs]
|
|
||||||
|
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.8.6
|
rev: 1.8.6
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
files: ^surfsense_backend/
|
files: ^surfsense_backend/
|
||||||
args: ['-r', '-f', 'json']
|
args: ['-f', 'json', '--severity-level', 'high', '--confidence-level', 'high']
|
||||||
exclude: ^surfsense_backend/(tests/|alembic/)
|
exclude: ^surfsense_backend/(tests/|test_.*\.py|.*test.*\.py|alembic/)
|
||||||
|
|
||||||
# Frontend/Extension Hooks (TypeScript/JavaScript)
|
# Frontend/Extension Hooks (TypeScript/JavaScript)
|
||||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from logging.config import fileConfig
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import pool
|
||||||
from sqlalchemy.engine import Connection
|
from sqlalchemy.engine import Connection
|
||||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
@ -11,7 +11,7 @@ from alembic import context
|
||||||
|
|
||||||
# Ensure the app directory is in the Python path
|
# Ensure the app directory is in the Python path
|
||||||
# This allows Alembic to find your models
|
# This allows Alembic to find your models
|
||||||
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
# Import your models base
|
# Import your models base
|
||||||
from app.db import Base # Assuming your Base is defined in app.db
|
from app.db import Base # Assuming your Base is defined in app.db
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,15 @@ Revision ID: 10
|
||||||
Revises: 9
|
Revises: 9
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "10"
|
revision: str = "10"
|
||||||
down_revision: Union[str, None] = "9"
|
down_revision: str | None = "9"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
# Define the ENUM type name
|
# Define the ENUM type name
|
||||||
CHAT_TYPE_ENUM = "chattype"
|
CHAT_TYPE_ENUM = "chattype"
|
||||||
|
|
@ -27,12 +25,7 @@ def upgrade() -> None:
|
||||||
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
||||||
|
|
||||||
# New enum values
|
# New enum values
|
||||||
new_values = (
|
new_values = ("QNA", "REPORT_GENERAL", "REPORT_DEEP", "REPORT_DEEPER")
|
||||||
"QNA",
|
|
||||||
"REPORT_GENERAL",
|
|
||||||
"REPORT_DEEP",
|
|
||||||
"REPORT_DEEPER"
|
|
||||||
)
|
|
||||||
new_values_sql = ", ".join([f"'{v}'" for v in new_values])
|
new_values_sql = ", ".join([f"'{v}'" for v in new_values])
|
||||||
|
|
||||||
# Table and column info
|
# Table and column info
|
||||||
|
|
@ -46,19 +39,31 @@ def upgrade() -> None:
|
||||||
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({new_values_sql})")
|
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({new_values_sql})")
|
||||||
|
|
||||||
# Step 3: Add a temporary column with the new type
|
# Step 3: Add a temporary column with the new type
|
||||||
op.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}")
|
op.execute(
|
||||||
|
f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}"
|
||||||
|
)
|
||||||
|
|
||||||
# Step 4: Update the temporary column with mapped values
|
# Step 4: Update the temporary column with mapped values
|
||||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'QNA' WHERE {column_name}::text = 'GENERAL'")
|
op.execute(
|
||||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEP' WHERE {column_name}::text = 'DEEP'")
|
f"UPDATE {table_name} SET {column_name}_new = 'QNA' WHERE {column_name}::text = 'GENERAL'"
|
||||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPER'")
|
)
|
||||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPEST'")
|
op.execute(
|
||||||
|
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEP' WHERE {column_name}::text = 'DEEP'"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPER'"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPEST'"
|
||||||
|
)
|
||||||
|
|
||||||
# Step 5: Drop the old column
|
# Step 5: Drop the old column
|
||||||
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
|
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
|
||||||
|
|
||||||
# Step 6: Rename the new column to the original name
|
# Step 6: Rename the new column to the original name
|
||||||
op.execute(f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}")
|
op.execute(
|
||||||
|
f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}"
|
||||||
|
)
|
||||||
|
|
||||||
# Step 7: Drop the old enum type
|
# Step 7: Drop the old enum type
|
||||||
op.execute(f"DROP TYPE {old_enum_name}")
|
op.execute(f"DROP TYPE {old_enum_name}")
|
||||||
|
|
@ -71,12 +76,7 @@ def downgrade() -> None:
|
||||||
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
||||||
|
|
||||||
# Original enum values
|
# Original enum values
|
||||||
original_values = (
|
original_values = ("GENERAL", "DEEP", "DEEPER", "DEEPEST")
|
||||||
"GENERAL",
|
|
||||||
"DEEP",
|
|
||||||
"DEEPER",
|
|
||||||
"DEEPEST"
|
|
||||||
)
|
|
||||||
original_values_sql = ", ".join([f"'{v}'" for v in original_values])
|
original_values_sql = ", ".join([f"'{v}'" for v in original_values])
|
||||||
|
|
||||||
# Table and column info
|
# Table and column info
|
||||||
|
|
@ -90,19 +90,31 @@ def downgrade() -> None:
|
||||||
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({original_values_sql})")
|
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({original_values_sql})")
|
||||||
|
|
||||||
# Step 3: Add a temporary column with the original type
|
# Step 3: Add a temporary column with the original type
|
||||||
op.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}")
|
op.execute(
|
||||||
|
f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}"
|
||||||
|
)
|
||||||
|
|
||||||
# Step 4: Update the temporary column with mapped values back to old values
|
# Step 4: Update the temporary column with mapped values back to old values
|
||||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'QNA'")
|
op.execute(
|
||||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'REPORT_GENERAL'")
|
f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'QNA'"
|
||||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'DEEP' WHERE {column_name}::text = 'REPORT_DEEP'")
|
)
|
||||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'DEEPER' WHERE {column_name}::text = 'REPORT_DEEPER'")
|
op.execute(
|
||||||
|
f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'REPORT_GENERAL'"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
f"UPDATE {table_name} SET {column_name}_new = 'DEEP' WHERE {column_name}::text = 'REPORT_DEEP'"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
f"UPDATE {table_name} SET {column_name}_new = 'DEEPER' WHERE {column_name}::text = 'REPORT_DEEPER'"
|
||||||
|
)
|
||||||
|
|
||||||
# Step 5: Drop the old column
|
# Step 5: Drop the old column
|
||||||
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
|
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
|
||||||
|
|
||||||
# Step 6: Rename the new column to the original name
|
# Step 6: Rename the new column to the original name
|
||||||
op.execute(f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}")
|
op.execute(
|
||||||
|
f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}"
|
||||||
|
)
|
||||||
|
|
||||||
# Step 7: Drop the old enum type
|
# Step 7: Drop the old enum type
|
||||||
op.execute(f"DROP TYPE {old_enum_name}")
|
op.execute(f"DROP TYPE {old_enum_name}")
|
||||||
|
|
@ -4,16 +4,17 @@ Revision ID: 11
|
||||||
Revises: 10
|
Revises: 10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "11"
|
revision: str = "11"
|
||||||
down_revision: Union[str, None] = "10"
|
down_revision: str | None = "10"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,17 @@ Revision ID: 12
|
||||||
Revises: 11
|
Revises: 11
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
from sqlalchemy import inspect
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "12"
|
revision: str = "12"
|
||||||
down_revision: Union[str, None] = "11"
|
down_revision: str | None = "11"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,15 @@ Revision ID: 13
|
||||||
Revises: 12
|
Revises: 12
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "13"
|
revision: str = "13"
|
||||||
down_revision: Union[str, None] = "12"
|
down_revision: str | None = "12"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ Revises:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
|
|
@ -15,9 +15,9 @@ from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "1"
|
revision: str = "1"
|
||||||
down_revision: Union[str, None] = None
|
down_revision: str | None = None
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
|
@ -63,11 +63,9 @@ def downgrade() -> None:
|
||||||
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR')"
|
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR')"
|
||||||
)
|
)
|
||||||
op.execute(
|
op.execute(
|
||||||
(
|
|
||||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||||
"connector_type::text::searchsourceconnectortype"
|
"connector_type::text::searchsourceconnectortype"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
op.execute("DROP TYPE searchsourceconnectortype_old")
|
op.execute("DROP TYPE searchsourceconnectortype_old")
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,15 @@ Revises: e55302644c51
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "2"
|
revision: str = "2"
|
||||||
down_revision: Union[str, None] = "e55302644c51"
|
down_revision: str | None = "e55302644c51"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
|
@ -49,11 +49,9 @@ def downgrade() -> None:
|
||||||
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR')"
|
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR')"
|
||||||
)
|
)
|
||||||
op.execute(
|
op.execute(
|
||||||
(
|
|
||||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||||
"connector_type::text::searchsourceconnectortype"
|
"connector_type::text::searchsourceconnectortype"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
op.execute("DROP TYPE searchsourceconnectortype_old")
|
op.execute("DROP TYPE searchsourceconnectortype_old")
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,15 @@ Revises: 2
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "3"
|
revision: str = "3"
|
||||||
down_revision: Union[str, None] = "2"
|
down_revision: str | None = "2"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
# Define the ENUM type name and the new value
|
# Define the ENUM type name and the new value
|
||||||
ENUM_NAME = "documenttype" # Make sure this matches the name in your DB (usually lowercase class name)
|
ENUM_NAME = "documenttype" # Make sure this matches the name in your DB (usually lowercase class name)
|
||||||
|
|
|
||||||
|
|
@ -5,37 +5,26 @@ Revises: 3
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "4"
|
revision: str = "4"
|
||||||
down_revision: Union[str, None] = "3"
|
down_revision: str | None = "3"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
ENUM_NAME = "searchsourceconnectortype"
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
NEW_VALUE = "LINKUP_API"
|
|
||||||
|
|
||||||
op.execute(
|
# Manually add the command to add the enum value
|
||||||
f"""
|
op.execute("ALTER TYPE searchsourceconnectortype ADD VALUE 'LINKUP_API'")
|
||||||
DO $$
|
|
||||||
BEGIN
|
# Pass for the rest, as autogenerate didn't run to add other schema details
|
||||||
IF NOT EXISTS (
|
pass
|
||||||
SELECT 1 FROM pg_enum
|
# ### end Alembic commands ###
|
||||||
WHERE enumlabel = '{NEW_VALUE}'
|
|
||||||
AND enumtypid = (
|
|
||||||
SELECT oid FROM pg_type WHERE typname = '{ENUM_NAME}'
|
|
||||||
)
|
|
||||||
) THEN
|
|
||||||
ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}';
|
|
||||||
END IF;
|
|
||||||
END$$;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
|
|
@ -49,11 +38,9 @@ def downgrade() -> None:
|
||||||
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR', 'LINEAR_CONNECTOR')"
|
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR', 'LINEAR_CONNECTOR')"
|
||||||
)
|
)
|
||||||
op.execute(
|
op.execute(
|
||||||
(
|
|
||||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||||
"connector_type::text::searchsourceconnectortype"
|
"connector_type::text::searchsourceconnectortype"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
op.execute("DROP TYPE searchsourceconnectortype_old")
|
op.execute("DROP TYPE searchsourceconnectortype_old")
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -4,54 +4,73 @@ Revision ID: 5
|
||||||
Revises: 4
|
Revises: 4
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = '5'
|
revision: str = "5"
|
||||||
down_revision: Union[str, None] = '4'
|
down_revision: str | None = "4"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Alter Chat table
|
# Alter Chat table
|
||||||
op.alter_column('chats', 'title',
|
op.alter_column(
|
||||||
|
"chats",
|
||||||
|
"title",
|
||||||
existing_type=sa.String(200),
|
existing_type=sa.String(200),
|
||||||
type_=sa.String(),
|
type_=sa.String(),
|
||||||
existing_nullable=False)
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Alter Document table
|
# Alter Document table
|
||||||
op.alter_column('documents', 'title',
|
op.alter_column(
|
||||||
|
"documents",
|
||||||
|
"title",
|
||||||
existing_type=sa.String(200),
|
existing_type=sa.String(200),
|
||||||
type_=sa.String(),
|
type_=sa.String(),
|
||||||
existing_nullable=False)
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Alter Podcast table
|
# Alter Podcast table
|
||||||
op.alter_column('podcasts', 'title',
|
op.alter_column(
|
||||||
|
"podcasts",
|
||||||
|
"title",
|
||||||
existing_type=sa.String(200),
|
existing_type=sa.String(200),
|
||||||
type_=sa.String(),
|
type_=sa.String(),
|
||||||
existing_nullable=False)
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# Revert Chat table
|
# Revert Chat table
|
||||||
op.alter_column('chats', 'title',
|
op.alter_column(
|
||||||
|
"chats",
|
||||||
|
"title",
|
||||||
existing_type=sa.String(),
|
existing_type=sa.String(),
|
||||||
type_=sa.String(200),
|
type_=sa.String(200),
|
||||||
existing_nullable=False)
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Revert Document table
|
# Revert Document table
|
||||||
op.alter_column('documents', 'title',
|
op.alter_column(
|
||||||
|
"documents",
|
||||||
|
"title",
|
||||||
existing_type=sa.String(),
|
existing_type=sa.String(),
|
||||||
type_=sa.String(200),
|
type_=sa.String(200),
|
||||||
existing_nullable=False)
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Revert Podcast table
|
# Revert Podcast table
|
||||||
op.alter_column('podcasts', 'title',
|
op.alter_column(
|
||||||
|
"podcasts",
|
||||||
|
"title",
|
||||||
existing_type=sa.String(),
|
existing_type=sa.String(),
|
||||||
type_=sa.String(200),
|
type_=sa.String(200),
|
||||||
existing_nullable=False)
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,18 +5,19 @@ Revises: 5
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
from sqlalchemy.dialects.postgresql import JSON
|
from sqlalchemy.dialects.postgresql import JSON
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "6"
|
revision: str = "6"
|
||||||
down_revision: Union[str, None] = "5"
|
down_revision: str | None = "5"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
|
|
||||||
|
|
@ -5,17 +5,18 @@ Revises: 6
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "7"
|
revision: str = "7"
|
||||||
down_revision: Union[str, None] = "6"
|
down_revision: str | None = "6"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,18 @@ Revision ID: 8
|
||||||
Revises: 7
|
Revises: 7
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "8"
|
revision: str = "8"
|
||||||
down_revision: Union[str, None] = "7"
|
down_revision: str | None = "7"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,15 @@ Revision ID: 9
|
||||||
Revises: 8
|
Revises: 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "9"
|
revision: str = "9"
|
||||||
down_revision: Union[str, None] = "8"
|
down_revision: str | None = "8"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
# Define the ENUM type name and the new value
|
# Define the ENUM type name and the new value
|
||||||
CONNECTOR_ENUM = "searchsourceconnectortype"
|
CONNECTOR_ENUM = "searchsourceconnectortype"
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "e55302644c51"
|
revision: str = "e55302644c51"
|
||||||
down_revision: Union[str, None] = "1"
|
down_revision: str | None = "1"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
# Define the ENUM type name and the new value
|
# Define the ENUM type name and the new value
|
||||||
ENUM_NAME = "documenttype"
|
ENUM_NAME = "documenttype"
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
|
@ -21,7 +20,7 @@ class Configuration:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_runnable_config(
|
def from_runnable_config(
|
||||||
cls, config: Optional[RunnableConfig] = None
|
cls, config: RunnableConfig | None = None
|
||||||
) -> Configuration:
|
) -> Configuration:
|
||||||
"""Create a Configuration instance from a RunnableConfig object."""
|
"""Create a Configuration instance from a RunnableConfig object."""
|
||||||
configurable = (config.get("configurable") or {}) if config else {}
|
configurable = (config.get("configurable") or {}) if config else {}
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,11 @@
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
from .configuration import Configuration
|
from .configuration import Configuration
|
||||||
|
from .nodes import create_merged_podcast_audio, create_podcast_transcript
|
||||||
from .state import State
|
from .state import State
|
||||||
|
|
||||||
|
|
||||||
from .nodes import create_merged_podcast_audio, create_podcast_transcript
|
|
||||||
|
|
||||||
|
|
||||||
def build_graph():
|
def build_graph():
|
||||||
|
|
||||||
# Define a new graph
|
# Define a new graph
|
||||||
workflow = StateGraph(State, config_schema=Configuration)
|
workflow = StateGraph(State, config_schema=Configuration)
|
||||||
|
|
||||||
|
|
@ -27,5 +24,6 @@ def build_graph():
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
# Compile the graph once when the module is loaded
|
# Compile the graph once when the module is loaded
|
||||||
graph = build_graph()
|
graph = build_graph()
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,26 @@
|
||||||
from typing import Any, Dict
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
from typing import Any
|
||||||
|
|
||||||
|
from ffmpeg.asyncio import FFmpeg
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from litellm import aspeech
|
from litellm import aspeech
|
||||||
from ffmpeg.asyncio import FFmpeg
|
|
||||||
|
|
||||||
from .configuration import Configuration
|
|
||||||
from .state import PodcastTranscriptEntry, State, PodcastTranscripts
|
|
||||||
from .prompts import get_podcast_generation_prompt
|
|
||||||
from app.config import config as app_config
|
from app.config import config as app_config
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
|
|
||||||
|
from .configuration import Configuration
|
||||||
|
from .prompts import get_podcast_generation_prompt
|
||||||
|
from .state import PodcastTranscriptEntry, PodcastTranscripts, State
|
||||||
|
|
||||||
async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
|
||||||
|
async def create_podcast_transcript(
|
||||||
|
state: State, config: RunnableConfig
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Each node does work."""
|
"""Each node does work."""
|
||||||
|
|
||||||
# Get configuration from runnable config
|
# Get configuration from runnable config
|
||||||
|
|
@ -37,7 +40,9 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
|
||||||
# Create the messages
|
# Create the messages
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=prompt),
|
SystemMessage(content=prompt),
|
||||||
HumanMessage(content=f"<source_content>{state.source_content}</source_content>")
|
HumanMessage(
|
||||||
|
content=f"<source_content>{state.source_content}</source_content>"
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Generate the podcast transcript
|
# Generate the podcast transcript
|
||||||
|
|
@ -45,9 +50,11 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
|
||||||
|
|
||||||
# First try the direct approach
|
# First try the direct approach
|
||||||
try:
|
try:
|
||||||
podcast_transcript = PodcastTranscripts.model_validate(json.loads(llm_response.content))
|
podcast_transcript = PodcastTranscripts.model_validate(
|
||||||
|
json.loads(llm_response.content)
|
||||||
|
)
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
print(f"Direct JSON parsing failed, trying fallback approach: {str(e)}")
|
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||||
|
|
||||||
# Fallback: Parse the JSON response manually
|
# Fallback: Parse the JSON response manually
|
||||||
try:
|
try:
|
||||||
|
|
@ -55,8 +62,8 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
|
||||||
content = llm_response.content
|
content = llm_response.content
|
||||||
|
|
||||||
# Find the JSON in the content (handle case where LLM might add additional text)
|
# Find the JSON in the content (handle case where LLM might add additional text)
|
||||||
json_start = content.find('{')
|
json_start = content.find("{")
|
||||||
json_end = content.rfind('}') + 1
|
json_end = content.rfind("}") + 1
|
||||||
if json_start >= 0 and json_end > json_start:
|
if json_start >= 0 and json_end > json_start:
|
||||||
json_str = content[json_start:json_end]
|
json_str = content[json_start:json_end]
|
||||||
|
|
||||||
|
|
@ -66,7 +73,7 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
|
||||||
# Convert to Pydantic model
|
# Convert to Pydantic model
|
||||||
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
|
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
|
||||||
|
|
||||||
print(f"Successfully parsed podcast transcript using fallback approach")
|
print("Successfully parsed podcast transcript using fallback approach")
|
||||||
else:
|
else:
|
||||||
# If JSON structure not found, raise a clear error
|
# If JSON structure not found, raise a clear error
|
||||||
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
||||||
|
|
@ -75,36 +82,35 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
|
||||||
|
|
||||||
except (json.JSONDecodeError, ValueError) as e2:
|
except (json.JSONDecodeError, ValueError) as e2:
|
||||||
# Log the error and re-raise it
|
# Log the error and re-raise it
|
||||||
error_message = f"Error parsing LLM response (fallback also failed): {str(e2)}"
|
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||||
print(f"Error parsing LLM response: {str(e2)}")
|
print(f"Error parsing LLM response: {e2!s}")
|
||||||
print(f"Raw response: {llm_response.content}")
|
print(f"Raw response: {llm_response.content}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return {
|
return {"podcast_transcript": podcast_transcript.podcast_transcripts}
|
||||||
"podcast_transcript": podcast_transcript.podcast_transcripts
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
async def create_merged_podcast_audio(
|
||||||
|
state: State, config: RunnableConfig
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Generate audio for each transcript and merge them into a single podcast file."""
|
"""Generate audio for each transcript and merge them into a single podcast file."""
|
||||||
|
|
||||||
configuration = Configuration.from_runnable_config(config)
|
configuration = Configuration.from_runnable_config(config)
|
||||||
|
|
||||||
starting_transcript = PodcastTranscriptEntry(
|
starting_transcript = PodcastTranscriptEntry(
|
||||||
speaker_id=1,
|
speaker_id=1, dialog=f"Welcome to {configuration.podcast_title} Podcast."
|
||||||
dialog=f"Welcome to {configuration.podcast_title} Podcast."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
transcript = state.podcast_transcript
|
transcript = state.podcast_transcript
|
||||||
|
|
||||||
# Merge the starting transcript with the podcast transcript
|
# Merge the starting transcript with the podcast transcript
|
||||||
# Check if transcript is a PodcastTranscripts object or already a list
|
# Check if transcript is a PodcastTranscripts object or already a list
|
||||||
if hasattr(transcript, 'podcast_transcripts'):
|
if hasattr(transcript, "podcast_transcripts"):
|
||||||
transcript_entries = transcript.podcast_transcripts
|
transcript_entries = transcript.podcast_transcripts
|
||||||
else:
|
else:
|
||||||
transcript_entries = transcript
|
transcript_entries = transcript
|
||||||
|
|
||||||
merged_transcript = [starting_transcript] + transcript_entries
|
merged_transcript = [starting_transcript, *transcript_entries]
|
||||||
|
|
||||||
# Create a temporary directory for audio files
|
# Create a temporary directory for audio files
|
||||||
temp_dir = Path("temp_audio")
|
temp_dir = Path("temp_audio")
|
||||||
|
|
@ -130,7 +136,7 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
|
||||||
|
|
||||||
async def generate_speech_for_segment(segment, index):
|
async def generate_speech_for_segment(segment, index):
|
||||||
# Handle both dictionary and PodcastTranscriptEntry objects
|
# Handle both dictionary and PodcastTranscriptEntry objects
|
||||||
if hasattr(segment, 'speaker_id'):
|
if hasattr(segment, "speaker_id"):
|
||||||
speaker_id = segment.speaker_id
|
speaker_id = segment.speaker_id
|
||||||
dialog = segment.dialog
|
dialog = segment.dialog
|
||||||
else:
|
else:
|
||||||
|
|
@ -165,16 +171,19 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save the audio to a file - use proper streaming method
|
# Save the audio to a file - use proper streaming method
|
||||||
with open(filename, 'wb') as f:
|
with open(filename, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
return filename
|
return filename
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating speech for segment {index}: {str(e)}")
|
print(f"Error generating speech for segment {index}: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Generate all audio files concurrently
|
# Generate all audio files concurrently
|
||||||
tasks = [generate_speech_for_segment(segment, i) for i, segment in enumerate(merged_transcript)]
|
tasks = [
|
||||||
|
generate_speech_for_segment(segment, i)
|
||||||
|
for i, segment in enumerate(merged_transcript)
|
||||||
|
]
|
||||||
audio_files = await asyncio.gather(*tasks)
|
audio_files = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Merge audio files using ffmpeg
|
# Merge audio files using ffmpeg
|
||||||
|
|
@ -191,7 +200,9 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
|
||||||
for i in range(len(audio_files)):
|
for i in range(len(audio_files)):
|
||||||
filter_complex.append(f"[{i}:0]")
|
filter_complex.append(f"[{i}:0]")
|
||||||
|
|
||||||
filter_complex_str = "".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
|
filter_complex_str = (
|
||||||
|
"".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
|
||||||
|
)
|
||||||
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
|
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
|
||||||
ffmpeg = ffmpeg.output(output_path, map="[outa]")
|
ffmpeg = ffmpeg.output(output_path, map="[outa]")
|
||||||
|
|
||||||
|
|
@ -201,17 +212,18 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
|
||||||
print(f"Successfully created podcast audio: {output_path}")
|
print(f"Successfully created podcast audio: {output_path}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error merging audio files: {str(e)}")
|
print(f"Error merging audio files: {e!s}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Clean up temporary files
|
# Clean up temporary files
|
||||||
for audio_file in audio_files:
|
for audio_file in audio_files:
|
||||||
try:
|
try:
|
||||||
os.remove(audio_file)
|
os.remove(audio_file)
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(f"Error removing audio file {audio_file}: {e!s}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"podcast_transcript": merged_transcript,
|
"podcast_transcript": merged_transcript,
|
||||||
"final_podcast_file_path": output_path
|
"final_podcast_file_path": output_path,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,16 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
class PodcastTranscriptEntry(BaseModel):
|
class PodcastTranscriptEntry(BaseModel):
|
||||||
"""
|
"""
|
||||||
Represents a single entry in a podcast transcript.
|
Represents a single entry in a podcast transcript.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)")
|
speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)")
|
||||||
dialog: str = Field(..., description="The dialog text spoken by the speaker")
|
dialog: str = Field(..., description="The dialog text spoken by the speaker")
|
||||||
|
|
||||||
|
|
@ -19,11 +21,12 @@ class PodcastTranscripts(BaseModel):
|
||||||
"""
|
"""
|
||||||
Represents the full podcast transcript structure.
|
Represents the full podcast transcript structure.
|
||||||
"""
|
"""
|
||||||
podcast_transcripts: List[PodcastTranscriptEntry] = Field(
|
|
||||||
...,
|
podcast_transcripts: list[PodcastTranscriptEntry] = Field(
|
||||||
description="List of transcript entries with alternating speakers"
|
..., description="List of transcript entries with alternating speakers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class State:
|
class State:
|
||||||
"""Defines the input state for the agent, representing a narrower interface to the outside world.
|
"""Defines the input state for the agent, representing a narrower interface to the outside world.
|
||||||
|
|
@ -32,8 +35,9 @@ class State:
|
||||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||||
for more information.
|
for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Runtime context
|
# Runtime context
|
||||||
db_session: AsyncSession
|
db_session: AsyncSession
|
||||||
source_content: str
|
source_content: str
|
||||||
podcast_transcript: Optional[List[PodcastTranscriptEntry]] = None
|
podcast_transcript: list[PodcastTranscriptEntry] | None = None
|
||||||
final_podcast_file_path: Optional[str] = None
|
final_podcast_file_path: str | None = None
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,20 @@ from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, List, Any
|
|
||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
class SearchMode(Enum):
|
class SearchMode(Enum):
|
||||||
"""Enum defining the type of search mode."""
|
"""Enum defining the type of search mode."""
|
||||||
|
|
||||||
CHUNKS = "CHUNKS"
|
CHUNKS = "CHUNKS"
|
||||||
DOCUMENTS = "DOCUMENTS"
|
DOCUMENTS = "DOCUMENTS"
|
||||||
|
|
||||||
|
|
||||||
class ResearchMode(Enum):
|
class ResearchMode(Enum):
|
||||||
"""Enum defining the type of research mode."""
|
"""Enum defining the type of research mode."""
|
||||||
|
|
||||||
QNA = "QNA"
|
QNA = "QNA"
|
||||||
REPORT_GENERAL = "REPORT_GENERAL"
|
REPORT_GENERAL = "REPORT_GENERAL"
|
||||||
REPORT_DEEP = "REPORT_DEEP"
|
REPORT_DEEP = "REPORT_DEEP"
|
||||||
|
|
@ -28,16 +31,16 @@ class Configuration:
|
||||||
# Input parameters provided at invocation
|
# Input parameters provided at invocation
|
||||||
user_query: str
|
user_query: str
|
||||||
num_sections: int
|
num_sections: int
|
||||||
connectors_to_search: List[str]
|
connectors_to_search: list[str]
|
||||||
user_id: str
|
user_id: str
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
search_mode: SearchMode
|
search_mode: SearchMode
|
||||||
research_mode: ResearchMode
|
research_mode: ResearchMode
|
||||||
document_ids_to_add_in_context: List[int]
|
document_ids_to_add_in_context: list[int]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_runnable_config(
|
def from_runnable_config(
|
||||||
cls, config: Optional[RunnableConfig] = None
|
cls, config: RunnableConfig | None = None
|
||||||
) -> Configuration:
|
) -> Configuration:
|
||||||
"""Create a Configuration instance from a RunnableConfig object."""
|
"""Create a Configuration instance from a RunnableConfig object."""
|
||||||
configurable = (config.get("configurable") or {}) if config else {}
|
configurable = (config.get("configurable") or {}) if config else {}
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,25 @@
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
from .state import State
|
|
||||||
from .nodes import reformulate_user_query, write_answer_outline, process_sections, handle_qna_workflow, generate_further_questions
|
|
||||||
from .configuration import Configuration, ResearchMode
|
from .configuration import Configuration, ResearchMode
|
||||||
from typing import TypedDict, List, Dict, Any, Optional
|
from .nodes import (
|
||||||
|
generate_further_questions,
|
||||||
|
handle_qna_workflow,
|
||||||
|
process_sections,
|
||||||
|
reformulate_user_query,
|
||||||
|
write_answer_outline,
|
||||||
|
)
|
||||||
|
from .state import State
|
||||||
|
|
||||||
|
|
||||||
# Define what keys are in our state dict
|
# Define what keys are in our state dict
|
||||||
class GraphState(TypedDict):
|
class GraphState(TypedDict):
|
||||||
# Intermediate data produced during workflow
|
# Intermediate data produced during workflow
|
||||||
answer_outline: Optional[Any]
|
answer_outline: Any | None
|
||||||
# Final output
|
# Final output
|
||||||
final_written_report: Optional[str]
|
final_written_report: str | None
|
||||||
|
|
||||||
|
|
||||||
def build_graph():
|
def build_graph():
|
||||||
"""
|
"""
|
||||||
|
|
@ -51,8 +61,8 @@ def build_graph():
|
||||||
route_after_reformulate,
|
route_after_reformulate,
|
||||||
{
|
{
|
||||||
"handle_qna_workflow": "handle_qna_workflow",
|
"handle_qna_workflow": "handle_qna_workflow",
|
||||||
"write_answer_outline": "write_answer_outline"
|
"write_answer_outline": "write_answer_outline",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# QNA workflow path: handle_qna_workflow -> generate_further_questions -> __end__
|
# QNA workflow path: handle_qna_workflow -> generate_further_questions -> __end__
|
||||||
|
|
@ -71,5 +81,6 @@ def build_graph():
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
# Compile the graph once when the module is loaded
|
# Compile the graph once when the module is loaded
|
||||||
graph = build_graph()
|
graph = build_graph()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from app.db import Document, SearchSpace
|
|
||||||
from app.services.connector_service import ConnectorService
|
|
||||||
from app.services.query_service import QueryService
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langgraph.types import StreamWriter
|
from langgraph.types import StreamWriter
|
||||||
|
|
@ -13,6 +10,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
# Additional imports for document fetching
|
# Additional imports for document fetching
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import Document, SearchSpace
|
||||||
|
from app.services.connector_service import ConnectorService
|
||||||
|
from app.services.query_service import QueryService
|
||||||
|
|
||||||
from .configuration import Configuration, SearchMode
|
from .configuration import Configuration, SearchMode
|
||||||
from .prompts import (
|
from .prompts import (
|
||||||
get_answer_outline_system_prompt,
|
get_answer_outline_system_prompt,
|
||||||
|
|
@ -26,8 +27,8 @@ from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_na
|
||||||
|
|
||||||
|
|
||||||
async def fetch_documents_by_ids(
|
async def fetch_documents_by_ids(
|
||||||
document_ids: List[int], user_id: str, db_session: AsyncSession
|
document_ids: list[int], user_id: str, db_session: AsyncSession
|
||||||
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Fetch documents by their IDs with ownership check using DOCUMENTS mode approach.
|
Fetch documents by their IDs with ownership check using DOCUMENTS mode approach.
|
||||||
|
|
||||||
|
|
@ -358,13 +359,13 @@ async def fetch_documents_by_ids(
|
||||||
return source_objects, formatted_documents
|
return source_objects, formatted_documents
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error fetching documents by IDs: {str(e)}")
|
print(f"Error fetching documents by IDs: {e!s}")
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
async def write_answer_outline(
|
async def write_answer_outline(
|
||||||
state: State, config: RunnableConfig, writer: StreamWriter
|
state: State, config: RunnableConfig, writer: StreamWriter
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Create a structured answer outline based on the user query.
|
Create a structured answer outline based on the user query.
|
||||||
|
|
||||||
|
|
@ -502,27 +503,27 @@ async def write_answer_outline(
|
||||||
|
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
# Log the error and re-raise it
|
# Log the error and re-raise it
|
||||||
error_message = f"Error parsing LLM response: {str(e)}"
|
error_message = f"Error parsing LLM response: {e!s}"
|
||||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||||
|
|
||||||
print(f"Error parsing LLM response: {str(e)}")
|
print(f"Error parsing LLM response: {e!s}")
|
||||||
print(f"Raw response: {response.content}")
|
print(f"Raw response: {response.content}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def fetch_relevant_documents(
|
async def fetch_relevant_documents(
|
||||||
research_questions: List[str],
|
research_questions: list[str],
|
||||||
user_id: str,
|
user_id: str,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
connectors_to_search: List[str],
|
connectors_to_search: list[str],
|
||||||
writer: StreamWriter = None,
|
writer: StreamWriter = None,
|
||||||
state: State = None,
|
state: State = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
connector_service: ConnectorService = None,
|
connector_service: ConnectorService = None,
|
||||||
search_mode: SearchMode = SearchMode.CHUNKS,
|
search_mode: SearchMode = SearchMode.CHUNKS,
|
||||||
user_selected_sources: List[Dict[str, Any]] = None,
|
user_selected_sources: list[dict[str, Any]] | None = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch relevant documents for research questions using the provided connectors.
|
Fetch relevant documents for research questions using the provided connectors.
|
||||||
|
|
||||||
|
|
@ -833,7 +834,10 @@ async def fetch_relevant_documents(
|
||||||
elif connector == "LINKUP_API":
|
elif connector == "LINKUP_API":
|
||||||
linkup_mode = "standard"
|
linkup_mode = "standard"
|
||||||
|
|
||||||
source_object, linkup_chunks = await connector_service.search_linkup(
|
(
|
||||||
|
source_object,
|
||||||
|
linkup_chunks,
|
||||||
|
) = await connector_service.search_linkup(
|
||||||
user_query=reformulated_query,
|
user_query=reformulated_query,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
mode=linkup_mode,
|
mode=linkup_mode,
|
||||||
|
|
@ -904,7 +908,7 @@ async def fetch_relevant_documents(
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Error searching connector {connector}: {str(e)}"
|
error_message = f"Error searching connector {connector}: {e!s}"
|
||||||
print(error_message)
|
print(error_message)
|
||||||
|
|
||||||
# Stream error message
|
# Stream error message
|
||||||
|
|
@ -913,7 +917,7 @@ async def fetch_relevant_documents(
|
||||||
writer(
|
writer(
|
||||||
{
|
{
|
||||||
"yield_value": streaming_service.format_error(
|
"yield_value": streaming_service.format_error(
|
||||||
f"Error searching {friendly_name}: {str(e)}"
|
f"Error searching {friendly_name}: {e!s}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -948,36 +952,48 @@ async def fetch_relevant_documents(
|
||||||
|
|
||||||
if source_id and source_type:
|
if source_id and source_type:
|
||||||
source_key = f"{source_type}_{source_id}"
|
source_key = f"{source_type}_{source_id}"
|
||||||
current_sources_count = len(source_obj.get('sources', []))
|
current_sources_count = len(source_obj.get("sources", []))
|
||||||
|
|
||||||
if source_key not in seen_source_keys:
|
if source_key not in seen_source_keys:
|
||||||
seen_source_keys.add(source_key)
|
seen_source_keys.add(source_key)
|
||||||
deduplicated_sources.append(source_obj)
|
deduplicated_sources.append(source_obj)
|
||||||
print(f"Debug: Added source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count}")
|
print(
|
||||||
|
f"Debug: Added source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Check if this source object has more sources than the existing one
|
# Check if this source object has more sources than the existing one
|
||||||
existing_index = None
|
existing_index = None
|
||||||
for i, existing_source in enumerate(deduplicated_sources):
|
for i, existing_source in enumerate(deduplicated_sources):
|
||||||
existing_id = existing_source.get('id')
|
existing_id = existing_source.get("id")
|
||||||
existing_type = existing_source.get('type')
|
existing_type = existing_source.get("type")
|
||||||
if existing_id == source_id and existing_type == source_type:
|
if existing_id == source_id and existing_type == source_type:
|
||||||
existing_index = i
|
existing_index = i
|
||||||
break
|
break
|
||||||
|
|
||||||
if existing_index is not None:
|
if existing_index is not None:
|
||||||
existing_sources_count = len(deduplicated_sources[existing_index].get('sources', []))
|
existing_sources_count = len(
|
||||||
|
deduplicated_sources[existing_index].get("sources", [])
|
||||||
|
)
|
||||||
if current_sources_count > existing_sources_count:
|
if current_sources_count > existing_sources_count:
|
||||||
# Replace the existing source object with the new one that has more sources
|
# Replace the existing source object with the new one that has more sources
|
||||||
deduplicated_sources[existing_index] = source_obj
|
deduplicated_sources[existing_index] = source_obj
|
||||||
print(f"Debug: Replaced source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {existing_sources_count} -> {current_sources_count}")
|
print(
|
||||||
|
f"Debug: Replaced source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {existing_sources_count} -> {current_sources_count}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count} <= {existing_sources_count}")
|
print(
|
||||||
|
f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count} <= {existing_sources_count}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key} (couldn't find existing)")
|
print(
|
||||||
|
f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key} (couldn't find existing)"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# If there's no ID or type, just add it to be safe
|
# If there's no ID or type, just add it to be safe
|
||||||
deduplicated_sources.append(source_obj)
|
deduplicated_sources.append(source_obj)
|
||||||
print(f"Debug: Added source without ID/type - {source_obj.get('name', 'UNKNOWN')}")
|
print(
|
||||||
|
f"Debug: Added source without ID/type - {source_obj.get('name', 'UNKNOWN')}"
|
||||||
|
)
|
||||||
|
|
||||||
# Stream info about deduplicated sources
|
# Stream info about deduplicated sources
|
||||||
if streaming_service and writer:
|
if streaming_service and writer:
|
||||||
|
|
@ -1039,7 +1055,7 @@ async def fetch_relevant_documents(
|
||||||
|
|
||||||
async def process_sections(
|
async def process_sections(
|
||||||
state: State, config: RunnableConfig, writer: StreamWriter
|
state: State, config: RunnableConfig, writer: StreamWriter
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Process all sections in parallel and combine the results.
|
Process all sections in parallel and combine the results.
|
||||||
|
|
||||||
|
|
@ -1100,13 +1116,13 @@ async def process_sections(
|
||||||
)
|
)
|
||||||
|
|
||||||
if configuration.num_sections == 1:
|
if configuration.num_sections == 1:
|
||||||
TOP_K = 10
|
top_k = 10
|
||||||
elif configuration.num_sections == 3:
|
elif configuration.num_sections == 3:
|
||||||
TOP_K = 20
|
top_k = 20
|
||||||
elif configuration.num_sections == 6:
|
elif configuration.num_sections == 6:
|
||||||
TOP_K = 30
|
top_k = 30
|
||||||
else:
|
else:
|
||||||
TOP_K = 10
|
top_k = 10
|
||||||
|
|
||||||
relevant_documents = []
|
relevant_documents = []
|
||||||
user_selected_documents = []
|
user_selected_documents = []
|
||||||
|
|
@ -1155,13 +1171,13 @@ async def process_sections(
|
||||||
connectors_to_search=configuration.connectors_to_search,
|
connectors_to_search=configuration.connectors_to_search,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
state=state,
|
state=state,
|
||||||
top_k=TOP_K,
|
top_k=top_k,
|
||||||
connector_service=connector_service,
|
connector_service=connector_service,
|
||||||
search_mode=configuration.search_mode,
|
search_mode=configuration.search_mode,
|
||||||
user_selected_sources=user_selected_sources,
|
user_selected_sources=user_selected_sources,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Error fetching relevant documents: {str(e)}"
|
error_message = f"Error fetching relevant documents: {e!s}"
|
||||||
print(error_message)
|
print(error_message)
|
||||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||||
# Log the error and continue with an empty list of documents
|
# Log the error and continue with an empty list of documents
|
||||||
|
|
@ -1251,7 +1267,7 @@ async def process_sections(
|
||||||
for i, result in enumerate(section_results):
|
for i, result in enumerate(section_results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
section_title = answer_outline.answer_outline[i].section_title
|
section_title = answer_outline.answer_outline[i].section_title
|
||||||
error_message = f"Error processing section '{section_title}': {str(result)}"
|
error_message = f"Error processing section '{section_title}': {result!s}"
|
||||||
print(error_message)
|
print(error_message)
|
||||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||||
processed_results.append(error_message)
|
processed_results.append(error_message)
|
||||||
|
|
@ -1260,8 +1276,8 @@ async def process_sections(
|
||||||
|
|
||||||
# Combine the results into a final report with section titles
|
# Combine the results into a final report with section titles
|
||||||
final_report = []
|
final_report = []
|
||||||
for i, (section, content) in enumerate(
|
for _i, (section, content) in enumerate(
|
||||||
zip(answer_outline.answer_outline, processed_results)
|
zip(answer_outline.answer_outline, processed_results, strict=False)
|
||||||
):
|
):
|
||||||
# Skip adding the section header since the content already contains the title
|
# Skip adding the section header since the content already contains the title
|
||||||
final_report.append(content)
|
final_report.append(content)
|
||||||
|
|
@ -1299,15 +1315,15 @@ async def process_sections(
|
||||||
async def process_section_with_documents(
|
async def process_section_with_documents(
|
||||||
section_id: int,
|
section_id: int,
|
||||||
section_title: str,
|
section_title: str,
|
||||||
section_questions: List[str],
|
section_questions: list[str],
|
||||||
user_id: str,
|
user_id: str,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
relevant_documents: List[Dict[str, Any]],
|
relevant_documents: list[dict[str, Any]],
|
||||||
user_query: str,
|
user_query: str,
|
||||||
state: State = None,
|
state: State = None,
|
||||||
writer: StreamWriter = None,
|
writer: StreamWriter = None,
|
||||||
sub_section_type: SubSectionType = SubSectionType.MIDDLE,
|
sub_section_type: SubSectionType = SubSectionType.MIDDLE,
|
||||||
section_contents: Dict[int, Dict[str, Any]] = None,
|
section_contents: dict[int, dict[str, Any]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process a single section using pre-fetched documents.
|
Process a single section using pre-fetched documents.
|
||||||
|
|
@ -1388,7 +1404,7 @@ async def process_section_with_documents(
|
||||||
# Variables to track streaming state
|
# Variables to track streaming state
|
||||||
complete_content = "" # Tracks the complete content received so far
|
complete_content = "" # Tracks the complete content received so far
|
||||||
|
|
||||||
async for chunk_type, chunk in sub_section_writer_graph.astream(
|
async for _chunk_type, chunk in sub_section_writer_graph.astream(
|
||||||
sub_state, config, stream_mode=["values"]
|
sub_state, config, stream_mode=["values"]
|
||||||
):
|
):
|
||||||
if "final_answer" in chunk:
|
if "final_answer" in chunk:
|
||||||
|
|
@ -1448,24 +1464,24 @@ async def process_section_with_documents(
|
||||||
|
|
||||||
return complete_content
|
return complete_content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing section '{section_title}': {str(e)}")
|
print(f"Error processing section '{section_title}': {e!s}")
|
||||||
|
|
||||||
# Send error update via streaming if available
|
# Send error update via streaming if available
|
||||||
if state and state.streaming_service and writer:
|
if state and state.streaming_service and writer:
|
||||||
writer(
|
writer(
|
||||||
{
|
{
|
||||||
"yield_value": state.streaming_service.format_error(
|
"yield_value": state.streaming_service.format_error(
|
||||||
f'Error processing section "{section_title}": {str(e)}'
|
f'Error processing section "{section_title}": {e!s}'
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"Error processing section: {section_title}. Details: {str(e)}"
|
return f"Error processing section: {section_title}. Details: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
async def reformulate_user_query(
|
async def reformulate_user_query(
|
||||||
state: State, config: RunnableConfig, writer: StreamWriter
|
state: State, config: RunnableConfig, writer: StreamWriter
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Reforms the user query based on the chat history.
|
Reforms the user query based on the chat history.
|
||||||
"""
|
"""
|
||||||
|
|
@ -1490,7 +1506,7 @@ async def reformulate_user_query(
|
||||||
|
|
||||||
async def handle_qna_workflow(
|
async def handle_qna_workflow(
|
||||||
state: State, config: RunnableConfig, writer: StreamWriter
|
state: State, config: RunnableConfig, writer: StreamWriter
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Handle the QNA research workflow.
|
Handle the QNA research workflow.
|
||||||
|
|
||||||
|
|
@ -1532,7 +1548,7 @@ async def handle_qna_workflow(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM
|
# Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM
|
||||||
TOP_K = 15
|
top_k = 15
|
||||||
|
|
||||||
relevant_documents = []
|
relevant_documents = []
|
||||||
user_selected_documents = []
|
user_selected_documents = []
|
||||||
|
|
@ -1584,13 +1600,13 @@ async def handle_qna_workflow(
|
||||||
connectors_to_search=configuration.connectors_to_search,
|
connectors_to_search=configuration.connectors_to_search,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
state=state,
|
state=state,
|
||||||
top_k=TOP_K,
|
top_k=top_k,
|
||||||
connector_service=connector_service,
|
connector_service=connector_service,
|
||||||
search_mode=configuration.search_mode,
|
search_mode=configuration.search_mode,
|
||||||
user_selected_sources=user_selected_sources,
|
user_selected_sources=user_selected_sources,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Error fetching relevant documents for QNA: {str(e)}"
|
error_message = f"Error fetching relevant documents for QNA: {e!s}"
|
||||||
print(error_message)
|
print(error_message)
|
||||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||||
# Continue with empty documents - the QNA agent will handle this gracefully
|
# Continue with empty documents - the QNA agent will handle this gracefully
|
||||||
|
|
@ -1688,16 +1704,16 @@ async def handle_qna_workflow(
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Error generating QNA answer: {str(e)}"
|
error_message = f"Error generating QNA answer: {e!s}"
|
||||||
print(error_message)
|
print(error_message)
|
||||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||||
|
|
||||||
return {"final_written_report": f"Error generating answer: {str(e)}"}
|
return {"final_written_report": f"Error generating answer: {e!s}"}
|
||||||
|
|
||||||
|
|
||||||
async def generate_further_questions(
|
async def generate_further_questions(
|
||||||
state: State, config: RunnableConfig, writer: StreamWriter
|
state: State, config: RunnableConfig, writer: StreamWriter
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate contextually relevant follow-up questions based on chat history and available documents.
|
Generate contextually relevant follow-up questions based on chat history and available documents.
|
||||||
|
|
||||||
|
|
@ -1748,7 +1764,7 @@ async def generate_further_questions(
|
||||||
chat_history_xml += f"<assistant>{message.content}</assistant>\n"
|
chat_history_xml += f"<assistant>{message.content}</assistant>\n"
|
||||||
else:
|
else:
|
||||||
# Handle other message types if needed
|
# Handle other message types if needed
|
||||||
chat_history_xml += f"<message>{str(message)}</message>\n"
|
chat_history_xml += f"<message>{message!s}</message>\n"
|
||||||
chat_history_xml += "</chat_history>"
|
chat_history_xml += "</chat_history>"
|
||||||
|
|
||||||
# Format available documents for the prompt
|
# Format available documents for the prompt
|
||||||
|
|
@ -1868,7 +1884,7 @@ async def generate_further_questions(
|
||||||
|
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
# Log the error and return empty list
|
# Log the error and return empty list
|
||||||
error_message = f"Error parsing further questions response: {str(e)}"
|
error_message = f"Error parsing further questions response: {e!s}"
|
||||||
print(error_message)
|
print(error_message)
|
||||||
writer(
|
writer(
|
||||||
{"yield_value": streaming_service.format_error(f"Warning: {error_message}")}
|
{"yield_value": streaming_service.format_error(f"Warning: {error_message}")}
|
||||||
|
|
@ -1880,7 +1896,7 @@ async def generate_further_questions(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle any other errors
|
# Handle any other errors
|
||||||
error_message = f"Error generating further questions: {str(e)}"
|
error_message = f"Error generating further questions: {e!s}"
|
||||||
print(error_message)
|
print(error_message)
|
||||||
writer(
|
writer(
|
||||||
{"yield_value": streaming_service.format_error(f"Warning: {error_message}")}
|
{"yield_value": streaming_service.format_error(f"Warning: {error_message}")}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
"""QnA Agent.
|
"""QnA Agent."""
|
||||||
"""
|
|
||||||
|
|
||||||
from .graph import graph
|
from .graph import graph
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import Optional, List, Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
|
@ -15,13 +15,15 @@ class Configuration:
|
||||||
# Configuration parameters for the Q&A agent
|
# Configuration parameters for the Q&A agent
|
||||||
user_query: str # The user's question to answer
|
user_query: str # The user's question to answer
|
||||||
reformulated_query: str # The reformulated query
|
reformulated_query: str # The reformulated query
|
||||||
relevant_documents: List[Any] # Documents provided directly to the agent for answering
|
relevant_documents: list[
|
||||||
|
Any
|
||||||
|
] # Documents provided directly to the agent for answering
|
||||||
user_id: str # User identifier
|
user_id: str # User identifier
|
||||||
search_space_id: int # Search space identifier
|
search_space_id: int # Search space identifier
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_runnable_config(
|
def from_runnable_config(
|
||||||
cls, config: Optional[RunnableConfig] = None
|
cls, config: RunnableConfig | None = None
|
||||||
) -> Configuration:
|
) -> Configuration:
|
||||||
"""Create a Configuration instance from a RunnableConfig object."""
|
"""Create a Configuration instance from a RunnableConfig object."""
|
||||||
configurable = (config.get("configurable") or {}) if config else {}
|
configurable = (config.get("configurable") or {}) if config else {}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
from .state import State
|
|
||||||
from .nodes import rerank_documents, answer_question
|
|
||||||
from .configuration import Configuration
|
from .configuration import Configuration
|
||||||
|
from .nodes import answer_question, rerank_documents
|
||||||
|
from .state import State
|
||||||
|
|
||||||
# Define a new graph
|
# Define a new graph
|
||||||
workflow = StateGraph(State, config_schema=Configuration)
|
workflow = StateGraph(State, config_schema=Configuration)
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,21 @@
|
||||||
from app.services.reranker_service import RerankerService
|
from typing import Any
|
||||||
from .configuration import Configuration
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
from .state import State
|
|
||||||
from typing import Any, Dict
|
|
||||||
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from ..utils import (
|
|
||||||
optimize_documents_for_token_limit,
|
|
||||||
calculate_token_count,
|
|
||||||
format_documents_section
|
|
||||||
)
|
|
||||||
|
|
||||||
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
from app.services.reranker_service import RerankerService
|
||||||
|
|
||||||
|
from ..utils import (
|
||||||
|
calculate_token_count,
|
||||||
|
format_documents_section,
|
||||||
|
optimize_documents_for_token_limit,
|
||||||
|
)
|
||||||
|
from .configuration import Configuration
|
||||||
|
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
||||||
|
from .state import State
|
||||||
|
|
||||||
|
|
||||||
|
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Rerank the documents based on relevance to the user's question.
|
Rerank the documents based on relevance to the user's question.
|
||||||
|
|
||||||
|
|
@ -30,9 +34,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
||||||
|
|
||||||
# If no documents were provided, return empty list
|
# If no documents were provided, return empty list
|
||||||
if not documents or len(documents) == 0:
|
if not documents or len(documents) == 0:
|
||||||
return {
|
return {"reranked_documents": []}
|
||||||
"reranked_documents": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get reranker service from app config
|
# Get reranker service from app config
|
||||||
reranker_service = RerankerService.get_reranker_instance()
|
reranker_service = RerankerService.get_reranker_instance()
|
||||||
|
|
@ -51,28 +53,34 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
||||||
"document": {
|
"document": {
|
||||||
"id": doc.get("document", {}).get("id", ""),
|
"id": doc.get("document", {}).get("id", ""),
|
||||||
"title": doc.get("document", {}).get("title", ""),
|
"title": doc.get("document", {}).get("title", ""),
|
||||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
"document_type": doc.get("document", {}).get(
|
||||||
"metadata": doc.get("document", {}).get("metadata", {})
|
"document_type", ""
|
||||||
|
),
|
||||||
|
"metadata": doc.get("document", {}).get("metadata", {}),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
} for i, doc in enumerate(documents)
|
for i, doc in enumerate(documents)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Rerank documents using the user's query
|
# Rerank documents using the user's query
|
||||||
reranked_docs = reranker_service.rerank_documents(user_query + "\n" + reformulated_query, reranker_input_docs)
|
reranked_docs = reranker_service.rerank_documents(
|
||||||
|
user_query + "\n" + reformulated_query, reranker_input_docs
|
||||||
|
)
|
||||||
|
|
||||||
# Sort by score in descending order
|
# Sort by score in descending order
|
||||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||||
|
|
||||||
print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}")
|
print(
|
||||||
|
f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during reranking: {str(e)}")
|
print(f"Error during reranking: {e!s}")
|
||||||
# Use original docs if reranking fails
|
# Use original docs if reranking fails
|
||||||
|
|
||||||
return {
|
return {"reranked_documents": reranked_docs}
|
||||||
"reranked_documents": reranked_docs
|
|
||||||
}
|
|
||||||
|
|
||||||
async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
|
||||||
|
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Answer the user's question using the provided documents.
|
Answer the user's question using the provided documents.
|
||||||
|
|
||||||
|
|
@ -117,14 +125,15 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
||||||
|
|
||||||
# Use initial system prompt for token calculation
|
# Use initial system prompt for token calculation
|
||||||
initial_system_prompt = get_qna_citation_system_prompt()
|
initial_system_prompt = get_qna_citation_system_prompt()
|
||||||
base_messages = state.chat_history + [
|
base_messages = [
|
||||||
|
*state.chat_history,
|
||||||
SystemMessage(content=initial_system_prompt),
|
SystemMessage(content=initial_system_prompt),
|
||||||
HumanMessage(content=base_human_message_template)
|
HumanMessage(content=base_human_message_template),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Optimize documents to fit within token limits
|
# Optimize documents to fit within token limits
|
||||||
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
|
optimized_documents, has_optimized_documents = (
|
||||||
documents, base_messages, llm.model
|
optimize_documents_for_token_limit(documents, base_messages, llm.model)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update state based on optimization result
|
# Update state based on optimization result
|
||||||
|
|
@ -134,19 +143,26 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
||||||
has_documents = False
|
has_documents = False
|
||||||
|
|
||||||
# Choose system prompt based on final document availability
|
# Choose system prompt based on final document availability
|
||||||
system_prompt = get_qna_citation_system_prompt() if has_documents else get_qna_no_documents_system_prompt()
|
system_prompt = (
|
||||||
|
get_qna_citation_system_prompt()
|
||||||
|
if has_documents
|
||||||
|
else get_qna_no_documents_system_prompt()
|
||||||
|
)
|
||||||
|
|
||||||
# Generate documents section
|
# Generate documents section
|
||||||
documents_text = format_documents_section(
|
documents_text = (
|
||||||
documents,
|
format_documents_section(
|
||||||
"Source material from your personal knowledge base"
|
documents, "Source material from your personal knowledge base"
|
||||||
) if has_documents else ""
|
)
|
||||||
|
if has_documents
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
# Create final human message content
|
# Create final human message content
|
||||||
instruction_text = (
|
instruction_text = (
|
||||||
"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner."
|
"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner."
|
||||||
if has_documents else
|
if has_documents
|
||||||
"Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
|
else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
|
||||||
)
|
)
|
||||||
|
|
||||||
human_message_content = f"""
|
human_message_content = f"""
|
||||||
|
|
@ -161,20 +177,18 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create final messages for the LLM
|
# Create final messages for the LLM
|
||||||
messages_with_chat_history = state.chat_history + [
|
messages_with_chat_history = [
|
||||||
|
*state.chat_history,
|
||||||
SystemMessage(content=system_prompt),
|
SystemMessage(content=system_prompt),
|
||||||
HumanMessage(content=human_message_content)
|
HumanMessage(content=human_message_content),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Log final token count
|
# Log final token count
|
||||||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
||||||
print(f"Final token count: {total_tokens}")
|
print(f"Final token count: {total_tokens}")
|
||||||
|
|
||||||
|
|
||||||
# Call the LLM and get the response
|
# Call the LLM and get the response
|
||||||
response = await llm.ainvoke(messages_with_chat_history)
|
response = await llm.ainvoke(messages_with_chat_history)
|
||||||
final_answer = response.content
|
final_answer = response.content
|
||||||
|
|
||||||
return {
|
return {"final_answer": final_answer}
|
||||||
"final_answer": final_answer
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,11 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class State:
|
class State:
|
||||||
"""Defines the dynamic state for the Q&A agent during execution.
|
"""Defines the dynamic state for the Q&A agent during execution.
|
||||||
|
|
@ -19,7 +21,7 @@ class State:
|
||||||
# Runtime context
|
# Runtime context
|
||||||
db_session: AsyncSession
|
db_session: AsyncSession
|
||||||
|
|
||||||
chat_history: Optional[List[Any]] = field(default_factory=list)
|
chat_history: list[Any] | None = field(default_factory=list)
|
||||||
# OUTPUT: Populated by agent nodes
|
# OUTPUT: Populated by agent nodes
|
||||||
reranked_documents: Optional[List[Any]] = None
|
reranked_documents: list[Any] | None = None
|
||||||
final_answer: Optional[str] = None
|
final_answer: str | None = None
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,13 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.services.streaming_service import StreamingService
|
from app.services.streaming_service import StreamingService
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class State:
|
class State:
|
||||||
"""Defines the dynamic state for the agent during execution.
|
"""Defines the dynamic state for the agent during execution.
|
||||||
|
|
@ -15,23 +18,23 @@ class State:
|
||||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||||
for more information.
|
for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Runtime context (not part of actual graph state)
|
# Runtime context (not part of actual graph state)
|
||||||
db_session: AsyncSession
|
db_session: AsyncSession
|
||||||
|
|
||||||
# Streaming service
|
# Streaming service
|
||||||
streaming_service: StreamingService
|
streaming_service: StreamingService
|
||||||
|
|
||||||
chat_history: Optional[List[Any]] = field(default_factory=list)
|
chat_history: list[Any] | None = field(default_factory=list)
|
||||||
|
|
||||||
reformulated_query: Optional[str] = field(default=None)
|
reformulated_query: str | None = field(default=None)
|
||||||
# Using field to explicitly mark as part of state
|
# Using field to explicitly mark as part of state
|
||||||
answer_outline: Optional[Any] = field(default=None)
|
answer_outline: Any | None = field(default=None)
|
||||||
further_questions: Optional[Any] = field(default=None)
|
further_questions: Any | None = field(default=None)
|
||||||
|
|
||||||
# Temporary field to hold reranked documents from sub-agents for further question generation
|
# Temporary field to hold reranked documents from sub-agents for further question generation
|
||||||
reranked_documents: Optional[List[Any]] = field(default=None)
|
reranked_documents: list[Any] | None = field(default=None)
|
||||||
|
|
||||||
# OUTPUT: Populated by agent nodes
|
# OUTPUT: Populated by agent nodes
|
||||||
# Using field to explicitly mark as part of state
|
# Using field to explicitly mark as part of state
|
||||||
final_written_report: Optional[str] = field(default=None)
|
final_written_report: str | None = field(default=None)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,14 @@ from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, List, Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
class SubSectionType(Enum):
|
class SubSectionType(Enum):
|
||||||
"""Enum defining the type of sub-section."""
|
"""Enum defining the type of sub-section."""
|
||||||
|
|
||||||
START = "START"
|
START = "START"
|
||||||
MIDDLE = "MIDDLE"
|
MIDDLE = "MIDDLE"
|
||||||
END = "END"
|
END = "END"
|
||||||
|
|
@ -22,17 +23,16 @@ class Configuration:
|
||||||
|
|
||||||
# Input parameters provided at invocation
|
# Input parameters provided at invocation
|
||||||
sub_section_title: str
|
sub_section_title: str
|
||||||
sub_section_questions: List[str]
|
sub_section_questions: list[str]
|
||||||
sub_section_type: SubSectionType
|
sub_section_type: SubSectionType
|
||||||
user_query: str
|
user_query: str
|
||||||
relevant_documents: List[Any] # Documents provided directly to the agent
|
relevant_documents: list[Any] # Documents provided directly to the agent
|
||||||
user_id: str
|
user_id: str
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_runnable_config(
|
def from_runnable_config(
|
||||||
cls, config: Optional[RunnableConfig] = None
|
cls, config: RunnableConfig | None = None
|
||||||
) -> Configuration:
|
) -> Configuration:
|
||||||
"""Create a Configuration instance from a RunnableConfig object."""
|
"""Create a Configuration instance from a RunnableConfig object."""
|
||||||
configurable = (config.get("configurable") or {}) if config else {}
|
configurable = (config.get("configurable") or {}) if config else {}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
from .state import State
|
|
||||||
from .nodes import write_sub_section, rerank_documents
|
|
||||||
from .configuration import Configuration
|
from .configuration import Configuration
|
||||||
|
from .nodes import rerank_documents, write_sub_section
|
||||||
|
from .state import State
|
||||||
|
|
||||||
# Define a new graph
|
# Define a new graph
|
||||||
workflow = StateGraph(State, config_schema=Configuration)
|
workflow = StateGraph(State, config_schema=Configuration)
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,21 @@
|
||||||
from .configuration import Configuration
|
from typing import Any
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
from .state import State
|
|
||||||
from typing import Any, Dict
|
|
||||||
from app.services.reranker_service import RerankerService
|
|
||||||
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from .configuration import SubSectionType
|
|
||||||
from ..utils import (
|
|
||||||
optimize_documents_for_token_limit,
|
|
||||||
calculate_token_count,
|
|
||||||
format_documents_section
|
|
||||||
)
|
|
||||||
|
|
||||||
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
from app.services.reranker_service import RerankerService
|
||||||
|
|
||||||
|
from ..utils import (
|
||||||
|
calculate_token_count,
|
||||||
|
format_documents_section,
|
||||||
|
optimize_documents_for_token_limit,
|
||||||
|
)
|
||||||
|
from .configuration import Configuration, SubSectionType
|
||||||
|
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
||||||
|
from .state import State
|
||||||
|
|
||||||
|
|
||||||
|
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Rerank the documents based on relevance to the sub-section title.
|
Rerank the documents based on relevance to the sub-section title.
|
||||||
|
|
||||||
|
|
@ -30,9 +33,7 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
||||||
|
|
||||||
# If no documents were provided, return empty list
|
# If no documents were provided, return empty list
|
||||||
if not documents or len(documents) == 0:
|
if not documents or len(documents) == 0:
|
||||||
return {
|
return {"reranked_documents": []}
|
||||||
"reranked_documents": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get reranker service from app config
|
# Get reranker service from app config
|
||||||
reranker_service = RerankerService.get_reranker_instance()
|
reranker_service = RerankerService.get_reranker_instance()
|
||||||
|
|
@ -46,7 +47,9 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
||||||
# rerank_query = "\n".join(sub_section_questions)
|
# rerank_query = "\n".join(sub_section_questions)
|
||||||
# rerank_query = configuration.user_query
|
# rerank_query = configuration.user_query
|
||||||
|
|
||||||
rerank_query = configuration.user_query + "\n" + "\n".join(sub_section_questions)
|
rerank_query = (
|
||||||
|
configuration.user_query + "\n" + "\n".join(sub_section_questions)
|
||||||
|
)
|
||||||
|
|
||||||
# Convert documents to format expected by reranker if needed
|
# Convert documents to format expected by reranker if needed
|
||||||
reranker_input_docs = [
|
reranker_input_docs = [
|
||||||
|
|
@ -57,28 +60,34 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
||||||
"document": {
|
"document": {
|
||||||
"id": doc.get("document", {}).get("id", ""),
|
"id": doc.get("document", {}).get("id", ""),
|
||||||
"title": doc.get("document", {}).get("title", ""),
|
"title": doc.get("document", {}).get("title", ""),
|
||||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
"document_type": doc.get("document", {}).get(
|
||||||
"metadata": doc.get("document", {}).get("metadata", {})
|
"document_type", ""
|
||||||
|
),
|
||||||
|
"metadata": doc.get("document", {}).get("metadata", {}),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
} for i, doc in enumerate(documents)
|
for i, doc in enumerate(documents)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Rerank documents using the section title
|
# Rerank documents using the section title
|
||||||
reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs)
|
reranked_docs = reranker_service.rerank_documents(
|
||||||
|
rerank_query, reranker_input_docs
|
||||||
|
)
|
||||||
|
|
||||||
# Sort by score in descending order
|
# Sort by score in descending order
|
||||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||||
|
|
||||||
print(f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}")
|
print(
|
||||||
|
f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during reranking: {str(e)}")
|
print(f"Error during reranking: {e!s}")
|
||||||
# Use original docs if reranking fails
|
# Use original docs if reranking fails
|
||||||
|
|
||||||
return {
|
return {"reranked_documents": reranked_docs}
|
||||||
"reranked_documents": reranked_docs
|
|
||||||
}
|
|
||||||
|
|
||||||
async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
|
||||||
|
async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Write the sub-section using the provided documents.
|
Write the sub-section using the provided documents.
|
||||||
|
|
||||||
|
|
@ -118,7 +127,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
||||||
section_position_context_map = {
|
section_position_context_map = {
|
||||||
SubSectionType.START: "This is the INTRODUCTION section.",
|
SubSectionType.START: "This is the INTRODUCTION section.",
|
||||||
SubSectionType.MIDDLE: "This is a MIDDLE section. Ensure this content flows naturally from previous sections and into subsequent ones. This could be any middle section in the document, so maintain coherence with the overall structure while addressing the specific topic of this section. Do not provide any conclusions in this section, as conclusions should only appear in the final section.",
|
SubSectionType.MIDDLE: "This is a MIDDLE section. Ensure this content flows naturally from previous sections and into subsequent ones. This could be any middle section in the document, so maintain coherence with the overall structure while addressing the specific topic of this section. Do not provide any conclusions in this section, as conclusions should only appear in the final section.",
|
||||||
SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure."
|
SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure.",
|
||||||
}
|
}
|
||||||
section_position_context = section_position_context_map.get(sub_section_type, "")
|
section_position_context = section_position_context_map.get(sub_section_type, "")
|
||||||
|
|
||||||
|
|
@ -152,14 +161,15 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
||||||
|
|
||||||
# Use initial system prompt for token calculation
|
# Use initial system prompt for token calculation
|
||||||
initial_system_prompt = get_citation_system_prompt()
|
initial_system_prompt = get_citation_system_prompt()
|
||||||
base_messages = state.chat_history + [
|
base_messages = [
|
||||||
|
*state.chat_history,
|
||||||
SystemMessage(content=initial_system_prompt),
|
SystemMessage(content=initial_system_prompt),
|
||||||
HumanMessage(content=base_human_message_template)
|
HumanMessage(content=base_human_message_template),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Optimize documents to fit within token limits
|
# Optimize documents to fit within token limits
|
||||||
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
|
optimized_documents, has_optimized_documents = (
|
||||||
documents, base_messages, llm.model
|
optimize_documents_for_token_limit(documents, base_messages, llm.model)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update state based on optimization result
|
# Update state based on optimization result
|
||||||
|
|
@ -169,16 +179,22 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
||||||
has_documents = False
|
has_documents = False
|
||||||
|
|
||||||
# Choose system prompt based on final document availability
|
# Choose system prompt based on final document availability
|
||||||
system_prompt = get_citation_system_prompt() if has_documents else get_no_documents_system_prompt()
|
system_prompt = (
|
||||||
|
get_citation_system_prompt()
|
||||||
|
if has_documents
|
||||||
|
else get_no_documents_system_prompt()
|
||||||
|
)
|
||||||
|
|
||||||
# Generate documents section
|
# Generate documents section
|
||||||
documents_text = format_documents_section(documents, "Source material") if has_documents else ""
|
documents_text = (
|
||||||
|
format_documents_section(documents, "Source material") if has_documents else ""
|
||||||
|
)
|
||||||
|
|
||||||
# Create final human message content
|
# Create final human message content
|
||||||
instruction_text = (
|
instruction_text = (
|
||||||
"Please write content for this sub-section using the provided source material and cite all information appropriately."
|
"Please write content for this sub-section using the provided source material and cite all information appropriately."
|
||||||
if has_documents else
|
if has_documents
|
||||||
"Please write content for this sub-section based on our conversation history and your general knowledge."
|
else "Please write content for this sub-section based on our conversation history and your general knowledge."
|
||||||
)
|
)
|
||||||
|
|
||||||
human_message_content = f"""
|
human_message_content = f"""
|
||||||
|
|
@ -206,9 +222,10 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create final messages for the LLM
|
# Create final messages for the LLM
|
||||||
messages_with_chat_history = state.chat_history + [
|
messages_with_chat_history = [
|
||||||
|
*state.chat_history,
|
||||||
SystemMessage(content=system_prompt),
|
SystemMessage(content=system_prompt),
|
||||||
HumanMessage(content=human_message_content)
|
HumanMessage(content=human_message_content),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Log final token count
|
# Log final token count
|
||||||
|
|
@ -219,7 +236,4 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
||||||
response = await llm.ainvoke(messages_with_chat_history)
|
response = await llm.ainvoke(messages_with_chat_history)
|
||||||
final_answer = response.content
|
final_answer = response.content
|
||||||
|
|
||||||
return {
|
return {"final_answer": final_answer}
|
||||||
"final_answer": final_answer
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,11 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class State:
|
class State:
|
||||||
"""Defines the dynamic state for the agent during execution.
|
"""Defines the dynamic state for the agent during execution.
|
||||||
|
|
@ -14,11 +16,11 @@ class State:
|
||||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||||
for more information.
|
for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Runtime context
|
# Runtime context
|
||||||
db_session: AsyncSession
|
db_session: AsyncSession
|
||||||
|
|
||||||
chat_history: Optional[List[Any]] = field(default_factory=list)
|
chat_history: list[Any] | None = field(default_factory=list)
|
||||||
# OUTPUT: Populated by agent nodes
|
# OUTPUT: Populated by agent nodes
|
||||||
reranked_documents: Optional[List[Any]] = None
|
reranked_documents: list[Any] | None = None
|
||||||
final_answer: Optional[str] = None
|
final_answer: str | None = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,33 @@
|
||||||
from typing import List, Dict, Any, Tuple, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from litellm import get_model_info, token_counter
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from litellm import token_counter, get_model_info
|
|
||||||
|
|
||||||
class Section(BaseModel):
|
class Section(BaseModel):
|
||||||
"""A section in the answer outline."""
|
"""A section in the answer outline."""
|
||||||
|
|
||||||
section_id: int = Field(..., description="The zero-based index of the section")
|
section_id: int = Field(..., description="The zero-based index of the section")
|
||||||
section_title: str = Field(..., description="The title of the section")
|
section_title: str = Field(..., description="The title of the section")
|
||||||
questions: List[str] = Field(..., description="Questions to research for this section")
|
questions: list[str] = Field(
|
||||||
|
..., description="Questions to research for this section"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AnswerOutline(BaseModel):
|
class AnswerOutline(BaseModel):
|
||||||
"""The complete answer outline with all sections."""
|
"""The complete answer outline with all sections."""
|
||||||
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
|
||||||
|
answer_outline: list[Section] = Field(
|
||||||
|
..., description="List of sections in the answer outline"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentTokenInfo(NamedTuple):
|
class DocumentTokenInfo(NamedTuple):
|
||||||
"""Information about a document and its token cost."""
|
"""Information about a document and its token cost."""
|
||||||
|
|
||||||
index: int
|
index: int
|
||||||
document: Dict[str, Any]
|
document: dict[str, Any]
|
||||||
formatted_content: str
|
formatted_content: str
|
||||||
token_count: int
|
token_count: int
|
||||||
|
|
||||||
|
|
@ -36,7 +46,7 @@ def get_connector_emoji(connector_name: str) -> str:
|
||||||
"JIRA_CONNECTOR": "🎫",
|
"JIRA_CONNECTOR": "🎫",
|
||||||
"DISCORD_CONNECTOR": "🗨️",
|
"DISCORD_CONNECTOR": "🗨️",
|
||||||
"TAVILY_API": "🔍",
|
"TAVILY_API": "🔍",
|
||||||
"LINKUP_API": "🔗"
|
"LINKUP_API": "🔗",
|
||||||
}
|
}
|
||||||
return connector_emojis.get(connector_name, "🔎")
|
return connector_emojis.get(connector_name, "🔎")
|
||||||
|
|
||||||
|
|
@ -55,31 +65,26 @@ def get_connector_friendly_name(connector_name: str) -> str:
|
||||||
"JIRA_CONNECTOR": "Jira",
|
"JIRA_CONNECTOR": "Jira",
|
||||||
"DISCORD_CONNECTOR": "Discord",
|
"DISCORD_CONNECTOR": "Discord",
|
||||||
"TAVILY_API": "Tavily Search",
|
"TAVILY_API": "Tavily Search",
|
||||||
"LINKUP_API": "Linkup Search"
|
"LINKUP_API": "Linkup Search",
|
||||||
}
|
}
|
||||||
return connector_friendly_names.get(connector_name, connector_name)
|
return connector_friendly_names.get(connector_name, connector_name)
|
||||||
|
|
||||||
|
|
||||||
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
def convert_langchain_messages_to_dict(
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
"""Convert LangChain messages to format expected by token_counter."""
|
"""Convert LangChain messages to format expected by token_counter."""
|
||||||
role_mapping = {
|
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
|
||||||
'system': 'system',
|
|
||||||
'human': 'user',
|
|
||||||
'ai': 'assistant'
|
|
||||||
}
|
|
||||||
|
|
||||||
converted_messages = []
|
converted_messages = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
role = role_mapping.get(getattr(msg, 'type', None), 'user')
|
role = role_mapping.get(getattr(msg, "type", None), "user")
|
||||||
converted_messages.append({
|
converted_messages.append({"role": role, "content": str(msg.content)})
|
||||||
"role": role,
|
|
||||||
"content": str(msg.content)
|
|
||||||
})
|
|
||||||
|
|
||||||
return converted_messages
|
return converted_messages
|
||||||
|
|
||||||
|
|
||||||
def format_document_for_citation(document: Dict[str, Any]) -> str:
|
def format_document_for_citation(document: dict[str, Any]) -> str:
|
||||||
"""Format a single document for citation in the standard XML format."""
|
"""Format a single document for citation in the standard XML format."""
|
||||||
content = document.get("content", "")
|
content = document.get("content", "")
|
||||||
doc_info = document.get("document", {})
|
doc_info = document.get("document", {})
|
||||||
|
|
@ -97,7 +102,9 @@ def format_document_for_citation(document: Dict[str, Any]) -> str:
|
||||||
</document>"""
|
</document>"""
|
||||||
|
|
||||||
|
|
||||||
def format_documents_section(documents: List[Dict[str, Any]], section_title: str = "Source material") -> str:
|
def format_documents_section(
|
||||||
|
documents: list[dict[str, Any]], section_title: str = "Source material"
|
||||||
|
) -> str:
|
||||||
"""Format multiple documents into a complete documents section."""
|
"""Format multiple documents into a complete documents section."""
|
||||||
if not documents:
|
if not documents:
|
||||||
return ""
|
return ""
|
||||||
|
|
@ -110,7 +117,9 @@ def format_documents_section(documents: List[Dict[str, Any]], section_title: str
|
||||||
</documents>"""
|
</documents>"""
|
||||||
|
|
||||||
|
|
||||||
def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) -> List[DocumentTokenInfo]:
|
def calculate_document_token_costs(
|
||||||
|
documents: list[dict[str, Any]], model: str
|
||||||
|
) -> list[DocumentTokenInfo]:
|
||||||
"""Pre-calculate token costs for each document."""
|
"""Pre-calculate token costs for each document."""
|
||||||
document_token_info = []
|
document_token_info = []
|
||||||
|
|
||||||
|
|
@ -119,24 +128,24 @@ def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str)
|
||||||
|
|
||||||
# Calculate token count for this document
|
# Calculate token count for this document
|
||||||
token_count = token_counter(
|
token_count = token_counter(
|
||||||
messages=[{"role": "user", "content": formatted_doc}],
|
messages=[{"role": "user", "content": formatted_doc}], model=model
|
||||||
model=model
|
|
||||||
)
|
)
|
||||||
|
|
||||||
document_token_info.append(DocumentTokenInfo(
|
document_token_info.append(
|
||||||
|
DocumentTokenInfo(
|
||||||
index=i,
|
index=i,
|
||||||
document=doc,
|
document=doc,
|
||||||
formatted_content=formatted_doc,
|
formatted_content=formatted_doc,
|
||||||
token_count=token_count
|
token_count=token_count,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return document_token_info
|
return document_token_info
|
||||||
|
|
||||||
|
|
||||||
def find_optimal_documents_with_binary_search(
|
def find_optimal_documents_with_binary_search(
|
||||||
document_tokens: List[DocumentTokenInfo],
|
document_tokens: list[DocumentTokenInfo], available_tokens: int
|
||||||
available_tokens: int
|
) -> list[DocumentTokenInfo]:
|
||||||
) -> List[DocumentTokenInfo]:
|
|
||||||
"""Use binary search to find the maximum number of documents that fit within token limit."""
|
"""Use binary search to find the maximum number of documents that fit within token limit."""
|
||||||
if not document_tokens or available_tokens <= 0:
|
if not document_tokens or available_tokens <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
@ -147,8 +156,7 @@ def find_optimal_documents_with_binary_search(
|
||||||
while left <= right:
|
while left <= right:
|
||||||
mid = (left + right) // 2
|
mid = (left + right) // 2
|
||||||
current_docs = document_tokens[:mid]
|
current_docs = document_tokens[:mid]
|
||||||
current_token_sum = sum(
|
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
|
||||||
doc_info.token_count for doc_info in current_docs)
|
|
||||||
|
|
||||||
if current_token_sum <= available_tokens:
|
if current_token_sum <= available_tokens:
|
||||||
optimal_docs = current_docs
|
optimal_docs = current_docs
|
||||||
|
|
@ -163,20 +171,18 @@ def get_model_context_window(model_name: str) -> int:
|
||||||
"""Get the total context window size for a model (input + output tokens)."""
|
"""Get the total context window size for a model (input + output tokens)."""
|
||||||
try:
|
try:
|
||||||
model_info = get_model_info(model_name)
|
model_info = get_model_info(model_name)
|
||||||
context_window = model_info.get(
|
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
|
||||||
'max_input_tokens', 4096) # Default fallback
|
|
||||||
return context_window
|
return context_window
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}")
|
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
|
||||||
|
)
|
||||||
return 4096 # Conservative fallback
|
return 4096 # Conservative fallback
|
||||||
|
|
||||||
|
|
||||||
def optimize_documents_for_token_limit(
|
def optimize_documents_for_token_limit(
|
||||||
documents: List[Dict[str, Any]],
|
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
|
||||||
base_messages: List[BaseMessage],
|
) -> tuple[list[dict[str, Any]], bool]:
|
||||||
model_name: str
|
|
||||||
) -> Tuple[List[Dict[str, Any]], bool]:
|
|
||||||
"""
|
"""
|
||||||
Optimize documents to fit within token limits using binary search.
|
Optimize documents to fit within token limits using binary search.
|
||||||
|
|
||||||
|
|
@ -201,7 +207,8 @@ def optimize_documents_for_token_limit(
|
||||||
available_tokens_for_docs = context_window - base_tokens
|
available_tokens_for_docs = context_window - base_tokens
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}")
|
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
|
||||||
|
)
|
||||||
|
|
||||||
if available_tokens_for_docs <= 0:
|
if available_tokens_for_docs <= 0:
|
||||||
print("No tokens available for documents after base content and output buffer")
|
print("No tokens available for documents after base content and output buffer")
|
||||||
|
|
@ -212,8 +219,7 @@ def optimize_documents_for_token_limit(
|
||||||
|
|
||||||
# Find optimal number of documents using binary search
|
# Find optimal number of documents using binary search
|
||||||
optimal_doc_info = find_optimal_documents_with_binary_search(
|
optimal_doc_info = find_optimal_documents_with_binary_search(
|
||||||
document_token_info,
|
document_token_info, available_tokens_for_docs
|
||||||
available_tokens_for_docs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract the original document objects
|
# Extract the original document objects
|
||||||
|
|
@ -221,12 +227,13 @@ def optimize_documents_for_token_limit(
|
||||||
has_documents_remaining = len(optimized_documents) > 0
|
has_documents_remaining = len(optimized_documents) > 0
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents")
|
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
|
||||||
|
)
|
||||||
|
|
||||||
return optimized_documents, has_documents_remaining
|
return optimized_documents, has_documents_remaining
|
||||||
|
|
||||||
|
|
||||||
def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int:
|
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
|
||||||
"""Calculate token count for a list of LangChain messages."""
|
"""Calculate token count for a list of LangChain messages."""
|
||||||
model = model_name
|
model = model_name
|
||||||
messages_dict = convert_langchain_messages_to_dict(messages)
|
messages_dict = convert_langchain_messages_to_dict(messages)
|
||||||
|
|
|
||||||
|
|
@ -2,22 +2,13 @@ from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import User, create_db_and_tables, get_async_session
|
|
||||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
|
||||||
|
|
||||||
|
|
||||||
from app.routes import router as crud_router
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
from app.db import User, create_db_and_tables, get_async_session
|
||||||
from app.users import (
|
from app.routes import router as crud_router
|
||||||
SECRET,
|
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||||
auth_backend,
|
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||||
fastapi_users,
|
|
||||||
current_active_user
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|
@ -64,12 +55,10 @@ app.include_router(
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
from app.users import google_oauth_client
|
from app.users import google_oauth_client
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
fastapi_users.get_oauth_router(
|
fastapi_users.get_oauth_router(
|
||||||
google_oauth_client,
|
google_oauth_client, auth_backend, SECRET, is_verified_by_default=True
|
||||||
auth_backend,
|
|
||||||
SECRET,
|
|
||||||
is_verified_by_default=True
|
|
||||||
),
|
),
|
||||||
prefix="/auth/google",
|
prefix="/auth/google",
|
||||||
tags=["auth"],
|
tags=["auth"],
|
||||||
|
|
@ -79,5 +68,8 @@ app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
|
||||||
|
|
||||||
|
|
||||||
@app.get("/verify-token")
|
@app.get("/verify-token")
|
||||||
async def authenticated_route(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)):
|
async def authenticated_route(
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
return {"message": "Token is valid"}
|
return {"message": "Token is valid"}
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
|
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from rerankers import Reranker
|
from rerankers import Reranker
|
||||||
|
|
||||||
|
|
||||||
# Get the base directory of the project
|
# Get the base directory of the project
|
||||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||||
|
|
||||||
|
|
@ -25,30 +23,30 @@ def is_ffmpeg_installed():
|
||||||
return shutil.which("ffmpeg") is not None
|
return shutil.which("ffmpeg") is not None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
# Check if ffmpeg is installed
|
# Check if ffmpeg is installed
|
||||||
if not is_ffmpeg_installed():
|
if not is_ffmpeg_installed():
|
||||||
import static_ffmpeg
|
import static_ffmpeg
|
||||||
|
|
||||||
# ffmpeg installed on first call to add_paths(), threadsafe.
|
# ffmpeg installed on first call to add_paths(), threadsafe.
|
||||||
static_ffmpeg.add_paths()
|
static_ffmpeg.add_paths()
|
||||||
# check if ffmpeg is installed again
|
# check if ffmpeg is installed again
|
||||||
if not is_ffmpeg_installed():
|
if not is_ffmpeg_installed():
|
||||||
raise ValueError("FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster.")
|
raise ValueError(
|
||||||
|
"FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster."
|
||||||
|
)
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||||
|
|
||||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
||||||
|
|
||||||
|
|
||||||
# AUTH: Google OAuth
|
# AUTH: Google OAuth
|
||||||
AUTH_TYPE = os.getenv("AUTH_TYPE")
|
AUTH_TYPE = os.getenv("AUTH_TYPE")
|
||||||
if AUTH_TYPE == "GOOGLE":
|
if AUTH_TYPE == "GOOGLE":
|
||||||
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||||
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||||
|
|
||||||
|
|
||||||
# LLM instances are now managed per-user through the LLMConfig system
|
# LLM instances are now managed per-user through the LLMConfig system
|
||||||
# Legacy environment variables removed in favor of user-specific configurations
|
# Legacy environment variables removed in favor of user-specific configurations
|
||||||
|
|
||||||
|
|
@ -56,10 +54,10 @@ class Config:
|
||||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||||
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
|
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
|
||||||
chunker_instance = RecursiveChunker(
|
chunker_instance = RecursiveChunker(
|
||||||
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512)
|
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
|
||||||
)
|
)
|
||||||
code_chunker_instance = CodeChunker(
|
code_chunker_instance = CodeChunker(
|
||||||
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512)
|
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
|
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
|
||||||
|
|
@ -97,17 +95,18 @@ class Config:
|
||||||
STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE")
|
STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE")
|
||||||
STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY")
|
STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
# Validation Checks
|
# Validation Checks
|
||||||
# Check embedding dimension
|
# Check embedding dimension
|
||||||
if hasattr(embedding_model_instance, 'dimension') and embedding_model_instance.dimension > 2000:
|
if (
|
||||||
|
hasattr(embedding_model_instance, "dimension")
|
||||||
|
and embedding_model_instance.dimension > 2000
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Embedding dimension for Model: {EMBEDDING_MODEL} "
|
f"Embedding dimension for Model: {EMBEDDING_MODEL} "
|
||||||
f"has {embedding_model_instance.dimension} dimensions, which "
|
f"has {embedding_model_instance.dimension} dimensions, which "
|
||||||
f"exceeds the maximum of 2000 allowed by PGVector."
|
f"exceeds the maximum of 2000 allowed by PGVector."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_settings(cls):
|
def get_settings(cls):
|
||||||
"""Get all settings as a dictionary."""
|
"""Get all settings as a dictionary."""
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,25 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def _parse_bool(value):
|
def _parse_bool(value):
|
||||||
"""Parse boolean value from string."""
|
"""Parse boolean value from string."""
|
||||||
return value.lower() == "true" if value else False
|
return value.lower() == "true" if value else False
|
||||||
|
|
||||||
|
|
||||||
def _parse_int(value, var_name):
|
def _parse_int(value, var_name):
|
||||||
"""Parse integer value with error handling."""
|
"""Parse integer value with error handling."""
|
||||||
try:
|
try:
|
||||||
return int(value)
|
return int(value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"Invalid integer value for {var_name}: {value}")
|
raise ValueError(f"Invalid integer value for {var_name}: {value}") from None
|
||||||
|
|
||||||
|
|
||||||
def _parse_headers(value):
|
def _parse_headers(value):
|
||||||
"""Parse headers from comma-separated string."""
|
"""Parse headers from comma-separated string."""
|
||||||
try:
|
try:
|
||||||
return [
|
return [tuple(h.split(":", 1)) for h in value.split(",") if ":" in h]
|
||||||
tuple(h.split(":", 1))
|
|
||||||
for h in value.split(",")
|
|
||||||
if ":" in h
|
|
||||||
]
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(f"Invalid headers format: {value}")
|
raise ValueError(f"Invalid headers format: {value}") from None
|
||||||
|
|
||||||
|
|
||||||
def load_uvicorn_config(args=None):
|
def load_uvicorn_config(args=None):
|
||||||
|
|
@ -28,14 +27,14 @@ def load_uvicorn_config(args=None):
|
||||||
Load Uvicorn configuration from environment variables and CLI args.
|
Load Uvicorn configuration from environment variables and CLI args.
|
||||||
Returns a dict suitable for passing to uvicorn.Config.
|
Returns a dict suitable for passing to uvicorn.Config.
|
||||||
"""
|
"""
|
||||||
config_kwargs = dict(
|
config_kwargs = {
|
||||||
app="app.app:app",
|
"app": "app.app:app",
|
||||||
host=os.getenv("UVICORN_HOST", "0.0.0.0"),
|
"host": os.getenv("UVICORN_HOST", "0.0.0.0"),
|
||||||
port=int(os.getenv("UVICORN_PORT", 8000)),
|
"port": int(os.getenv("UVICORN_PORT", 8000)),
|
||||||
log_level=os.getenv("UVICORN_LOG_LEVEL", "info"),
|
"log_level": os.getenv("UVICORN_LOG_LEVEL", "info"),
|
||||||
reload=args.reload if args else False,
|
"reload": args.reload if args else False,
|
||||||
reload_dirs=["app"] if (args and args.reload) else None,
|
"reload_dirs": ["app"] if (args and args.reload) else None,
|
||||||
)
|
}
|
||||||
|
|
||||||
# Configuration mapping for advanced options
|
# Configuration mapping for advanced options
|
||||||
config_mapping = {
|
config_mapping = {
|
||||||
|
|
@ -51,15 +50,33 @@ def load_uvicorn_config(args=None):
|
||||||
"UVICORN_LOG_CONFIG": ("log_config", str),
|
"UVICORN_LOG_CONFIG": ("log_config", str),
|
||||||
"UVICORN_SERVER_HEADER": ("server_header", _parse_bool),
|
"UVICORN_SERVER_HEADER": ("server_header", _parse_bool),
|
||||||
"UVICORN_DATE_HEADER": ("date_header", _parse_bool),
|
"UVICORN_DATE_HEADER": ("date_header", _parse_bool),
|
||||||
"UVICORN_LIMIT_CONCURRENCY": ("limit_concurrency", lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY")),
|
"UVICORN_LIMIT_CONCURRENCY": (
|
||||||
"UVICORN_LIMIT_MAX_REQUESTS": ("limit_max_requests", lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS")),
|
"limit_concurrency",
|
||||||
"UVICORN_TIMEOUT_KEEP_ALIVE": ("timeout_keep_alive", lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE")),
|
lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY"),
|
||||||
"UVICORN_TIMEOUT_NOTIFY": ("timeout_notify", lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY")),
|
),
|
||||||
|
"UVICORN_LIMIT_MAX_REQUESTS": (
|
||||||
|
"limit_max_requests",
|
||||||
|
lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS"),
|
||||||
|
),
|
||||||
|
"UVICORN_TIMEOUT_KEEP_ALIVE": (
|
||||||
|
"timeout_keep_alive",
|
||||||
|
lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE"),
|
||||||
|
),
|
||||||
|
"UVICORN_TIMEOUT_NOTIFY": (
|
||||||
|
"timeout_notify",
|
||||||
|
lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY"),
|
||||||
|
),
|
||||||
"UVICORN_SSL_KEYFILE": ("ssl_keyfile", str),
|
"UVICORN_SSL_KEYFILE": ("ssl_keyfile", str),
|
||||||
"UVICORN_SSL_CERTFILE": ("ssl_certfile", str),
|
"UVICORN_SSL_CERTFILE": ("ssl_certfile", str),
|
||||||
"UVICORN_SSL_KEYFILE_PASSWORD": ("ssl_keyfile_password", str),
|
"UVICORN_SSL_KEYFILE_PASSWORD": ("ssl_keyfile_password", str),
|
||||||
"UVICORN_SSL_VERSION": ("ssl_version", lambda x: _parse_int(x, "UVICORN_SSL_VERSION")),
|
"UVICORN_SSL_VERSION": (
|
||||||
"UVICORN_SSL_CERT_REQS": ("ssl_cert_reqs", lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS")),
|
"ssl_version",
|
||||||
|
lambda x: _parse_int(x, "UVICORN_SSL_VERSION"),
|
||||||
|
),
|
||||||
|
"UVICORN_SSL_CERT_REQS": (
|
||||||
|
"ssl_cert_reqs",
|
||||||
|
lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS"),
|
||||||
|
),
|
||||||
"UVICORN_SSL_CA_CERTS": ("ssl_ca_certs", str),
|
"UVICORN_SSL_CA_CERTS": ("ssl_ca_certs", str),
|
||||||
"UVICORN_SSL_CIPHERS": ("ssl_ciphers", str),
|
"UVICORN_SSL_CIPHERS": ("ssl_ciphers", str),
|
||||||
"UVICORN_HEADERS": ("headers", _parse_headers),
|
"UVICORN_HEADERS": ("headers", _parse_headers),
|
||||||
|
|
@ -76,7 +93,6 @@ def load_uvicorn_config(args=None):
|
||||||
try:
|
try:
|
||||||
config_kwargs[config_key] = parser(value)
|
config_kwargs[config_key] = parser(value)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(f"Configuration error for {env_var}: {e}")
|
raise ValueError(f"Configuration error for {env_var}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
return config_kwargs
|
return config_kwargs
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,12 @@ A module for interacting with Discord's HTTP API to retrieve guilds, channels, a
|
||||||
Requires a Discord bot token.
|
Requires a Discord bot token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
import datetime
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||||
class DiscordConnector(commands.Bot):
|
class DiscordConnector(commands.Bot):
|
||||||
"""Class for retrieving guild, channel, and message history from Discord."""
|
"""Class for retrieving guild, channel, and message history from Discord."""
|
||||||
|
|
||||||
def __init__(self, token: str = None):
|
def __init__(self, token: str | None = None):
|
||||||
"""
|
"""
|
||||||
Initialize the DiscordConnector with a bot token.
|
Initialize the DiscordConnector with a bot token.
|
||||||
|
|
||||||
|
|
@ -30,7 +31,9 @@ class DiscordConnector(commands.Bot):
|
||||||
intents.messages = True # Required to fetch messages
|
intents.messages = True # Required to fetch messages
|
||||||
intents.message_content = True # Required to read message content
|
intents.message_content = True # Required to read message content
|
||||||
intents.members = True # Required to fetch member information
|
intents.members = True # Required to fetch member information
|
||||||
super().__init__(command_prefix="!", intents=intents) # command_prefix is required but not strictly used here
|
super().__init__(
|
||||||
|
command_prefix="!", intents=intents
|
||||||
|
) # command_prefix is required but not strictly used here
|
||||||
self.token = token
|
self.token = token
|
||||||
self._bot_task = None # Holds the async bot task
|
self._bot_task = None # Holds the async bot task
|
||||||
self._is_running = False # Flag to track if the bot is running
|
self._is_running = False # Flag to track if the bot is running
|
||||||
|
|
@ -63,17 +66,23 @@ class DiscordConnector(commands.Bot):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self._is_running:
|
if self._is_running:
|
||||||
logger.warning("Bot is already running. Use close_bot() to stop it before starting again.")
|
logger.warning(
|
||||||
|
"Bot is already running. Use close_bot() to stop it before starting again."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.start(self.token)
|
await self.start(self.token)
|
||||||
logger.info("Discord bot started successfully.")
|
logger.info("Discord bot started successfully.")
|
||||||
except discord.LoginFailure:
|
except discord.LoginFailure:
|
||||||
logger.error("Failed to log in: Invalid token was provided. Please check your bot token.")
|
logger.error(
|
||||||
|
"Failed to log in: Invalid token was provided. Please check your bot token."
|
||||||
|
)
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
raise
|
raise
|
||||||
except discord.PrivilegedIntentsRequired as e:
|
except discord.PrivilegedIntentsRequired as e:
|
||||||
logger.error(f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page.")
|
logger.error(
|
||||||
|
f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page."
|
||||||
|
)
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
raise
|
raise
|
||||||
except discord.ConnectionClosed as e:
|
except discord.ConnectionClosed as e:
|
||||||
|
|
@ -96,7 +105,6 @@ class DiscordConnector(commands.Bot):
|
||||||
else:
|
else:
|
||||||
logger.info("Bot is not running or already disconnected.")
|
logger.info("Bot is not running or already disconnected.")
|
||||||
|
|
||||||
|
|
||||||
def set_token(self, token: str) -> None:
|
def set_token(self, token: str) -> None:
|
||||||
"""
|
"""
|
||||||
Set the discord bot token.
|
Set the discord bot token.
|
||||||
|
|
@ -106,7 +114,9 @@ class DiscordConnector(commands.Bot):
|
||||||
"""
|
"""
|
||||||
logger.info("Setting Discord bot token.")
|
logger.info("Setting Discord bot token.")
|
||||||
self.token = token
|
self.token = token
|
||||||
logger.info("Token set successfully. You can now start the bot with start_bot().")
|
logger.info(
|
||||||
|
"Token set successfully. You can now start the bot with start_bot()."
|
||||||
|
)
|
||||||
|
|
||||||
async def _wait_until_ready(self):
|
async def _wait_until_ready(self):
|
||||||
"""Helper to wait until the bot is connected and ready."""
|
"""Helper to wait until the bot is connected and ready."""
|
||||||
|
|
@ -120,11 +130,15 @@ class DiscordConnector(commands.Bot):
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.wait_until_ready(), timeout=60.0)
|
await asyncio.wait_for(self.wait_until_ready(), timeout=60.0)
|
||||||
logger.info("Bot is ready.")
|
logger.info("Bot is ready.")
|
||||||
except asyncio.TimeoutError:
|
except TimeoutError:
|
||||||
logger.error(f"Bot did not become ready within 60 seconds. Connection may have failed.")
|
logger.error(
|
||||||
|
"Bot did not become ready within 60 seconds. Connection may have failed."
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An unexpected error occurred while waiting for the bot to be ready: {e}")
|
logger.error(
|
||||||
|
f"An unexpected error occurred while waiting for the bot to be ready: {e}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_guilds(self) -> list[dict]:
|
async def get_guilds(self) -> list[dict]:
|
||||||
|
|
@ -143,7 +157,9 @@ class DiscordConnector(commands.Bot):
|
||||||
|
|
||||||
guilds_data = []
|
guilds_data = []
|
||||||
for guild in self.guilds:
|
for guild in self.guilds:
|
||||||
member_count = guild.member_count if guild.member_count is not None else "N/A"
|
member_count = (
|
||||||
|
guild.member_count if guild.member_count is not None else "N/A"
|
||||||
|
)
|
||||||
guilds_data.append(
|
guilds_data.append(
|
||||||
{
|
{
|
||||||
"id": str(guild.id),
|
"id": str(guild.id),
|
||||||
|
|
@ -184,14 +200,16 @@ class DiscordConnector(commands.Bot):
|
||||||
{"id": str(channel.id), "name": channel.name, "type": "text"}
|
{"id": str(channel.id), "name": channel.name, "type": "text"}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Fetched {len(channels_data)} text channels from guild {guild_id}.")
|
logger.info(
|
||||||
|
f"Fetched {len(channels_data)} text channels from guild {guild_id}."
|
||||||
|
)
|
||||||
return channels_data
|
return channels_data
|
||||||
|
|
||||||
async def get_channel_history(
|
async def get_channel_history(
|
||||||
self,
|
self,
|
||||||
channel_id: str,
|
channel_id: str,
|
||||||
start_date: str = None,
|
start_date: str | None = None,
|
||||||
end_date: str = None,
|
end_date: str | None = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Fetch message history from a text channel.
|
Fetch message history from a text channel.
|
||||||
|
|
@ -227,20 +245,26 @@ class DiscordConnector(commands.Bot):
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
try:
|
try:
|
||||||
start_datetime = datetime.datetime.fromisoformat(start_date).replace(tzinfo=datetime.timezone.utc)
|
start_datetime = datetime.datetime.fromisoformat(start_date).replace(
|
||||||
|
tzinfo=datetime.UTC
|
||||||
|
)
|
||||||
after = start_datetime
|
after = start_datetime
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(f"Invalid start_date format: {start_date}. Ignoring.")
|
logger.warning(f"Invalid start_date format: {start_date}. Ignoring.")
|
||||||
|
|
||||||
if end_date:
|
if end_date:
|
||||||
try:
|
try:
|
||||||
end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(tzinfo=datetime.timezone.utc)
|
end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(
|
||||||
|
tzinfo=datetime.UTC
|
||||||
|
)
|
||||||
before = end_datetime
|
before = end_datetime
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(f"Invalid end_date format: {end_date}. Ignoring.")
|
logger.warning(f"Invalid end_date format: {end_date}. Ignoring.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for message in channel.history(limit=None, before=before, after=after):
|
async for message in channel.history(
|
||||||
|
limit=None, before=before, after=after
|
||||||
|
):
|
||||||
messages_data.append(
|
messages_data.append(
|
||||||
{
|
{
|
||||||
"id": str(message.id),
|
"id": str(message.id),
|
||||||
|
|
@ -251,7 +275,9 @@ class DiscordConnector(commands.Bot):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except discord.Forbidden:
|
except discord.Forbidden:
|
||||||
logger.error(f"Bot does not have permissions to read message history in channel {channel_id}.")
|
logger.error(
|
||||||
|
f"Bot does not have permissions to read message history in channel {channel_id}."
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
except discord.HTTPException as e:
|
except discord.HTTPException as e:
|
||||||
logger.error(f"Failed to fetch messages from channel {channel_id}: {e}")
|
logger.error(f"Failed to fetch messages from channel {channel_id}: {e}")
|
||||||
|
|
@ -278,7 +304,9 @@ class DiscordConnector(commands.Bot):
|
||||||
permissions to view members.
|
permissions to view members.
|
||||||
"""
|
"""
|
||||||
await self._wait_until_ready()
|
await self._wait_until_ready()
|
||||||
logger.info(f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}")
|
logger.info(
|
||||||
|
f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}"
|
||||||
|
)
|
||||||
|
|
||||||
guild = self.get_guild(int(guild_id))
|
guild = self.get_guild(int(guild_id))
|
||||||
if not guild:
|
if not guild:
|
||||||
|
|
@ -294,7 +322,9 @@ class DiscordConnector(commands.Bot):
|
||||||
return {
|
return {
|
||||||
"id": str(member.id),
|
"id": str(member.id),
|
||||||
"name": member.name,
|
"name": member.name,
|
||||||
"joined_at": member.joined_at.isoformat() if member.joined_at else None,
|
"joined_at": member.joined_at.isoformat()
|
||||||
|
if member.joined_at
|
||||||
|
else None,
|
||||||
"roles": roles,
|
"roles": roles,
|
||||||
}
|
}
|
||||||
logger.warning(f"User {user_id} not found in guild {guild_id}.")
|
logger.warning(f"User {user_id} not found in guild {guild_id}.")
|
||||||
|
|
@ -303,8 +333,12 @@ class DiscordConnector(commands.Bot):
|
||||||
logger.warning(f"User {user_id} not found in guild {guild_id}.")
|
logger.warning(f"User {user_id} not found in guild {guild_id}.")
|
||||||
return None
|
return None
|
||||||
except discord.Forbidden:
|
except discord.Forbidden:
|
||||||
logger.error(f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled.")
|
logger.error(
|
||||||
|
f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled."
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
except discord.HTTPException as e:
|
except discord.HTTPException as e:
|
||||||
logger.error(f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}")
|
logger.error(
|
||||||
|
f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1,54 +1,91 @@
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import Any
|
||||||
from github3 import login as github_login, exceptions as github_exceptions
|
|
||||||
from github3.repos.contents import Contents
|
from github3 import exceptions as github_exceptions, login as github_login
|
||||||
from github3.exceptions import ForbiddenError, NotFoundError
|
from github3.exceptions import ForbiddenError, NotFoundError
|
||||||
|
from github3.repos.contents import Contents
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# List of common code file extensions to target
|
# List of common code file extensions to target
|
||||||
CODE_EXTENSIONS = {
|
CODE_EXTENSIONS = {
|
||||||
'.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp',
|
".py",
|
||||||
'.cs', '.go', '.rb', '.php', '.swift', '.kt', '.scala', '.rs', '.m',
|
".js",
|
||||||
'.sh', '.bash', '.ps1', '.lua', '.pl', '.pm', '.r', '.dart', '.sql'
|
".jsx",
|
||||||
|
".ts",
|
||||||
|
".tsx",
|
||||||
|
".java",
|
||||||
|
".c",
|
||||||
|
".cpp",
|
||||||
|
".h",
|
||||||
|
".hpp",
|
||||||
|
".cs",
|
||||||
|
".go",
|
||||||
|
".rb",
|
||||||
|
".php",
|
||||||
|
".swift",
|
||||||
|
".kt",
|
||||||
|
".scala",
|
||||||
|
".rs",
|
||||||
|
".m",
|
||||||
|
".sh",
|
||||||
|
".bash",
|
||||||
|
".ps1",
|
||||||
|
".lua",
|
||||||
|
".pl",
|
||||||
|
".pm",
|
||||||
|
".r",
|
||||||
|
".dart",
|
||||||
|
".sql",
|
||||||
}
|
}
|
||||||
|
|
||||||
# List of common documentation/text file extensions
|
# List of common documentation/text file extensions
|
||||||
DOC_EXTENSIONS = {
|
DOC_EXTENSIONS = {
|
||||||
'.md', '.txt', '.rst', '.adoc', '.html', '.htm', '.xml', '.json', '.yaml', '.yml', '.toml'
|
".md",
|
||||||
|
".txt",
|
||||||
|
".rst",
|
||||||
|
".adoc",
|
||||||
|
".html",
|
||||||
|
".htm",
|
||||||
|
".xml",
|
||||||
|
".json",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
".toml",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Maximum file size in bytes (e.g., 1MB)
|
# Maximum file size in bytes (e.g., 1MB)
|
||||||
MAX_FILE_SIZE = 1 * 1024 * 1024
|
MAX_FILE_SIZE = 1 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
class GitHubConnector:
|
class GitHubConnector:
|
||||||
"""Connector for interacting with the GitHub API."""
|
"""Connector for interacting with the GitHub API."""
|
||||||
|
|
||||||
# Directories to skip during file traversal
|
# Directories to skip during file traversal
|
||||||
SKIPPED_DIRS = {
|
SKIPPED_DIRS = {
|
||||||
# Version control
|
# Version control
|
||||||
'.git',
|
".git",
|
||||||
# Dependencies
|
# Dependencies
|
||||||
'node_modules',
|
"node_modules",
|
||||||
'vendor',
|
"vendor",
|
||||||
# Build artifacts / Caches
|
# Build artifacts / Caches
|
||||||
'build',
|
"build",
|
||||||
'dist',
|
"dist",
|
||||||
'target',
|
"target",
|
||||||
'__pycache__',
|
"__pycache__",
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
'venv',
|
"venv",
|
||||||
'.venv',
|
".venv",
|
||||||
'env',
|
"env",
|
||||||
# IDE/Editor config
|
# IDE/Editor config
|
||||||
'.vscode',
|
".vscode",
|
||||||
'.idea',
|
".idea",
|
||||||
'.project',
|
".project",
|
||||||
'.settings',
|
".settings",
|
||||||
# Temporary / Logs
|
# Temporary / Logs
|
||||||
'tmp',
|
"tmp",
|
||||||
'logs',
|
"logs",
|
||||||
# Add other project-specific irrelevant directories if needed
|
# Add other project-specific irrelevant directories if needed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -68,20 +105,21 @@ class GitHubConnector:
|
||||||
logger.info("Successfully authenticated with GitHub API.")
|
logger.info("Successfully authenticated with GitHub API.")
|
||||||
except (github_exceptions.AuthenticationFailed, ForbiddenError) as e:
|
except (github_exceptions.AuthenticationFailed, ForbiddenError) as e:
|
||||||
logger.error(f"GitHub authentication failed: {e}")
|
logger.error(f"GitHub authentication failed: {e}")
|
||||||
raise ValueError("Invalid GitHub token or insufficient permissions.")
|
raise ValueError("Invalid GitHub token or insufficient permissions.") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize GitHub client: {e}")
|
logger.error(f"Failed to initialize GitHub client: {e}")
|
||||||
raise
|
raise e
|
||||||
|
|
||||||
def get_user_repositories(self) -> List[Dict[str, Any]]:
|
def get_user_repositories(self) -> list[dict[str, Any]]:
|
||||||
"""Fetches repositories accessible by the authenticated user."""
|
"""Fetches repositories accessible by the authenticated user."""
|
||||||
repos_data = []
|
repos_data = []
|
||||||
try:
|
try:
|
||||||
# type='owner' fetches repos owned by the user
|
# type='owner' fetches repos owned by the user
|
||||||
# type='member' fetches repos the user is a collaborator on (including orgs)
|
# type='member' fetches repos the user is a collaborator on (including orgs)
|
||||||
# type='all' fetches both
|
# type='all' fetches both
|
||||||
for repo in self.gh.repositories(type='all', sort='updated'):
|
for repo in self.gh.repositories(type="all", sort="updated"):
|
||||||
repos_data.append({
|
repos_data.append(
|
||||||
|
{
|
||||||
"id": repo.id,
|
"id": repo.id,
|
||||||
"name": repo.name,
|
"name": repo.name,
|
||||||
"full_name": repo.full_name,
|
"full_name": repo.full_name,
|
||||||
|
|
@ -89,14 +127,17 @@ class GitHubConnector:
|
||||||
"url": repo.html_url,
|
"url": repo.html_url,
|
||||||
"description": repo.description or "",
|
"description": repo.description or "",
|
||||||
"last_updated": repo.updated_at if repo.updated_at else None,
|
"last_updated": repo.updated_at if repo.updated_at else None,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
logger.info(f"Fetched {len(repos_data)} repositories.")
|
logger.info(f"Fetched {len(repos_data)} repositories.")
|
||||||
return repos_data
|
return repos_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to fetch GitHub repositories: {e}")
|
logger.error(f"Failed to fetch GitHub repositories: {e}")
|
||||||
return [] # Return empty list on error
|
return [] # Return empty list on error
|
||||||
|
|
||||||
def get_repository_files(self, repo_full_name: str, path: str = '') -> List[Dict[str, Any]]:
|
def get_repository_files(
|
||||||
|
self, repo_full_name: str, path: str = ""
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Recursively fetches details of relevant files (code, docs) within a repository path.
|
Recursively fetches details of relevant files (code, docs) within a repository path.
|
||||||
|
|
||||||
|
|
@ -110,54 +151,72 @@ class GitHubConnector:
|
||||||
"""
|
"""
|
||||||
files_list = []
|
files_list = []
|
||||||
try:
|
try:
|
||||||
owner, repo_name = repo_full_name.split('/')
|
owner, repo_name = repo_full_name.split("/")
|
||||||
repo = self.gh.repository(owner, repo_name)
|
repo = self.gh.repository(owner, repo_name)
|
||||||
if not repo:
|
if not repo:
|
||||||
logger.warning(f"Repository '{repo_full_name}' not found.")
|
logger.warning(f"Repository '{repo_full_name}' not found.")
|
||||||
return []
|
return []
|
||||||
contents = repo.directory_contents(directory_path=path) # Use directory_contents for clarity
|
contents = repo.directory_contents(
|
||||||
|
directory_path=path
|
||||||
|
) # Use directory_contents for clarity
|
||||||
|
|
||||||
# contents returns a list of tuples (name, content_obj)
|
# contents returns a list of tuples (name, content_obj)
|
||||||
for item_name, content_item in contents:
|
for _item_name, content_item in contents:
|
||||||
if not isinstance(content_item, Contents):
|
if not isinstance(content_item, Contents):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if content_item.type == 'dir':
|
if content_item.type == "dir":
|
||||||
# Check if the directory name is in the skipped list
|
# Check if the directory name is in the skipped list
|
||||||
if content_item.name in self.SKIPPED_DIRS:
|
if content_item.name in self.SKIPPED_DIRS:
|
||||||
logger.debug(f"Skipping directory: {content_item.path}")
|
logger.debug(f"Skipping directory: {content_item.path}")
|
||||||
continue # Skip recursion for this directory
|
continue # Skip recursion for this directory
|
||||||
|
|
||||||
# Recursively fetch contents of subdirectory
|
# Recursively fetch contents of subdirectory
|
||||||
files_list.extend(self.get_repository_files(repo_full_name, path=content_item.path))
|
files_list.extend(
|
||||||
elif content_item.type == 'file':
|
self.get_repository_files(
|
||||||
|
repo_full_name, path=content_item.path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif content_item.type == "file":
|
||||||
# Check if the file extension is relevant and size is within limits
|
# Check if the file extension is relevant and size is within limits
|
||||||
file_extension = '.' + content_item.name.split('.')[-1].lower() if '.' in content_item.name else ''
|
file_extension = (
|
||||||
|
"." + content_item.name.split(".")[-1].lower()
|
||||||
|
if "." in content_item.name
|
||||||
|
else ""
|
||||||
|
)
|
||||||
is_code = file_extension in CODE_EXTENSIONS
|
is_code = file_extension in CODE_EXTENSIONS
|
||||||
is_doc = file_extension in DOC_EXTENSIONS
|
is_doc = file_extension in DOC_EXTENSIONS
|
||||||
|
|
||||||
if (is_code or is_doc) and content_item.size <= MAX_FILE_SIZE:
|
if (is_code or is_doc) and content_item.size <= MAX_FILE_SIZE:
|
||||||
files_list.append({
|
files_list.append(
|
||||||
|
{
|
||||||
"path": content_item.path,
|
"path": content_item.path,
|
||||||
"sha": content_item.sha,
|
"sha": content_item.sha,
|
||||||
"url": content_item.html_url,
|
"url": content_item.html_url,
|
||||||
"size": content_item.size,
|
"size": content_item.size,
|
||||||
"type": "code" if is_code else "doc"
|
"type": "code" if is_code else "doc",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
elif content_item.size > MAX_FILE_SIZE:
|
elif content_item.size > MAX_FILE_SIZE:
|
||||||
logger.debug(f"Skipping large file: {content_item.path} ({content_item.size} bytes)")
|
logger.debug(
|
||||||
|
f"Skipping large file: {content_item.path} ({content_item.size} bytes)"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Skipping irrelevant file type: {content_item.path}")
|
logger.debug(
|
||||||
|
f"Skipping irrelevant file type: {content_item.path}"
|
||||||
|
)
|
||||||
|
|
||||||
except (NotFoundError, ForbiddenError) as e:
|
except (NotFoundError, ForbiddenError) as e:
|
||||||
logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}")
|
logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get files for {repo_full_name} at path '{path}': {e}")
|
logger.error(
|
||||||
|
f"Failed to get files for {repo_full_name} at path '{path}': {e}"
|
||||||
|
)
|
||||||
# Return what we have collected so far in case of partial failure
|
# Return what we have collected so far in case of partial failure
|
||||||
|
|
||||||
return files_list
|
return files_list
|
||||||
|
|
||||||
def get_file_content(self, repo_full_name: str, file_path: str) -> Optional[str]:
|
def get_file_content(self, repo_full_name: str, file_path: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
Fetches the decoded content of a specific file.
|
Fetches the decoded content of a specific file.
|
||||||
|
|
||||||
|
|
@ -169,43 +228,69 @@ class GitHubConnector:
|
||||||
The decoded file content as a string, or None if fetching fails or file is too large.
|
The decoded file content as a string, or None if fetching fails or file is too large.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
owner, repo_name = repo_full_name.split('/')
|
owner, repo_name = repo_full_name.split("/")
|
||||||
repo = self.gh.repository(owner, repo_name)
|
repo = self.gh.repository(owner, repo_name)
|
||||||
if not repo:
|
if not repo:
|
||||||
logger.warning(f"Repository '{repo_full_name}' not found when fetching file '{file_path}'.")
|
logger.warning(
|
||||||
|
f"Repository '{repo_full_name}' not found when fetching file '{file_path}'."
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
content_item = repo.file_contents(path=file_path) # Use file_contents for clarity
|
content_item = repo.file_contents(
|
||||||
|
path=file_path
|
||||||
|
) # Use file_contents for clarity
|
||||||
|
|
||||||
if not content_item or not isinstance(content_item, Contents) or content_item.type != 'file':
|
if (
|
||||||
logger.warning(f"File '{file_path}' not found or is not a file in '{repo_full_name}'.")
|
not content_item
|
||||||
|
or not isinstance(content_item, Contents)
|
||||||
|
or content_item.type != "file"
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"File '{file_path}' not found or is not a file in '{repo_full_name}'."
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if content_item.size > MAX_FILE_SIZE:
|
if content_item.size > MAX_FILE_SIZE:
|
||||||
logger.warning(f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch.")
|
logger.warning(
|
||||||
|
f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch."
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Content is base64 encoded
|
# Content is base64 encoded
|
||||||
if content_item.content:
|
if content_item.content:
|
||||||
try:
|
try:
|
||||||
decoded_content = base64.b64decode(content_item.content).decode('utf-8')
|
decoded_content = base64.b64decode(content_item.content).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
return decoded_content
|
return decoded_content
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
logger.warning(f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'.")
|
logger.warning(
|
||||||
|
f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'."
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# Try a fallback encoding
|
# Try a fallback encoding
|
||||||
decoded_content = base64.b64decode(content_item.content).decode('latin-1')
|
decoded_content = base64.b64decode(content_item.content).decode(
|
||||||
|
"latin-1"
|
||||||
|
)
|
||||||
return decoded_content
|
return decoded_content
|
||||||
except Exception as decode_err:
|
except Exception as decode_err:
|
||||||
logger.error(f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}")
|
logger.error(
|
||||||
|
f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}"
|
||||||
|
)
|
||||||
return None # Give up if fallback fails
|
return None # Give up if fallback fails
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty.")
|
logger.warning(
|
||||||
|
f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty."
|
||||||
|
)
|
||||||
return "" # Return empty string for empty files
|
return "" # Return empty string for empty files
|
||||||
|
|
||||||
except (NotFoundError, ForbiddenError) as e:
|
except (NotFoundError, ForbiddenError) as e:
|
||||||
logger.warning(f"Cannot access file '{file_path}' in '{repo_full_name}': {e}")
|
logger.warning(
|
||||||
|
f"Cannot access file '{file_path}' in '{repo_full_name}': {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}")
|
logger.error(
|
||||||
|
f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ Allows fetching issue lists and their comments, projects and more.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
@ -17,9 +17,9 @@ class JiraConnector:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: Optional[str] = None,
|
base_url: str | None = None,
|
||||||
email: Optional[str] = None,
|
email: str | None = None,
|
||||||
api_token: Optional[str] = None,
|
api_token: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the JiraConnector class.
|
Initialize the JiraConnector class.
|
||||||
|
|
@ -65,7 +65,7 @@ class JiraConnector:
|
||||||
"""
|
"""
|
||||||
self.api_token = api_token
|
self.api_token = api_token
|
||||||
|
|
||||||
def get_headers(self) -> Dict[str, str]:
|
def get_headers(self) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Get headers for Jira API requests using Basic Authentication.
|
Get headers for Jira API requests using Basic Authentication.
|
||||||
|
|
||||||
|
|
@ -92,8 +92,8 @@ class JiraConnector:
|
||||||
}
|
}
|
||||||
|
|
||||||
def make_api_request(
|
def make_api_request(
|
||||||
self, endpoint: str, params: Optional[Dict[str, Any]] = None
|
self, endpoint: str, params: dict[str, Any] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Make a request to the Jira API.
|
Make a request to the Jira API.
|
||||||
|
|
||||||
|
|
@ -138,7 +138,7 @@ class JiraConnector:
|
||||||
"""
|
"""
|
||||||
return self.make_api_request("project/search")
|
return self.make_api_request("project/search")
|
||||||
|
|
||||||
def get_all_issues(self, project_key: Optional[str] = None) -> List[Dict[str, Any]]:
|
def get_all_issues(self, project_key: str | None = None) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch all issues from Jira.
|
Fetch all issues from Jira.
|
||||||
|
|
||||||
|
|
@ -204,8 +204,8 @@ class JiraConnector:
|
||||||
start_date: str,
|
start_date: str,
|
||||||
end_date: str,
|
end_date: str,
|
||||||
include_comments: bool = True,
|
include_comments: bool = True,
|
||||||
project_key: Optional[str] = None,
|
project_key: str | None = None,
|
||||||
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
) -> tuple[list[dict[str, Any]], str | None]:
|
||||||
"""
|
"""
|
||||||
Fetch issues within a date range.
|
Fetch issues within a date range.
|
||||||
|
|
||||||
|
|
@ -226,9 +226,9 @@ class JiraConnector:
|
||||||
)
|
)
|
||||||
# TODO : This JQL needs some improvement to work as expected
|
# TODO : This JQL needs some improvement to work as expected
|
||||||
|
|
||||||
jql = f"{date_filter}"
|
_jql = f"{date_filter}"
|
||||||
if project_key:
|
if project_key:
|
||||||
jql = (
|
_jql = (
|
||||||
f'project = "{project_key}" AND {date_filter} ORDER BY created DESC'
|
f'project = "{project_key}" AND {date_filter} ORDER BY created DESC'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -283,9 +283,9 @@ class JiraConnector:
|
||||||
return all_issues, None
|
return all_issues, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return [], f"Error fetching issues: {str(e)}"
|
return [], f"Error fetching issues: {e!s}"
|
||||||
|
|
||||||
def format_issue(self, issue: Dict[str, Any]) -> Dict[str, Any]:
|
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Format an issue for easier consumption.
|
Format an issue for easier consumption.
|
||||||
|
|
||||||
|
|
@ -401,7 +401,7 @@ class JiraConnector:
|
||||||
|
|
||||||
return formatted
|
return formatted
|
||||||
|
|
||||||
def format_issue_to_markdown(self, issue: Dict[str, Any]) -> str:
|
def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Convert an issue to markdown format.
|
Convert an issue to markdown format.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,16 @@ A module for retrieving issues and comments from Linear.
|
||||||
Allows fetching issue lists and their comments with date range filtering.
|
Allows fetching issue lists and their comments with date range filtering.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import requests
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
class LinearConnector:
|
class LinearConnector:
|
||||||
"""Class for retrieving issues and comments from Linear."""
|
"""Class for retrieving issues and comments from Linear."""
|
||||||
|
|
||||||
def __init__(self, token: str = None):
|
def __init__(self, token: str | None = None):
|
||||||
"""
|
"""
|
||||||
Initialize the LinearConnector class.
|
Initialize the LinearConnector class.
|
||||||
|
|
||||||
|
|
@ -32,7 +33,7 @@ class LinearConnector:
|
||||||
"""
|
"""
|
||||||
self.token = token
|
self.token = token
|
||||||
|
|
||||||
def get_headers(self) -> Dict[str, str]:
|
def get_headers(self) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Get headers for Linear API requests.
|
Get headers for Linear API requests.
|
||||||
|
|
||||||
|
|
@ -45,12 +46,11 @@ class LinearConnector:
|
||||||
if not self.token:
|
if not self.token:
|
||||||
raise ValueError("Linear token not initialized. Call set_token() first.")
|
raise ValueError("Linear token not initialized. Call set_token() first.")
|
||||||
|
|
||||||
return {
|
return {"Content-Type": "application/json", "Authorization": self.token}
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': self.token
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute_graphql_query(self, query: str, variables: Dict[str, Any] = None) -> Dict[str, Any]:
|
def execute_graphql_query(
|
||||||
|
self, query: str, variables: dict[str, Any] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Execute a GraphQL query against the Linear API.
|
Execute a GraphQL query against the Linear API.
|
||||||
|
|
||||||
|
|
@ -69,23 +69,21 @@ class LinearConnector:
|
||||||
raise ValueError("Linear token not initialized. Call set_token() first.")
|
raise ValueError("Linear token not initialized. Call set_token() first.")
|
||||||
|
|
||||||
headers = self.get_headers()
|
headers = self.get_headers()
|
||||||
payload = {'query': query}
|
payload = {"query": query}
|
||||||
|
|
||||||
if variables:
|
if variables:
|
||||||
payload['variables'] = variables
|
payload["variables"] = variables
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||||
self.api_url,
|
|
||||||
headers=headers,
|
|
||||||
json=payload
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.json()
|
return response.json()
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Query failed with status code {response.status_code}: {response.text}")
|
raise Exception(
|
||||||
|
f"Query failed with status code {response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_all_issues(self, include_comments: bool = True) -> List[Dict[str, Any]]:
|
def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch all issues from Linear.
|
Fetch all issues from Linear.
|
||||||
|
|
||||||
|
|
@ -151,17 +149,18 @@ class LinearConnector:
|
||||||
result = self.execute_graphql_query(query)
|
result = self.execute_graphql_query(query)
|
||||||
|
|
||||||
# Extract issues from the response
|
# Extract issues from the response
|
||||||
if "data" in result and "issues" in result["data"] and "nodes" in result["data"]["issues"]:
|
if (
|
||||||
|
"data" in result
|
||||||
|
and "issues" in result["data"]
|
||||||
|
and "nodes" in result["data"]["issues"]
|
||||||
|
):
|
||||||
return result["data"]["issues"]["nodes"]
|
return result["data"]["issues"]["nodes"]
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_issues_by_date_range(
|
def get_issues_by_date_range(
|
||||||
self,
|
self, start_date: str, end_date: str, include_comments: bool = True
|
||||||
start_date: str,
|
) -> tuple[list[dict[str, Any]], str | None]:
|
||||||
end_date: str,
|
|
||||||
include_comments: bool = True
|
|
||||||
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
|
|
||||||
"""
|
"""
|
||||||
Fetch issues within a date range.
|
Fetch issues within a date range.
|
||||||
|
|
||||||
|
|
@ -263,7 +262,12 @@ class LinearConnector:
|
||||||
|
|
||||||
# Check for errors
|
# Check for errors
|
||||||
if "errors" in result:
|
if "errors" in result:
|
||||||
error_message = "; ".join([error.get("message", "Unknown error") for error in result["errors"]])
|
error_message = "; ".join(
|
||||||
|
[
|
||||||
|
error.get("message", "Unknown error")
|
||||||
|
for error in result["errors"]
|
||||||
|
]
|
||||||
|
)
|
||||||
return [], f"GraphQL errors: {error_message}"
|
return [], f"GraphQL errors: {error_message}"
|
||||||
|
|
||||||
# Extract issues from the response
|
# Extract issues from the response
|
||||||
|
|
@ -278,7 +282,9 @@ class LinearConnector:
|
||||||
if "pageInfo" in issues_page:
|
if "pageInfo" in issues_page:
|
||||||
page_info = issues_page["pageInfo"]
|
page_info = issues_page["pageInfo"]
|
||||||
has_next_page = page_info.get("hasNextPage", False)
|
has_next_page = page_info.get("hasNextPage", False)
|
||||||
cursor = page_info.get("endCursor") if has_next_page else None
|
cursor = (
|
||||||
|
page_info.get("endCursor") if has_next_page else None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
has_next_page = False
|
has_next_page = False
|
||||||
else:
|
else:
|
||||||
|
|
@ -290,12 +296,12 @@ class LinearConnector:
|
||||||
return all_issues, None
|
return all_issues, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return [], f"Error fetching issues: {str(e)}"
|
return [], f"Error fetching issues: {e!s}"
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return [], f"Invalid date format: {str(e)}. Please use YYYY-MM-DD."
|
return [], f"Invalid date format: {e!s}. Please use YYYY-MM-DD."
|
||||||
|
|
||||||
def format_issue(self, issue: Dict[str, Any]) -> Dict[str, Any]:
|
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Format an issue for easier consumption.
|
Format an issue for easier consumption.
|
||||||
|
|
||||||
|
|
@ -311,21 +317,35 @@ class LinearConnector:
|
||||||
"identifier": issue.get("identifier", ""),
|
"identifier": issue.get("identifier", ""),
|
||||||
"title": issue.get("title", ""),
|
"title": issue.get("title", ""),
|
||||||
"description": issue.get("description", ""),
|
"description": issue.get("description", ""),
|
||||||
"state": issue.get("state", {}).get("name", "Unknown") if issue.get("state") else "Unknown",
|
"state": issue.get("state", {}).get("name", "Unknown")
|
||||||
"state_type": issue.get("state", {}).get("type", "Unknown") if issue.get("state") else "Unknown",
|
if issue.get("state")
|
||||||
|
else "Unknown",
|
||||||
|
"state_type": issue.get("state", {}).get("type", "Unknown")
|
||||||
|
if issue.get("state")
|
||||||
|
else "Unknown",
|
||||||
"created_at": issue.get("createdAt", ""),
|
"created_at": issue.get("createdAt", ""),
|
||||||
"updated_at": issue.get("updatedAt", ""),
|
"updated_at": issue.get("updatedAt", ""),
|
||||||
"creator": {
|
"creator": {
|
||||||
"id": issue.get("creator", {}).get("id", "") if issue.get("creator") else "",
|
"id": issue.get("creator", {}).get("id", "")
|
||||||
"name": issue.get("creator", {}).get("name", "Unknown") if issue.get("creator") else "Unknown",
|
if issue.get("creator")
|
||||||
"email": issue.get("creator", {}).get("email", "") if issue.get("creator") else ""
|
else "",
|
||||||
} if issue.get("creator") else {"id": "", "name": "Unknown", "email": ""},
|
"name": issue.get("creator", {}).get("name", "Unknown")
|
||||||
|
if issue.get("creator")
|
||||||
|
else "Unknown",
|
||||||
|
"email": issue.get("creator", {}).get("email", "")
|
||||||
|
if issue.get("creator")
|
||||||
|
else "",
|
||||||
|
}
|
||||||
|
if issue.get("creator")
|
||||||
|
else {"id": "", "name": "Unknown", "email": ""},
|
||||||
"assignee": {
|
"assignee": {
|
||||||
"id": issue.get("assignee", {}).get("id", ""),
|
"id": issue.get("assignee", {}).get("id", ""),
|
||||||
"name": issue.get("assignee", {}).get("name", "Unknown"),
|
"name": issue.get("assignee", {}).get("name", "Unknown"),
|
||||||
"email": issue.get("assignee", {}).get("email", "")
|
"email": issue.get("assignee", {}).get("email", ""),
|
||||||
} if issue.get("assignee") else None,
|
}
|
||||||
"comments": []
|
if issue.get("assignee")
|
||||||
|
else None,
|
||||||
|
"comments": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extract comments if available
|
# Extract comments if available
|
||||||
|
|
@ -337,16 +357,24 @@ class LinearConnector:
|
||||||
"created_at": comment.get("createdAt", ""),
|
"created_at": comment.get("createdAt", ""),
|
||||||
"updated_at": comment.get("updatedAt", ""),
|
"updated_at": comment.get("updatedAt", ""),
|
||||||
"user": {
|
"user": {
|
||||||
"id": comment.get("user", {}).get("id", "") if comment.get("user") else "",
|
"id": comment.get("user", {}).get("id", "")
|
||||||
"name": comment.get("user", {}).get("name", "Unknown") if comment.get("user") else "Unknown",
|
if comment.get("user")
|
||||||
"email": comment.get("user", {}).get("email", "") if comment.get("user") else ""
|
else "",
|
||||||
} if comment.get("user") else {"id": "", "name": "Unknown", "email": ""}
|
"name": comment.get("user", {}).get("name", "Unknown")
|
||||||
|
if comment.get("user")
|
||||||
|
else "Unknown",
|
||||||
|
"email": comment.get("user", {}).get("email", "")
|
||||||
|
if comment.get("user")
|
||||||
|
else "",
|
||||||
|
}
|
||||||
|
if comment.get("user")
|
||||||
|
else {"id": "", "name": "Unknown", "email": ""},
|
||||||
}
|
}
|
||||||
formatted["comments"].append(formatted_comment)
|
formatted["comments"].append(formatted_comment)
|
||||||
|
|
||||||
return formatted
|
return formatted
|
||||||
|
|
||||||
def format_issue_to_markdown(self, issue: Dict[str, Any]) -> str:
|
def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Convert an issue to markdown format.
|
Convert an issue to markdown format.
|
||||||
|
|
||||||
|
|
@ -363,37 +391,37 @@ class LinearConnector:
|
||||||
# Build the markdown content
|
# Build the markdown content
|
||||||
markdown = f"# {issue.get('identifier', 'No ID')}: {issue.get('title', 'No Title')}\n\n"
|
markdown = f"# {issue.get('identifier', 'No ID')}: {issue.get('title', 'No Title')}\n\n"
|
||||||
|
|
||||||
if issue.get('state'):
|
if issue.get("state"):
|
||||||
markdown += f"**Status:** {issue['state']}\n\n"
|
markdown += f"**Status:** {issue['state']}\n\n"
|
||||||
|
|
||||||
if issue.get('assignee') and issue['assignee'].get('name'):
|
if issue.get("assignee") and issue["assignee"].get("name"):
|
||||||
markdown += f"**Assignee:** {issue['assignee']['name']}\n"
|
markdown += f"**Assignee:** {issue['assignee']['name']}\n"
|
||||||
|
|
||||||
if issue.get('creator') and issue['creator'].get('name'):
|
if issue.get("creator") and issue["creator"].get("name"):
|
||||||
markdown += f"**Created by:** {issue['creator']['name']}\n"
|
markdown += f"**Created by:** {issue['creator']['name']}\n"
|
||||||
|
|
||||||
if issue.get('created_at'):
|
if issue.get("created_at"):
|
||||||
created_date = self.format_date(issue['created_at'])
|
created_date = self.format_date(issue["created_at"])
|
||||||
markdown += f"**Created:** {created_date}\n"
|
markdown += f"**Created:** {created_date}\n"
|
||||||
|
|
||||||
if issue.get('updated_at'):
|
if issue.get("updated_at"):
|
||||||
updated_date = self.format_date(issue['updated_at'])
|
updated_date = self.format_date(issue["updated_at"])
|
||||||
markdown += f"**Updated:** {updated_date}\n\n"
|
markdown += f"**Updated:** {updated_date}\n\n"
|
||||||
|
|
||||||
if issue.get('description'):
|
if issue.get("description"):
|
||||||
markdown += f"## Description\n\n{issue['description']}\n\n"
|
markdown += f"## Description\n\n{issue['description']}\n\n"
|
||||||
|
|
||||||
if issue.get('comments'):
|
if issue.get("comments"):
|
||||||
markdown += f"## Comments ({len(issue['comments'])})\n\n"
|
markdown += f"## Comments ({len(issue['comments'])})\n\n"
|
||||||
|
|
||||||
for comment in issue['comments']:
|
for comment in issue["comments"]:
|
||||||
user_name = "Unknown"
|
user_name = "Unknown"
|
||||||
if comment.get('user') and comment['user'].get('name'):
|
if comment.get("user") and comment["user"].get("name"):
|
||||||
user_name = comment['user']['name']
|
user_name = comment["user"]["name"]
|
||||||
|
|
||||||
comment_date = "Unknown date"
|
comment_date = "Unknown date"
|
||||||
if comment.get('created_at'):
|
if comment.get("created_at"):
|
||||||
comment_date = self.format_date(comment['created_at'])
|
comment_date = self.format_date(comment["created_at"])
|
||||||
|
|
||||||
markdown += f"### {user_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n"
|
markdown += f"### {user_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n"
|
||||||
|
|
||||||
|
|
@ -414,8 +442,8 @@ class LinearConnector:
|
||||||
return "Unknown date"
|
return "Unknown date"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dt = datetime.fromisoformat(iso_date.replace('Z', '+00:00'))
|
dt = datetime.fromisoformat(iso_date.replace("Z", "+00:00"))
|
||||||
return dt.strftime('%Y-%m-%d %H:%M:%S')
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return iso_date
|
return iso_date
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from notion_client import Client
|
from notion_client import Client
|
||||||
|
|
||||||
|
|
||||||
class NotionHistoryConnector:
|
class NotionHistoryConnector:
|
||||||
def __init__(self, token):
|
def __init__(self, token):
|
||||||
"""
|
"""
|
||||||
|
|
@ -26,10 +27,7 @@ class NotionHistoryConnector:
|
||||||
search_params = {}
|
search_params = {}
|
||||||
|
|
||||||
# Filter for pages only (not databases)
|
# Filter for pages only (not databases)
|
||||||
search_params["filter"] = {
|
search_params["filter"] = {"value": "page", "property": "object"}
|
||||||
"value": "page",
|
|
||||||
"property": "object"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add date filters if provided
|
# Add date filters if provided
|
||||||
if start_date or end_date:
|
if start_date or end_date:
|
||||||
|
|
@ -45,7 +43,7 @@ class NotionHistoryConnector:
|
||||||
if date_filter:
|
if date_filter:
|
||||||
search_params["sort"] = {
|
search_params["sort"] = {
|
||||||
"direction": "descending",
|
"direction": "descending",
|
||||||
"timestamp": "last_edited_time"
|
"timestamp": "last_edited_time",
|
||||||
}
|
}
|
||||||
|
|
||||||
# First, get a list of all pages the integration has access to
|
# First, get a list of all pages the integration has access to
|
||||||
|
|
@ -60,11 +58,13 @@ class NotionHistoryConnector:
|
||||||
# Get detailed page information
|
# Get detailed page information
|
||||||
page_content = self.get_page_content(page_id)
|
page_content = self.get_page_content(page_id)
|
||||||
|
|
||||||
all_page_data.append({
|
all_page_data.append(
|
||||||
|
{
|
||||||
"page_id": page_id,
|
"page_id": page_id,
|
||||||
"title": self.get_page_title(page),
|
"title": self.get_page_title(page),
|
||||||
"content": page_content
|
"content": page_content,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return all_page_data
|
return all_page_data
|
||||||
|
|
||||||
|
|
@ -81,9 +81,11 @@ class NotionHistoryConnector:
|
||||||
# Title can be in different properties depending on the page type
|
# Title can be in different properties depending on the page type
|
||||||
if "properties" in page:
|
if "properties" in page:
|
||||||
# Try to find a title property
|
# Try to find a title property
|
||||||
for prop_name, prop_data in page["properties"].items():
|
for _prop_name, prop_data in page["properties"].items():
|
||||||
if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
|
if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
|
||||||
return " ".join([text_obj["plain_text"] for text_obj in prop_data["title"]])
|
return " ".join(
|
||||||
|
[text_obj["plain_text"] for text_obj in prop_data["title"]]
|
||||||
|
)
|
||||||
|
|
||||||
# If no title found, return the page ID as fallback
|
# If no title found, return the page ID as fallback
|
||||||
return f"Untitled page ({page['id']})"
|
return f"Untitled page ({page['id']})"
|
||||||
|
|
@ -105,7 +107,9 @@ class NotionHistoryConnector:
|
||||||
# Paginate through all blocks
|
# Paginate through all blocks
|
||||||
while has_more:
|
while has_more:
|
||||||
if cursor:
|
if cursor:
|
||||||
response = self.notion.blocks.children.list(block_id=page_id, start_cursor=cursor)
|
response = self.notion.blocks.children.list(
|
||||||
|
block_id=page_id, start_cursor=cursor
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response = self.notion.blocks.children.list(block_id=page_id)
|
response = self.notion.blocks.children.list(block_id=page_id)
|
||||||
|
|
||||||
|
|
@ -153,7 +157,7 @@ class NotionHistoryConnector:
|
||||||
"id": block_id,
|
"id": block_id,
|
||||||
"type": block_type,
|
"type": block_type,
|
||||||
"content": content,
|
"content": content,
|
||||||
"children": child_blocks
|
"children": child_blocks,
|
||||||
}
|
}
|
||||||
|
|
||||||
def extract_block_content(self, block):
|
def extract_block_content(self, block):
|
||||||
|
|
@ -170,7 +174,9 @@ class NotionHistoryConnector:
|
||||||
|
|
||||||
# Different block types have different structures
|
# Different block types have different structures
|
||||||
if block_type in block and "rich_text" in block[block_type]:
|
if block_type in block and "rich_text" in block[block_type]:
|
||||||
return "".join([text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]])
|
return "".join(
|
||||||
|
[text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]]
|
||||||
|
)
|
||||||
elif block_type == "image":
|
elif block_type == "image":
|
||||||
# Instead of returning the raw URL which may contain sensitive AWS credentials,
|
# Instead of returning the raw URL which may contain sensitive AWS credentials,
|
||||||
# return a placeholder or reference to the image
|
# return a placeholder or reference to the image
|
||||||
|
|
@ -183,13 +189,16 @@ class NotionHistoryConnector:
|
||||||
# Only return the domain part of external URLs to avoid potential sensitive parameters
|
# Only return the domain part of external URLs to avoid potential sensitive parameters
|
||||||
try:
|
try:
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
return f"[External Image from {parsed_url.netloc}]"
|
return f"[External Image from {parsed_url.netloc}]"
|
||||||
except:
|
except Exception:
|
||||||
return "[External Image]"
|
return "[External Image]"
|
||||||
elif block_type == "code":
|
elif block_type == "code":
|
||||||
language = block["code"]["language"]
|
language = block["code"]["language"]
|
||||||
code_text = "".join([text_obj["plain_text"] for text_obj in block["code"]["rich_text"]])
|
code_text = "".join(
|
||||||
|
[text_obj["plain_text"] for text_obj in block["code"]["rich_text"]]
|
||||||
|
)
|
||||||
return f"```{language}\n{code_text}\n```"
|
return f"```{language}\n{code_text}\n```"
|
||||||
elif block_type == "equation":
|
elif block_type == "equation":
|
||||||
return block["equation"]["expression"]
|
return block["equation"]["expression"]
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,13 @@ A module for retrieving conversation history from Slack channels.
|
||||||
Allows fetching channel lists and message history with date range filtering.
|
Allows fetching channel lists and message history with date range filtering.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time # Added import
|
|
||||||
import logging # Added import
|
import logging # Added import
|
||||||
|
import time # Added import
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from slack_sdk import WebClient
|
from slack_sdk import WebClient
|
||||||
from slack_sdk.errors import SlackApiError
|
from slack_sdk.errors import SlackApiError
|
||||||
from datetime import datetime
|
|
||||||
from typing import Dict, List, Optional, Tuple, Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # Added logger
|
logger = logging.getLogger(__name__) # Added logger
|
||||||
|
|
||||||
|
|
@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) # Added logger
|
||||||
class SlackHistory:
|
class SlackHistory:
|
||||||
"""Class for retrieving conversation history from Slack channels."""
|
"""Class for retrieving conversation history from Slack channels."""
|
||||||
|
|
||||||
def __init__(self, token: str = None):
|
def __init__(self, token: str | None = None):
|
||||||
"""
|
"""
|
||||||
Initialize the SlackHistory class.
|
Initialize the SlackHistory class.
|
||||||
|
|
||||||
|
|
@ -36,7 +37,7 @@ class SlackHistory:
|
||||||
"""
|
"""
|
||||||
self.client = WebClient(token=token)
|
self.client = WebClient(token=token)
|
||||||
|
|
||||||
def get_all_channels(self, include_private: bool = True) -> List[Dict[str, Any]]:
|
def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch all channels that the bot has access to, with rate limit handling.
|
Fetch all channels that the bot has access to, with rate limit handling.
|
||||||
|
|
||||||
|
|
@ -65,14 +66,14 @@ class SlackHistory:
|
||||||
while is_first_request or next_cursor:
|
while is_first_request or next_cursor:
|
||||||
try:
|
try:
|
||||||
if not is_first_request: # Add delay only for paginated requests
|
if not is_first_request: # Add delay only for paginated requests
|
||||||
logger.info(f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}")
|
logger.info(
|
||||||
|
f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}"
|
||||||
|
)
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
current_limit = 1000 # Max limit
|
current_limit = 1000 # Max limit
|
||||||
api_result = self.client.conversations_list(
|
api_result = self.client.conversations_list(
|
||||||
types=types,
|
types=types, cursor=next_cursor, limit=current_limit
|
||||||
cursor=next_cursor,
|
|
||||||
limit=current_limit
|
|
||||||
)
|
)
|
||||||
|
|
||||||
channels_on_page = api_result["channels"]
|
channels_on_page = api_result["channels"]
|
||||||
|
|
@ -86,12 +87,13 @@ class SlackHistory:
|
||||||
# It indicates if the authenticated user (bot) is a member.
|
# It indicates if the authenticated user (bot) is a member.
|
||||||
# For public channels, this might be true or the API might not focus on it
|
# For public channels, this might be true or the API might not focus on it
|
||||||
# if the bot can read it anyway. For private, it's crucial.
|
# if the bot can read it anyway. For private, it's crucial.
|
||||||
"is_member": channel.get("is_member", False)
|
"is_member": channel.get("is_member", False),
|
||||||
}
|
}
|
||||||
channels_list.append(channel_data)
|
channels_list.append(channel_data)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Channel found with missing name or id. Data: {channel}")
|
logger.warning(
|
||||||
|
f"Channel found with missing name or id. Data: {channel}"
|
||||||
|
)
|
||||||
|
|
||||||
next_cursor = api_result.get("response_metadata", {}).get("next_cursor")
|
next_cursor = api_result.get("response_metadata", {}).get("next_cursor")
|
||||||
is_first_request = False # Subsequent requests are not the first
|
is_first_request = False # Subsequent requests are not the first
|
||||||
|
|
@ -101,21 +103,29 @@ class SlackHistory:
|
||||||
|
|
||||||
except SlackApiError as e:
|
except SlackApiError as e:
|
||||||
if e.response is not None and e.response.status_code == 429:
|
if e.response is not None and e.response.status_code == 429:
|
||||||
retry_after_header = e.response.headers.get('Retry-After')
|
retry_after_header = e.response.headers.get("Retry-After")
|
||||||
wait_duration = 60 # Default wait time
|
wait_duration = 60 # Default wait time
|
||||||
if retry_after_header and retry_after_header.isdigit():
|
if retry_after_header and retry_after_header.isdigit():
|
||||||
wait_duration = int(retry_after_header)
|
wait_duration = int(retry_after_header)
|
||||||
|
|
||||||
logger.warning(f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}")
|
logger.warning(
|
||||||
|
f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}"
|
||||||
|
)
|
||||||
time.sleep(wait_duration)
|
time.sleep(wait_duration)
|
||||||
# The loop will continue, retrying with the same cursor
|
# The loop will continue, retrying with the same cursor
|
||||||
else:
|
else:
|
||||||
# Not a 429 error, or no response object, re-raise
|
# Not a 429 error, or no response object, re-raise
|
||||||
raise SlackApiError(f"Error retrieving channels: {e}", e.response)
|
raise SlackApiError(
|
||||||
|
f"Error retrieving channels: {e}", e.response
|
||||||
|
) from e
|
||||||
except Exception as general_error:
|
except Exception as general_error:
|
||||||
# Handle other potential errors like network issues if necessary, or re-raise
|
# Handle other potential errors like network issues if necessary, or re-raise
|
||||||
logger.error(f"An unexpected error occurred during channel fetching: {general_error}")
|
logger.error(
|
||||||
raise RuntimeError(f"An unexpected error occurred during channel fetching: {general_error}")
|
f"An unexpected error occurred during channel fetching: {general_error}"
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"An unexpected error occurred during channel fetching: {general_error}"
|
||||||
|
) from general_error
|
||||||
|
|
||||||
return channels_list
|
return channels_list
|
||||||
|
|
||||||
|
|
@ -123,9 +133,9 @@ class SlackHistory:
|
||||||
self,
|
self,
|
||||||
channel_id: str,
|
channel_id: str,
|
||||||
limit: int = 1000,
|
limit: int = 1000,
|
||||||
oldest: Optional[int] = None,
|
oldest: int | None = None,
|
||||||
latest: Optional[int] = None
|
latest: int | None = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch conversation history for a channel.
|
Fetch conversation history for a channel.
|
||||||
|
|
||||||
|
|
@ -170,8 +180,11 @@ class SlackHistory:
|
||||||
result = self.client.conversations_history(**kwargs)
|
result = self.client.conversations_history(**kwargs)
|
||||||
current_api_call_successful = True
|
current_api_call_successful = True
|
||||||
except SlackApiError as e_history:
|
except SlackApiError as e_history:
|
||||||
if e_history.response is not None and e_history.response.status_code == 429:
|
if (
|
||||||
retry_after_str = e_history.response.headers.get('Retry-After')
|
e_history.response is not None
|
||||||
|
and e_history.response.status_code == 429
|
||||||
|
):
|
||||||
|
retry_after_str = e_history.response.headers.get("Retry-After")
|
||||||
wait_time = 60 # Default
|
wait_time = 60 # Default
|
||||||
if retry_after_str and retry_after_str.isdigit():
|
if retry_after_str and retry_after_str.isdigit():
|
||||||
wait_time = int(retry_after_str)
|
wait_time = int(retry_after_str)
|
||||||
|
|
@ -197,26 +210,33 @@ class SlackHistory:
|
||||||
break # Exit pagination loop
|
break # Exit pagination loop
|
||||||
|
|
||||||
except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try
|
except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try
|
||||||
if (e.response is not None and
|
if (
|
||||||
hasattr(e.response, 'data') and
|
e.response is not None
|
||||||
isinstance(e.response.data, dict) and
|
and hasattr(e.response, "data")
|
||||||
e.response.data.get('error') == 'not_in_channel'):
|
and isinstance(e.response.data, dict)
|
||||||
|
and e.response.data.get("error") == "not_in_channel"
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Bot is not in channel '{channel_id}'. Cannot fetch history. "
|
f"Bot is not in channel '{channel_id}'. Cannot fetch history. "
|
||||||
"Please add the bot to this channel."
|
"Please add the bot to this channel."
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
# For other SlackApiErrors from inner block or this level
|
# For other SlackApiErrors from inner block or this level
|
||||||
raise SlackApiError(f"Error retrieving history for channel {channel_id}: {e}", e.response)
|
raise SlackApiError(
|
||||||
|
f"Error retrieving history for channel {channel_id}: {e}",
|
||||||
|
e.response,
|
||||||
|
) from e
|
||||||
except Exception as general_error: # Catch any other unexpected errors
|
except Exception as general_error: # Catch any other unexpected errors
|
||||||
logger.error(f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}")
|
logger.error(
|
||||||
|
f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}"
|
||||||
|
)
|
||||||
# Re-raise the general error to allow higher-level handling or visibility
|
# Re-raise the general error to allow higher-level handling or visibility
|
||||||
raise
|
raise general_error from general_error
|
||||||
|
|
||||||
return messages[:limit]
|
return messages[:limit]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_date_to_timestamp(date_str: str) -> Optional[int]:
|
def convert_date_to_timestamp(date_str: str) -> int | None:
|
||||||
"""
|
"""
|
||||||
Convert a date string in format YYYY-MM-DD to Unix timestamp.
|
Convert a date string in format YYYY-MM-DD to Unix timestamp.
|
||||||
|
|
||||||
|
|
@ -233,12 +253,8 @@ class SlackHistory:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_history_by_date_range(
|
def get_history_by_date_range(
|
||||||
self,
|
self, channel_id: str, start_date: str, end_date: str, limit: int = 1000
|
||||||
channel_id: str,
|
) -> tuple[list[dict[str, Any]], str | None]:
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
limit: int = 1000
|
|
||||||
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
|
|
||||||
"""
|
"""
|
||||||
Fetch conversation history within a date range.
|
Fetch conversation history within a date range.
|
||||||
|
|
||||||
|
|
@ -253,7 +269,10 @@ class SlackHistory:
|
||||||
"""
|
"""
|
||||||
oldest = self.convert_date_to_timestamp(start_date)
|
oldest = self.convert_date_to_timestamp(start_date)
|
||||||
if not oldest:
|
if not oldest:
|
||||||
return [], f"Invalid start date format: {start_date}. Please use YYYY-MM-DD."
|
return (
|
||||||
|
[],
|
||||||
|
f"Invalid start date format: {start_date}. Please use YYYY-MM-DD.",
|
||||||
|
)
|
||||||
|
|
||||||
latest = self.convert_date_to_timestamp(end_date)
|
latest = self.convert_date_to_timestamp(end_date)
|
||||||
if not latest:
|
if not latest:
|
||||||
|
|
@ -264,18 +283,15 @@ class SlackHistory:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
messages = self.get_conversation_history(
|
messages = self.get_conversation_history(
|
||||||
channel_id=channel_id,
|
channel_id=channel_id, limit=limit, oldest=oldest, latest=latest
|
||||||
limit=limit,
|
|
||||||
oldest=oldest,
|
|
||||||
latest=latest
|
|
||||||
)
|
)
|
||||||
return messages, None
|
return messages, None
|
||||||
except SlackApiError as e:
|
except SlackApiError as e:
|
||||||
return [], f"Slack API error: {str(e)}"
|
return [], f"Slack API error: {e!s}"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return [], str(e)
|
return [], str(e)
|
||||||
|
|
||||||
def get_user_info(self, user_id: str) -> Dict[str, Any]:
|
def get_user_info(self, user_id: str) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get information about a user.
|
Get information about a user.
|
||||||
|
|
||||||
|
|
@ -302,22 +318,34 @@ class SlackHistory:
|
||||||
return result["user"] # Success, return and exit loop implicitly
|
return result["user"] # Success, return and exit loop implicitly
|
||||||
|
|
||||||
except SlackApiError as e_user_info:
|
except SlackApiError as e_user_info:
|
||||||
if e_user_info.response is not None and e_user_info.response.status_code == 429:
|
if (
|
||||||
retry_after_str = e_user_info.response.headers.get('Retry-After')
|
e_user_info.response is not None
|
||||||
|
and e_user_info.response.status_code == 429
|
||||||
|
):
|
||||||
|
retry_after_str = e_user_info.response.headers.get("Retry-After")
|
||||||
wait_time = 30 # Default for Tier 4, can be adjusted
|
wait_time = 30 # Default for Tier 4, can be adjusted
|
||||||
if retry_after_str and retry_after_str.isdigit():
|
if retry_after_str and retry_after_str.isdigit():
|
||||||
wait_time = int(retry_after_str)
|
wait_time = int(retry_after_str)
|
||||||
logger.warning(f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds.")
|
logger.warning(
|
||||||
|
f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds."
|
||||||
|
)
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
continue # Retry the API call
|
continue # Retry the API call
|
||||||
else:
|
else:
|
||||||
# Not a 429 error, or no response object, re-raise
|
# Not a 429 error, or no response object, re-raise
|
||||||
raise SlackApiError(f"Error retrieving user info for {user_id}: {e_user_info}", e_user_info.response)
|
raise SlackApiError(
|
||||||
|
f"Error retrieving user info for {user_id}: {e_user_info}",
|
||||||
|
e_user_info.response,
|
||||||
|
) from e_user_info
|
||||||
except Exception as general_error: # Catch any other unexpected errors
|
except Exception as general_error: # Catch any other unexpected errors
|
||||||
logger.error(f"Unexpected error in get_user_info for user {user_id}: {general_error}")
|
logger.error(
|
||||||
raise # Re-raise unexpected errors
|
f"Unexpected error in get_user_info for user {user_id}: {general_error}"
|
||||||
|
)
|
||||||
|
raise general_error from general_error # Re-raise unexpected errors
|
||||||
|
|
||||||
def format_message(self, msg: Dict[str, Any], include_user_info: bool = False) -> Dict[str, Any]:
|
def format_message(
|
||||||
|
self, msg: dict[str, Any], include_user_info: bool = False
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Format a message for easier consumption.
|
Format a message for easier consumption.
|
||||||
|
|
||||||
|
|
@ -331,7 +359,9 @@ class SlackHistory:
|
||||||
formatted = {
|
formatted = {
|
||||||
"text": msg.get("text", ""),
|
"text": msg.get("text", ""),
|
||||||
"timestamp": msg.get("ts"),
|
"timestamp": msg.get("ts"),
|
||||||
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime('%Y-%m-%d %H:%M:%S'),
|
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime(
|
||||||
|
"%Y-%m-%d %H:%M:%S"
|
||||||
|
),
|
||||||
"user_id": msg.get("user", "UNKNOWN"),
|
"user_id": msg.get("user", "UNKNOWN"),
|
||||||
"has_attachments": bool(msg.get("attachments")),
|
"has_attachments": bool(msg.get("attachments")),
|
||||||
"has_files": bool(msg.get("files")),
|
"has_files": bool(msg.get("files")),
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,17 @@
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch, Mock
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from github3.exceptions import ForbiddenError # Import the specific exception
|
||||||
|
|
||||||
# Adjust the import path based on the actual location if test_github_connector.py
|
# Adjust the import path based on the actual location if test_github_connector.py
|
||||||
# is not in the same directory as github_connector.py or if paths are set up differently.
|
# is not in the same directory as github_connector.py or if paths are set up differently.
|
||||||
# Assuming surfsend_backend/app/connectors/test_github_connector.py
|
# Assuming surfsend_backend/app/connectors/test_github_connector.py
|
||||||
from surfsense_backend.app.connectors.github_connector import GitHubConnector
|
from surfsense_backend.app.connectors.github_connector import GitHubConnector
|
||||||
from github3.exceptions import ForbiddenError # Import the specific exception
|
|
||||||
|
|
||||||
class TestGitHubConnector(unittest.TestCase):
|
class TestGitHubConnector(unittest.TestCase):
|
||||||
|
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
|
||||||
def test_get_user_repositories_uses_type_all(self, mock_github_login):
|
def test_get_user_repositories_uses_type_all(self, mock_github_login):
|
||||||
# Mock the GitHub client object and its methods
|
# Mock the GitHub client object and its methods
|
||||||
mock_gh_instance = Mock()
|
mock_gh_instance = Mock()
|
||||||
|
|
@ -27,7 +28,9 @@ class TestGitHubConnector(unittest.TestCase):
|
||||||
mock_repo1_data.private = False
|
mock_repo1_data.private = False
|
||||||
mock_repo1_data.html_url = "http://example.com/user/repo1"
|
mock_repo1_data.html_url = "http://example.com/user/repo1"
|
||||||
mock_repo1_data.description = "Test repo 1"
|
mock_repo1_data.description = "Test repo 1"
|
||||||
mock_repo1_data.updated_at = datetime(2023, 1, 1, 10, 30, 0) # Added time component
|
mock_repo1_data.updated_at = datetime(
|
||||||
|
2023, 1, 1, 10, 30, 0
|
||||||
|
) # Added time component
|
||||||
|
|
||||||
mock_repo2_data = Mock()
|
mock_repo2_data = Mock()
|
||||||
mock_repo2_data.id = 2
|
mock_repo2_data.id = 2
|
||||||
|
|
@ -36,7 +39,9 @@ class TestGitHubConnector(unittest.TestCase):
|
||||||
mock_repo2_data.private = True
|
mock_repo2_data.private = True
|
||||||
mock_repo2_data.html_url = "http://example.com/org/org-repo"
|
mock_repo2_data.html_url = "http://example.com/org/org-repo"
|
||||||
mock_repo2_data.description = "Org repo"
|
mock_repo2_data.description = "Org repo"
|
||||||
mock_repo2_data.updated_at = datetime(2023, 1, 2, 12, 0, 0) # Added time component
|
mock_repo2_data.updated_at = datetime(
|
||||||
|
2023, 1, 2, 12, 0, 0
|
||||||
|
) # Added time component
|
||||||
|
|
||||||
# Configure the mock for gh.repositories() call
|
# Configure the mock for gh.repositories() call
|
||||||
# This method is an iterator, so it should return an iterable (e.g., a list)
|
# This method is an iterator, so it should return an iterable (e.g., a list)
|
||||||
|
|
@ -46,26 +51,38 @@ class TestGitHubConnector(unittest.TestCase):
|
||||||
repositories = connector.get_user_repositories()
|
repositories = connector.get_user_repositories()
|
||||||
|
|
||||||
# Assert that gh.repositories was called correctly
|
# Assert that gh.repositories was called correctly
|
||||||
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated')
|
mock_gh_instance.repositories.assert_called_once_with(
|
||||||
|
type="all", sort="updated"
|
||||||
|
)
|
||||||
|
|
||||||
# Assert the structure and content of the returned data
|
# Assert the structure and content of the returned data
|
||||||
expected_repositories = [
|
expected_repositories = [
|
||||||
{
|
{
|
||||||
"id": 1, "name": "repo1", "full_name": "user/repo1", "private": False,
|
"id": 1,
|
||||||
"url": "http://example.com/user/repo1", "description": "Test repo 1",
|
"name": "repo1",
|
||||||
"last_updated": datetime(2023, 1, 1, 10, 30, 0)
|
"full_name": "user/repo1",
|
||||||
|
"private": False,
|
||||||
|
"url": "http://example.com/user/repo1",
|
||||||
|
"description": "Test repo 1",
|
||||||
|
"last_updated": datetime(2023, 1, 1, 10, 30, 0),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2, "name": "org-repo", "full_name": "org/org-repo", "private": True,
|
"id": 2,
|
||||||
"url": "http://example.com/org/org-repo", "description": "Org repo",
|
"name": "org-repo",
|
||||||
"last_updated": datetime(2023, 1, 2, 12, 0, 0)
|
"full_name": "org/org-repo",
|
||||||
}
|
"private": True,
|
||||||
|
"url": "http://example.com/org/org-repo",
|
||||||
|
"description": "Org repo",
|
||||||
|
"last_updated": datetime(2023, 1, 2, 12, 0, 0),
|
||||||
|
},
|
||||||
]
|
]
|
||||||
self.assertEqual(repositories, expected_repositories)
|
self.assertEqual(repositories, expected_repositories)
|
||||||
self.assertEqual(len(repositories), 2)
|
self.assertEqual(len(repositories), 2)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||||
def test_get_user_repositories_handles_empty_description_and_none_updated_at(self, mock_github_login):
|
def test_get_user_repositories_handles_empty_description_and_none_updated_at(
|
||||||
|
self, mock_github_login
|
||||||
|
):
|
||||||
# Mock the GitHub client object and its methods
|
# Mock the GitHub client object and its methods
|
||||||
mock_gh_instance = Mock()
|
mock_gh_instance = Mock()
|
||||||
mock_github_login.return_value = mock_gh_instance
|
mock_github_login.return_value = mock_gh_instance
|
||||||
|
|
@ -84,17 +101,23 @@ class TestGitHubConnector(unittest.TestCase):
|
||||||
connector = GitHubConnector(token="fake_token")
|
connector = GitHubConnector(token="fake_token")
|
||||||
repositories = connector.get_user_repositories()
|
repositories = connector.get_user_repositories()
|
||||||
|
|
||||||
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated')
|
mock_gh_instance.repositories.assert_called_once_with(
|
||||||
|
type="all", sort="updated"
|
||||||
|
)
|
||||||
expected_repositories = [
|
expected_repositories = [
|
||||||
{
|
{
|
||||||
"id": 1, "name": "repo_no_desc", "full_name": "user/repo_no_desc", "private": False,
|
"id": 1,
|
||||||
"url": "http://example.com/user/repo_no_desc", "description": "", # Expect empty string
|
"name": "repo_no_desc",
|
||||||
"last_updated": None # Expect None
|
"full_name": "user/repo_no_desc",
|
||||||
|
"private": False,
|
||||||
|
"url": "http://example.com/user/repo_no_desc",
|
||||||
|
"description": "", # Expect empty string
|
||||||
|
"last_updated": None, # Expect None
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
self.assertEqual(repositories, expected_repositories)
|
self.assertEqual(repositories, expected_repositories)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||||
def test_github_connector_initialization_failure_forbidden(self, mock_github_login):
|
def test_github_connector_initialization_failure_forbidden(self, mock_github_login):
|
||||||
# Test that __init__ raises ValueError on auth failure (ForbiddenError)
|
# Test that __init__ raises ValueError on auth failure (ForbiddenError)
|
||||||
mock_gh_instance = Mock()
|
mock_gh_instance = Mock()
|
||||||
|
|
@ -111,10 +134,14 @@ class TestGitHubConnector(unittest.TestCase):
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as context:
|
with self.assertRaises(ValueError) as context:
|
||||||
GitHubConnector(token="invalid_token_forbidden")
|
GitHubConnector(token="invalid_token_forbidden")
|
||||||
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception))
|
self.assertIn(
|
||||||
|
"Invalid GitHub token or insufficient permissions.", str(context.exception)
|
||||||
|
)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||||
def test_github_connector_initialization_failure_authentication_failed(self, mock_github_login):
|
def test_github_connector_initialization_failure_authentication_failed(
|
||||||
|
self, mock_github_login
|
||||||
|
):
|
||||||
# Test that __init__ raises ValueError on auth failure (AuthenticationFailed, which is a subclass of ForbiddenError)
|
# Test that __init__ raises ValueError on auth failure (AuthenticationFailed, which is a subclass of ForbiddenError)
|
||||||
# For github3.py, AuthenticationFailed is more specific for token issues.
|
# For github3.py, AuthenticationFailed is more specific for token issues.
|
||||||
from github3.exceptions import AuthenticationFailed
|
from github3.exceptions import AuthenticationFailed
|
||||||
|
|
@ -129,9 +156,11 @@ class TestGitHubConnector(unittest.TestCase):
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as context:
|
with self.assertRaises(ValueError) as context:
|
||||||
GitHubConnector(token="invalid_token_authfailed")
|
GitHubConnector(token="invalid_token_authfailed")
|
||||||
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception))
|
self.assertIn(
|
||||||
|
"Invalid GitHub token or insufficient permissions.", str(context.exception)
|
||||||
|
)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||||
def test_get_user_repositories_handles_api_exception(self, mock_github_login):
|
def test_get_user_repositories_handles_api_exception(self, mock_github_login):
|
||||||
mock_gh_instance = Mock()
|
mock_gh_instance = Mock()
|
||||||
mock_github_login.return_value = mock_gh_instance
|
mock_github_login.return_value = mock_gh_instance
|
||||||
|
|
@ -142,13 +171,18 @@ class TestGitHubConnector(unittest.TestCase):
|
||||||
|
|
||||||
connector = GitHubConnector(token="fake_token")
|
connector = GitHubConnector(token="fake_token")
|
||||||
# We expect it to log an error and return an empty list
|
# We expect it to log an error and return an empty list
|
||||||
with patch('surfsense_backend.app.connectors.github_connector.logger') as mock_logger:
|
with patch(
|
||||||
|
"surfsense_backend.app.connectors.github_connector.logger"
|
||||||
|
) as mock_logger:
|
||||||
repositories = connector.get_user_repositories()
|
repositories = connector.get_user_repositories()
|
||||||
|
|
||||||
self.assertEqual(repositories, [])
|
self.assertEqual(repositories, [])
|
||||||
mock_logger.error.assert_called_once()
|
mock_logger.error.assert_called_once()
|
||||||
self.assertIn("Failed to fetch GitHub repositories: API Error", mock_logger.error.call_args[0][0])
|
self.assertIn(
|
||||||
|
"Failed to fetch GitHub repositories: API Error",
|
||||||
|
mock_logger.error.call_args[0][0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,39 @@
|
||||||
import unittest
|
import unittest
|
||||||
import time # Imported to be available for patching target module
|
from unittest.mock import Mock, call, patch
|
||||||
from unittest.mock import patch, Mock, call
|
|
||||||
from slack_sdk.errors import SlackApiError
|
from slack_sdk.errors import SlackApiError
|
||||||
|
|
||||||
# Since test_slack_history.py is in the same directory as slack_history.py
|
# Since test_slack_history.py is in the same directory as slack_history.py
|
||||||
from .slack_history import SlackHistory
|
from .slack_history import SlackHistory
|
||||||
|
|
||||||
class TestSlackHistoryGetAllChannels(unittest.TestCase):
|
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
class TestSlackHistoryGetAllChannels(unittest.TestCase):
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
def test_get_all_channels_pagination_with_delay(self, MockWebClient, mock_sleep, mock_logger):
|
@patch("slack_sdk.WebClient")
|
||||||
mock_client_instance = MockWebClient.return_value
|
def test_get_all_channels_pagination_with_delay(
|
||||||
|
self, mock_web_client, mock_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
# Mock API responses now include is_private and is_member
|
# Mock API responses now include is_private and is_member
|
||||||
page1_response = {
|
page1_response = {
|
||||||
"channels": [
|
"channels": [
|
||||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True},
|
{"name": "general", "id": "C1", "is_private": False, "is_member": True},
|
||||||
{"name": "dev", "id": "C0", "is_private": False, "is_member": True}
|
{"name": "dev", "id": "C0", "is_private": False, "is_member": True},
|
||||||
],
|
],
|
||||||
"response_metadata": {"next_cursor": "cursor123"}
|
"response_metadata": {"next_cursor": "cursor123"},
|
||||||
}
|
}
|
||||||
page2_response = {
|
page2_response = {
|
||||||
"channels": [{"name": "random", "id": "C2", "is_private": True, "is_member": True}],
|
"channels": [
|
||||||
"response_metadata": {"next_cursor": ""}
|
{"name": "random", "id": "C2", "is_private": True, "is_member": True}
|
||||||
|
],
|
||||||
|
"response_metadata": {"next_cursor": ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_client_instance.conversations_list.side_effect = [
|
mock_client_instance.conversations_list.side_effect = [
|
||||||
page1_response,
|
page1_response,
|
||||||
page2_response
|
page2_response,
|
||||||
]
|
]
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
|
|
@ -38,129 +42,163 @@ class TestSlackHistoryGetAllChannels(unittest.TestCase):
|
||||||
expected_channels_list = [
|
expected_channels_list = [
|
||||||
{"id": "C1", "name": "general", "is_private": False, "is_member": True},
|
{"id": "C1", "name": "general", "is_private": False, "is_member": True},
|
||||||
{"id": "C0", "name": "dev", "is_private": False, "is_member": True},
|
{"id": "C0", "name": "dev", "is_private": False, "is_member": True},
|
||||||
{"id": "C2", "name": "random", "is_private": True, "is_member": True}
|
{"id": "C2", "name": "random", "is_private": True, "is_member": True},
|
||||||
]
|
]
|
||||||
|
|
||||||
self.assertEqual(len(channels_list), 3)
|
self.assertEqual(len(channels_list), 3)
|
||||||
self.assertListEqual(channels_list, expected_channels_list) # Assert list equality
|
self.assertListEqual(
|
||||||
|
channels_list, expected_channels_list
|
||||||
|
) # Assert list equality
|
||||||
|
|
||||||
expected_calls = [
|
expected_calls = [
|
||||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||||
call(types="public_channel,private_channel", cursor="cursor123", limit=1000)
|
call(
|
||||||
|
types="public_channel,private_channel", cursor="cursor123", limit=1000
|
||||||
|
),
|
||||||
]
|
]
|
||||||
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
|
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
|
||||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||||
|
|
||||||
mock_sleep.assert_called_once_with(3)
|
mock_sleep.assert_called_once_with(3)
|
||||||
mock_logger.info.assert_called_once_with("Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123")
|
mock_logger.info.assert_called_once_with(
|
||||||
|
"Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123"
|
||||||
|
)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_get_all_channels_rate_limit_with_retry_after(self, MockWebClient, mock_sleep, mock_logger):
|
def test_get_all_channels_rate_limit_with_retry_after(
|
||||||
mock_client_instance = MockWebClient.return_value
|
self, mock_web_client, mock_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
mock_error_response = Mock()
|
mock_error_response = Mock()
|
||||||
mock_error_response.status_code = 429
|
mock_error_response.status_code = 429
|
||||||
mock_error_response.headers = {'Retry-After': '5'}
|
mock_error_response.headers = {"Retry-After": "5"}
|
||||||
|
|
||||||
successful_response = {
|
successful_response = {
|
||||||
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
|
"channels": [
|
||||||
"response_metadata": {"next_cursor": ""}
|
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
|
||||||
|
],
|
||||||
|
"response_metadata": {"next_cursor": ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_client_instance.conversations_list.side_effect = [
|
mock_client_instance.conversations_list.side_effect = [
|
||||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||||
successful_response
|
successful_response,
|
||||||
]
|
]
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
channels_list = slack_history.get_all_channels(include_private=True)
|
channels_list = slack_history.get_all_channels(include_private=True)
|
||||||
|
|
||||||
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
|
expected_channels_list = [
|
||||||
|
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
|
||||||
|
]
|
||||||
self.assertEqual(len(channels_list), 1)
|
self.assertEqual(len(channels_list), 1)
|
||||||
self.assertListEqual(channels_list, expected_channels_list)
|
self.assertListEqual(channels_list, expected_channels_list)
|
||||||
|
|
||||||
mock_sleep.assert_called_once_with(5)
|
mock_sleep.assert_called_once_with(5)
|
||||||
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None")
|
mock_logger.warning.assert_called_once_with(
|
||||||
|
"Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None"
|
||||||
|
)
|
||||||
|
|
||||||
expected_calls = [
|
expected_calls = [
|
||||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||||
call(types="public_channel,private_channel", cursor=None, limit=1000)
|
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||||
]
|
]
|
||||||
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
|
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
|
||||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_get_all_channels_rate_limit_no_retry_after_valid_header(self, MockWebClient, mock_sleep, mock_logger):
|
def test_get_all_channels_rate_limit_no_retry_after_valid_header(
|
||||||
mock_client_instance = MockWebClient.return_value
|
self, mock_web_client, mock_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
mock_error_response = Mock()
|
mock_error_response = Mock()
|
||||||
mock_error_response.status_code = 429
|
mock_error_response.status_code = 429
|
||||||
mock_error_response.headers = {'Retry-After': 'invalid_value'}
|
mock_error_response.headers = {"Retry-After": "invalid_value"}
|
||||||
|
|
||||||
successful_response = {
|
successful_response = {
|
||||||
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
|
"channels": [
|
||||||
"response_metadata": {"next_cursor": ""}
|
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
|
||||||
|
],
|
||||||
|
"response_metadata": {"next_cursor": ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_client_instance.conversations_list.side_effect = [
|
mock_client_instance.conversations_list.side_effect = [
|
||||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||||
successful_response
|
successful_response,
|
||||||
]
|
]
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
channels_list = slack_history.get_all_channels(include_private=True)
|
channels_list = slack_history.get_all_channels(include_private=True)
|
||||||
|
|
||||||
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
|
expected_channels_list = [
|
||||||
|
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
|
||||||
|
]
|
||||||
self.assertListEqual(channels_list, expected_channels_list)
|
self.assertListEqual(channels_list, expected_channels_list)
|
||||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||||
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None")
|
mock_logger.warning.assert_called_once_with(
|
||||||
|
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
|
||||||
|
)
|
||||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_get_all_channels_rate_limit_no_retry_after_header(self, MockWebClient, mock_sleep, mock_logger):
|
def test_get_all_channels_rate_limit_no_retry_after_header(
|
||||||
mock_client_instance = MockWebClient.return_value
|
self, mock_web_client, mock_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
mock_error_response = Mock()
|
mock_error_response = Mock()
|
||||||
mock_error_response.status_code = 429
|
mock_error_response.status_code = 429
|
||||||
mock_error_response.headers = {}
|
mock_error_response.headers = {}
|
||||||
|
|
||||||
successful_response = {
|
successful_response = {
|
||||||
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
|
"channels": [
|
||||||
"response_metadata": {"next_cursor": ""}
|
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
|
||||||
|
],
|
||||||
|
"response_metadata": {"next_cursor": ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_client_instance.conversations_list.side_effect = [
|
mock_client_instance.conversations_list.side_effect = [
|
||||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||||
successful_response
|
successful_response,
|
||||||
]
|
]
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
channels_list = slack_history.get_all_channels(include_private=True)
|
channels_list = slack_history.get_all_channels(include_private=True)
|
||||||
|
|
||||||
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
|
expected_channels_list = [
|
||||||
|
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
|
||||||
|
]
|
||||||
self.assertListEqual(channels_list, expected_channels_list)
|
self.assertListEqual(channels_list, expected_channels_list)
|
||||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||||
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None")
|
mock_logger.warning.assert_called_once_with(
|
||||||
|
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
|
||||||
|
)
|
||||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_get_all_channels_other_slack_api_error(self, MockWebClient, mock_sleep, mock_logger):
|
def test_get_all_channels_other_slack_api_error(
|
||||||
mock_client_instance = MockWebClient.return_value
|
self, mock_web_client, mock_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
mock_error_response = Mock()
|
mock_error_response = Mock()
|
||||||
mock_error_response.status_code = 500
|
mock_error_response.status_code = 500
|
||||||
mock_error_response.headers = {}
|
mock_error_response.headers = {}
|
||||||
mock_error_response.data = {"ok": False, "error": "internal_error"}
|
mock_error_response.data = {"ok": False, "error": "internal_error"}
|
||||||
|
|
||||||
original_error = SlackApiError(message="server error", response=mock_error_response)
|
original_error = SlackApiError(
|
||||||
|
message="server error", response=mock_error_response
|
||||||
|
)
|
||||||
mock_client_instance.conversations_list.side_effect = original_error
|
mock_client_instance.conversations_list.side_effect = original_error
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
|
|
@ -176,54 +214,75 @@ class TestSlackHistoryGetAllChannels(unittest.TestCase):
|
||||||
types="public_channel,private_channel", cursor=None, limit=1000
|
types="public_channel,private_channel", cursor=None, limit=1000
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_get_all_channels_handles_missing_name_id_gracefully(self, MockWebClient, mock_sleep, mock_logger):
|
def test_get_all_channels_handles_missing_name_id_gracefully(
|
||||||
mock_client_instance = MockWebClient.return_value
|
self, mock_web_client, mock_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
response_with_malformed_data = {
|
response_with_malformed_data = {
|
||||||
"channels": [
|
"channels": [
|
||||||
{"id": "C1_missing_name", "is_private": False, "is_member": True},
|
{"id": "C1_missing_name", "is_private": False, "is_member": True},
|
||||||
{"name": "channel_missing_id", "is_private": False, "is_member": True},
|
{"name": "channel_missing_id", "is_private": False, "is_member": True},
|
||||||
{"name": "general", "id": "C2_valid", "is_private": False, "is_member": True}
|
{
|
||||||
|
"name": "general",
|
||||||
|
"id": "C2_valid",
|
||||||
|
"is_private": False,
|
||||||
|
"is_member": True,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"response_metadata": {"next_cursor": ""}
|
"response_metadata": {"next_cursor": ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_client_instance.conversations_list.return_value = response_with_malformed_data
|
mock_client_instance.conversations_list.return_value = (
|
||||||
|
response_with_malformed_data
|
||||||
|
)
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
channels_list = slack_history.get_all_channels(include_private=True)
|
channels_list = slack_history.get_all_channels(include_private=True)
|
||||||
|
|
||||||
expected_channels_list = [
|
expected_channels_list = [
|
||||||
{"id": "C2_valid", "name": "general", "is_private": False, "is_member": True}
|
{
|
||||||
|
"id": "C2_valid",
|
||||||
|
"name": "general",
|
||||||
|
"is_private": False,
|
||||||
|
"is_member": True,
|
||||||
|
}
|
||||||
]
|
]
|
||||||
self.assertEqual(len(channels_list), 1)
|
self.assertEqual(len(channels_list), 1)
|
||||||
self.assertListEqual(channels_list, expected_channels_list)
|
self.assertListEqual(channels_list, expected_channels_list)
|
||||||
|
|
||||||
self.assertEqual(mock_logger.warning.call_count, 2)
|
self.assertEqual(mock_logger.warning.call_count, 2)
|
||||||
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}")
|
mock_logger.warning.assert_any_call(
|
||||||
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}")
|
"Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}"
|
||||||
|
)
|
||||||
|
mock_logger.warning.assert_any_call(
|
||||||
|
"Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}"
|
||||||
|
)
|
||||||
|
|
||||||
mock_sleep.assert_not_called()
|
mock_sleep.assert_not_called()
|
||||||
mock_client_instance.conversations_list.assert_called_once_with(
|
mock_client_instance.conversations_list.assert_called_once_with(
|
||||||
types="public_channel,private_channel", cursor=None, limit=1000
|
types="public_channel,private_channel", cursor=None, limit=1000
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
||||||
class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
def test_proactive_delay_single_page(self, MockWebClient, mock_time_sleep, mock_logger):
|
@patch("slack_sdk.WebClient")
|
||||||
mock_client_instance = MockWebClient.return_value
|
def test_proactive_delay_single_page(
|
||||||
|
self, mock_web_client, mock_time_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
mock_client_instance.conversations_history.return_value = {
|
mock_client_instance.conversations_history.return_value = {
|
||||||
"messages": [{"text": "msg1"}],
|
"messages": [{"text": "msg1"}],
|
||||||
"has_more": False
|
"has_more": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
|
|
@ -231,21 +290,20 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||||
|
|
||||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_proactive_delay_multiple_pages(self, MockWebClient, mock_time_sleep, mock_logger):
|
def test_proactive_delay_multiple_pages(
|
||||||
mock_client_instance = MockWebClient.return_value
|
self, mock_web_client, mock_time_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
mock_client_instance.conversations_history.side_effect = [
|
mock_client_instance.conversations_history.side_effect = [
|
||||||
{
|
{
|
||||||
"messages": [{"text": "msg1"}],
|
"messages": [{"text": "msg1"}],
|
||||||
"has_more": True,
|
"has_more": True,
|
||||||
"response_metadata": {"next_cursor": "cursor1"}
|
"response_metadata": {"next_cursor": "cursor1"},
|
||||||
},
|
},
|
||||||
{
|
{"messages": [{"text": "msg2"}], "has_more": False},
|
||||||
"messages": [{"text": "msg2"}],
|
|
||||||
"has_more": False
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
|
|
@ -255,19 +313,19 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||||
self.assertEqual(mock_time_sleep.call_count, 2)
|
self.assertEqual(mock_time_sleep.call_count, 2)
|
||||||
mock_time_sleep.assert_has_calls([call(1.2), call(1.2)])
|
mock_time_sleep.assert_has_calls([call(1.2), call(1.2)])
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger):
|
def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
|
||||||
mock_client_instance = MockWebClient.return_value
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
mock_error_response = Mock()
|
mock_error_response = Mock()
|
||||||
mock_error_response.status_code = 429
|
mock_error_response.status_code = 429
|
||||||
mock_error_response.headers = {'Retry-After': '5'}
|
mock_error_response.headers = {"Retry-After": "5"}
|
||||||
|
|
||||||
mock_client_instance.conversations_history.side_effect = [
|
mock_client_instance.conversations_history.side_effect = [
|
||||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||||
{"messages": [{"text": "msg1"}], "has_more": False}
|
{"messages": [{"text": "msg1"}], "has_more": False},
|
||||||
]
|
]
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
|
|
@ -277,23 +335,26 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||||
self.assertEqual(messages[0]["text"], "msg1")
|
self.assertEqual(messages[0]["text"], "msg1")
|
||||||
|
|
||||||
# Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt)
|
# Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt)
|
||||||
mock_time_sleep.assert_has_calls([call(1.2), call(5), call(1.2)], any_order=False)
|
mock_time_sleep.assert_has_calls(
|
||||||
|
[call(1.2), call(5), call(1.2)], any_order=False
|
||||||
|
)
|
||||||
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting
|
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_not_in_channel_error(self, MockWebClient, mock_time_sleep, mock_logger):
|
def test_not_in_channel_error(self, mock_web_client, mock_time_sleep, mock_logger):
|
||||||
mock_client_instance = MockWebClient.return_value
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
mock_error_response = Mock()
|
mock_error_response = Mock()
|
||||||
mock_error_response.status_code = 403 # Typical for not_in_channel, but data matters more
|
mock_error_response.status_code = (
|
||||||
mock_error_response.data = {'ok': False, 'error': 'not_in_channel'}
|
403 # Typical for not_in_channel, but data matters more
|
||||||
|
)
|
||||||
|
mock_error_response.data = {"ok": False, "error": "not_in_channel"}
|
||||||
|
|
||||||
# This error is now raised by the inner try-except, then caught by the outer one
|
# This error is now raised by the inner try-except, then caught by the outer one
|
||||||
mock_client_instance.conversations_history.side_effect = SlackApiError(
|
mock_client_instance.conversations_history.side_effect = SlackApiError(
|
||||||
message="not_in_channel error",
|
message="not_in_channel error", response=mock_error_response
|
||||||
response=mock_error_response
|
|
||||||
)
|
)
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
|
|
@ -303,18 +364,24 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||||
mock_logger.warning.assert_called_with(
|
mock_logger.warning.assert_called_with(
|
||||||
"Bot is not in channel 'C123'. Cannot fetch history. Please add the bot to this channel."
|
"Bot is not in channel 'C123'. Cannot fetch history. Please add the bot to this channel."
|
||||||
)
|
)
|
||||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay before the API call
|
mock_time_sleep.assert_called_once_with(
|
||||||
|
1.2
|
||||||
|
) # Proactive delay before the API call
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
def test_other_slack_api_error_propagates(
|
||||||
mock_client_instance = MockWebClient.return_value
|
self, mock_web_client, mock_time_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
mock_error_response = Mock()
|
mock_error_response = Mock()
|
||||||
mock_error_response.status_code = 500
|
mock_error_response.status_code = 500
|
||||||
mock_error_response.data = {'ok': False, 'error': 'internal_error'}
|
mock_error_response.data = {"ok": False, "error": "internal_error"}
|
||||||
original_error = SlackApiError(message="server error", response=mock_error_response)
|
original_error = SlackApiError(
|
||||||
|
message="server error", response=mock_error_response
|
||||||
|
)
|
||||||
|
|
||||||
mock_client_instance.conversations_history.side_effect = original_error
|
mock_client_instance.conversations_history.side_effect = original_error
|
||||||
|
|
||||||
|
|
@ -323,15 +390,19 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||||
with self.assertRaises(SlackApiError) as context:
|
with self.assertRaises(SlackApiError) as context:
|
||||||
slack_history.get_conversation_history(channel_id="C123")
|
slack_history.get_conversation_history(channel_id="C123")
|
||||||
|
|
||||||
self.assertIn("Error retrieving history for channel C123", str(context.exception))
|
self.assertIn(
|
||||||
|
"Error retrieving history for channel C123", str(context.exception)
|
||||||
|
)
|
||||||
self.assertIs(context.exception.response, mock_error_response)
|
self.assertIs(context.exception.response, mock_error_response)
|
||||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
def test_general_exception_propagates(
|
||||||
mock_client_instance = MockWebClient.return_value
|
self, mock_web_client, mock_time_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
original_error = Exception("Something broke")
|
original_error = Exception("Something broke")
|
||||||
mock_client_instance.conversations_history.side_effect = original_error
|
mock_client_instance.conversations_history.side_effect = original_error
|
||||||
|
|
||||||
|
|
@ -340,27 +411,31 @@ class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||||
with self.assertRaises(Exception) as context: # Check for generic Exception
|
with self.assertRaises(Exception) as context: # Check for generic Exception
|
||||||
slack_history.get_conversation_history(channel_id="C123")
|
slack_history.get_conversation_history(channel_id="C123")
|
||||||
|
|
||||||
self.assertIs(context.exception, original_error) # Should re-raise the original error
|
self.assertIs(
|
||||||
mock_logger.error.assert_called_once_with("Unexpected error in get_conversation_history for channel C123: Something broke")
|
context.exception, original_error
|
||||||
|
) # Should re-raise the original error
|
||||||
|
mock_logger.error.assert_called_once_with(
|
||||||
|
"Unexpected error in get_conversation_history for channel C123: Something broke"
|
||||||
|
)
|
||||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||||
|
|
||||||
class TestSlackHistoryGetUserInfo(unittest.TestCase):
|
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
class TestSlackHistoryGetUserInfo(unittest.TestCase):
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger):
|
@patch("slack_sdk.WebClient")
|
||||||
mock_client_instance = MockWebClient.return_value
|
def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
mock_error_response = Mock()
|
mock_error_response = Mock()
|
||||||
mock_error_response.status_code = 429
|
mock_error_response.status_code = 429
|
||||||
mock_error_response.headers = {'Retry-After': '3'} # Using 3 seconds for test
|
mock_error_response.headers = {"Retry-After": "3"} # Using 3 seconds for test
|
||||||
|
|
||||||
successful_user_data = {"id": "U123", "name": "testuser"}
|
successful_user_data = {"id": "U123", "name": "testuser"}
|
||||||
|
|
||||||
mock_client_instance.users_info.side_effect = [
|
mock_client_instance.users_info.side_effect = [
|
||||||
SlackApiError(message="ratelimited_userinfo", response=mock_error_response),
|
SlackApiError(message="ratelimited_userinfo", response=mock_error_response),
|
||||||
{"user": successful_user_data}
|
{"user": successful_user_data},
|
||||||
]
|
]
|
||||||
|
|
||||||
slack_history = SlackHistory(token="fake_token")
|
slack_history = SlackHistory(token="fake_token")
|
||||||
|
|
@ -375,18 +450,26 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase):
|
||||||
)
|
)
|
||||||
# Assert users_info was called twice (original + retry)
|
# Assert users_info was called twice (original + retry)
|
||||||
self.assertEqual(mock_client_instance.users_info.call_count, 2)
|
self.assertEqual(mock_client_instance.users_info.call_count, 2)
|
||||||
mock_client_instance.users_info.assert_has_calls([call(user="U123"), call(user="U123")])
|
mock_client_instance.users_info.assert_has_calls(
|
||||||
|
[call(user="U123"), call(user="U123")]
|
||||||
|
)
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') # time.sleep might be called by other logic, but not expected here
|
@patch(
|
||||||
@patch('slack_sdk.WebClient')
|
"surfsense_backend.app.connectors.slack_history.time.sleep"
|
||||||
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
) # time.sleep might be called by other logic, but not expected here
|
||||||
mock_client_instance = MockWebClient.return_value
|
@patch("slack_sdk.WebClient")
|
||||||
|
def test_other_slack_api_error_propagates(
|
||||||
|
self, mock_web_client, mock_time_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
|
|
||||||
mock_error_response = Mock()
|
mock_error_response = Mock()
|
||||||
mock_error_response.status_code = 500 # Some other error
|
mock_error_response.status_code = 500 # Some other error
|
||||||
mock_error_response.data = {'ok': False, 'error': 'internal_server_error'}
|
mock_error_response.data = {"ok": False, "error": "internal_server_error"}
|
||||||
original_error = SlackApiError(message="internal server error", response=mock_error_response)
|
original_error = SlackApiError(
|
||||||
|
message="internal server error", response=mock_error_response
|
||||||
|
)
|
||||||
|
|
||||||
mock_client_instance.users_info.side_effect = original_error
|
mock_client_instance.users_info.side_effect = original_error
|
||||||
|
|
||||||
|
|
@ -400,11 +483,13 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase):
|
||||||
self.assertIs(context.exception.response, mock_error_response)
|
self.assertIs(context.exception.response, mock_error_response)
|
||||||
mock_time_sleep.assert_not_called() # No rate limit sleep
|
mock_time_sleep.assert_not_called() # No rate limit sleep
|
||||||
|
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||||
@patch('slack_sdk.WebClient')
|
@patch("slack_sdk.WebClient")
|
||||||
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
def test_general_exception_propagates(
|
||||||
mock_client_instance = MockWebClient.return_value
|
self, mock_web_client, mock_time_sleep, mock_logger
|
||||||
|
):
|
||||||
|
mock_client_instance = mock_web_client.return_value
|
||||||
original_error = Exception("A very generic problem")
|
original_error = Exception("A very generic problem")
|
||||||
mock_client_instance.users_info.side_effect = original_error
|
mock_client_instance.users_info.side_effect = original_error
|
||||||
|
|
||||||
|
|
@ -413,7 +498,9 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase):
|
||||||
with self.assertRaises(Exception) as context:
|
with self.assertRaises(Exception) as context:
|
||||||
slack_history.get_user_info(user_id="U123")
|
slack_history.get_user_info(user_id="U123")
|
||||||
|
|
||||||
self.assertIs(context.exception, original_error) # Check it's the exact same exception
|
self.assertIs(
|
||||||
|
context.exception, original_error
|
||||||
|
) # Check it's the exact same exception
|
||||||
mock_logger.error.assert_called_once_with(
|
mock_logger.error.assert_called_once_with(
|
||||||
"Unexpected error in get_user_info for user U123: A very generic problem"
|
"Unexpected error in get_user_info for user U123: A very generic problem"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import datetime, timezone
|
from datetime import UTC, datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from app.config import config
|
|
||||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
|
||||||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
ARRAY,
|
ARRAY,
|
||||||
|
|
@ -13,9 +11,7 @@ from sqlalchemy import (
|
||||||
TIMESTAMP,
|
TIMESTAMP,
|
||||||
Boolean,
|
Boolean,
|
||||||
Column,
|
Column,
|
||||||
)
|
Enum as SQLAlchemyEnum,
|
||||||
from sqlalchemy import Enum as SQLAlchemyEnum
|
|
||||||
from sqlalchemy import (
|
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
String,
|
String,
|
||||||
|
|
@ -26,17 +22,12 @@ from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||||
|
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
from fastapi_users.db import (
|
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
|
||||||
SQLAlchemyBaseOAuthAccountTableUUID,
|
|
||||||
SQLAlchemyBaseUserTableUUID,
|
|
||||||
SQLAlchemyUserDatabase,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from fastapi_users.db import (
|
|
||||||
SQLAlchemyBaseUserTableUUID,
|
|
||||||
SQLAlchemyUserDatabase,
|
|
||||||
)
|
|
||||||
|
|
||||||
DATABASE_URL = config.DATABASE_URL
|
DATABASE_URL = config.DATABASE_URL
|
||||||
|
|
||||||
|
|
@ -118,11 +109,11 @@ class Base(DeclarativeBase):
|
||||||
|
|
||||||
class TimestampMixin:
|
class TimestampMixin:
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def created_at(cls):
|
def created_at(cls): # noqa: N805
|
||||||
return Column(
|
return Column(
|
||||||
TIMESTAMP(timezone=True),
|
TIMESTAMP(timezone=True),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=lambda: datetime.now(timezone.utc),
|
default=lambda: datetime.now(UTC),
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
|
DATE_TODAY = "Today's date is " + datetime.now(UTC).astimezone().isoformat() + "\n"
|
||||||
|
|
||||||
SUMMARY_PROMPT = DATE_TODAY + """
|
SUMMARY_PROMPT = (
|
||||||
|
DATE_TODAY
|
||||||
|
+ """
|
||||||
<INSTRUCTIONS>
|
<INSTRUCTIONS>
|
||||||
<context>
|
<context>
|
||||||
You are an expert document analyst and summarization specialist tasked with distilling complex information into clear,
|
You are an expert document analyst and summarization specialist tasked with distilling complex information into clear,
|
||||||
|
|
@ -96,8 +99,8 @@ SUMMARY_PROMPT = DATE_TODAY + """
|
||||||
</document_to_summarize>
|
</document_to_summarize>
|
||||||
</INSTRUCTIONS>
|
</INSTRUCTIONS>
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
|
SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
|
||||||
input_variables=["document"],
|
input_variables=["document"], template=SUMMARY_PROMPT
|
||||||
template=SUMMARY_PROMPT
|
|
||||||
)
|
)
|
||||||
|
|
@ -8,7 +8,13 @@ class ChucksHybridSearchRetriever:
|
||||||
"""
|
"""
|
||||||
self.db_session = db_session
|
self.db_session = db_session
|
||||||
|
|
||||||
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
async def vector_search(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
top_k: int,
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Perform vector similarity search on chunks.
|
Perform vector similarity search on chunks.
|
||||||
|
|
||||||
|
|
@ -21,10 +27,11 @@ class ChucksHybridSearchRetriever:
|
||||||
Returns:
|
Returns:
|
||||||
List of chunks sorted by vector similarity
|
List of chunks sorted by vector similarity
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select, func
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from app.db import Chunk, Document, SearchSpace
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
from app.db import Chunk, Document, SearchSpace
|
||||||
|
|
||||||
# Get embedding for the query
|
# Get embedding for the query
|
||||||
embedding_model = config.embedding_model_instance
|
embedding_model = config.embedding_model_instance
|
||||||
|
|
@ -44,11 +51,7 @@ class ChucksHybridSearchRetriever:
|
||||||
query = query.where(Document.search_space_id == search_space_id)
|
query = query.where(Document.search_space_id == search_space_id)
|
||||||
|
|
||||||
# Add vector similarity ordering
|
# Add vector similarity ordering
|
||||||
query = (
|
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
|
||||||
query
|
|
||||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
|
||||||
.limit(top_k)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute the query
|
# Execute the query
|
||||||
result = await self.db_session.execute(query)
|
result = await self.db_session.execute(query)
|
||||||
|
|
@ -56,7 +59,13 @@ class ChucksHybridSearchRetriever:
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
async def full_text_search(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
top_k: int,
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Perform full-text keyword search on chunks.
|
Perform full-text keyword search on chunks.
|
||||||
|
|
||||||
|
|
@ -69,13 +78,14 @@ class ChucksHybridSearchRetriever:
|
||||||
Returns:
|
Returns:
|
||||||
List of chunks sorted by text relevance
|
List of chunks sorted by text relevance
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select, func, text
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from app.db import Chunk, Document, SearchSpace
|
from app.db import Chunk, Document, SearchSpace
|
||||||
|
|
||||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||||
tsvector = func.to_tsvector('english', Chunk.content)
|
tsvector = func.to_tsvector("english", Chunk.content)
|
||||||
tsquery = func.plainto_tsquery('english', query_text)
|
tsquery = func.plainto_tsquery("english", query_text)
|
||||||
|
|
||||||
# Build the base query with user ownership check
|
# Build the base query with user ownership check
|
||||||
query = (
|
query = (
|
||||||
|
|
@ -84,7 +94,9 @@ class ChucksHybridSearchRetriever:
|
||||||
.join(Document, Chunk.document_id == Document.id)
|
.join(Document, Chunk.document_id == Document.id)
|
||||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||||
.where(SearchSpace.user_id == user_id)
|
.where(SearchSpace.user_id == user_id)
|
||||||
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
|
.where(
|
||||||
|
tsvector.op("@@")(tsquery)
|
||||||
|
) # Only include results that match the query
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add search space filter if provided
|
# Add search space filter if provided
|
||||||
|
|
@ -92,11 +104,7 @@ class ChucksHybridSearchRetriever:
|
||||||
query = query.where(Document.search_space_id == search_space_id)
|
query = query.where(Document.search_space_id == search_space_id)
|
||||||
|
|
||||||
# Add text search ranking
|
# Add text search ranking
|
||||||
query = (
|
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
|
||||||
query
|
|
||||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
|
||||||
.limit(top_k)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute the query
|
# Execute the query
|
||||||
result = await self.db_session.execute(query)
|
result = await self.db_session.execute(query)
|
||||||
|
|
@ -104,7 +112,14 @@ class ChucksHybridSearchRetriever:
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
|
async def hybrid_search(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
top_k: int,
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
document_type: str | None = None,
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||||
|
|
||||||
|
|
@ -118,10 +133,11 @@ class ChucksHybridSearchRetriever:
|
||||||
Returns:
|
Returns:
|
||||||
List of dictionaries containing chunk data and relevance scores
|
List of dictionaries containing chunk data and relevance scores
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select, func, text
|
from sqlalchemy import func, select, text
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from app.db import Chunk, Document, SearchSpace, DocumentType
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
from app.db import Chunk, Document, DocumentType, SearchSpace
|
||||||
|
|
||||||
# Get embedding for the query
|
# Get embedding for the query
|
||||||
embedding_model = config.embedding_model_instance
|
embedding_model = config.embedding_model_instance
|
||||||
|
|
@ -132,8 +148,8 @@ class ChucksHybridSearchRetriever:
|
||||||
n_results = top_k * 2 # Get more results for better fusion
|
n_results = top_k * 2 # Get more results for better fusion
|
||||||
|
|
||||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||||
tsvector = func.to_tsvector('english', Chunk.content)
|
tsvector = func.to_tsvector("english", Chunk.content)
|
||||||
tsquery = func.plainto_tsquery('english', query_text)
|
tsquery = func.plainto_tsquery("english", query_text)
|
||||||
|
|
||||||
# Base conditions for document filtering
|
# Base conditions for document filtering
|
||||||
base_conditions = [SearchSpace.user_id == user_id]
|
base_conditions = [SearchSpace.user_id == user_id]
|
||||||
|
|
@ -159,7 +175,9 @@ class ChucksHybridSearchRetriever:
|
||||||
semantic_search_cte = (
|
semantic_search_cte = (
|
||||||
select(
|
select(
|
||||||
Chunk.id,
|
Chunk.id,
|
||||||
func.rank().over(order_by=Chunk.embedding.op("<=>")(query_embedding)).label("rank")
|
func.rank()
|
||||||
|
.over(order_by=Chunk.embedding.op("<=>")(query_embedding))
|
||||||
|
.label("rank"),
|
||||||
)
|
)
|
||||||
.join(Document, Chunk.document_id == Document.id)
|
.join(Document, Chunk.document_id == Document.id)
|
||||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||||
|
|
@ -167,8 +185,7 @@ class ChucksHybridSearchRetriever:
|
||||||
)
|
)
|
||||||
|
|
||||||
semantic_search_cte = (
|
semantic_search_cte = (
|
||||||
semantic_search_cte
|
semantic_search_cte.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
|
||||||
.limit(n_results)
|
.limit(n_results)
|
||||||
.cte("semantic_search")
|
.cte("semantic_search")
|
||||||
)
|
)
|
||||||
|
|
@ -177,7 +194,9 @@ class ChucksHybridSearchRetriever:
|
||||||
keyword_search_cte = (
|
keyword_search_cte = (
|
||||||
select(
|
select(
|
||||||
Chunk.id,
|
Chunk.id,
|
||||||
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
|
func.rank()
|
||||||
|
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
|
||||||
|
.label("rank"),
|
||||||
)
|
)
|
||||||
.join(Document, Chunk.document_id == Document.id)
|
.join(Document, Chunk.document_id == Document.id)
|
||||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||||
|
|
@ -186,8 +205,7 @@ class ChucksHybridSearchRetriever:
|
||||||
)
|
)
|
||||||
|
|
||||||
keyword_search_cte = (
|
keyword_search_cte = (
|
||||||
keyword_search_cte
|
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
|
||||||
.limit(n_results)
|
.limit(n_results)
|
||||||
.cte("keyword_search")
|
.cte("keyword_search")
|
||||||
)
|
)
|
||||||
|
|
@ -197,20 +215,21 @@ class ChucksHybridSearchRetriever:
|
||||||
select(
|
select(
|
||||||
Chunk,
|
Chunk,
|
||||||
(
|
(
|
||||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
|
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
|
||||||
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||||
).label("score")
|
).label("score"),
|
||||||
)
|
)
|
||||||
.select_from(
|
.select_from(
|
||||||
semantic_search_cte.outerjoin(
|
semantic_search_cte.outerjoin(
|
||||||
keyword_search_cte,
|
keyword_search_cte,
|
||||||
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
||||||
full=True
|
full=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.join(
|
.join(
|
||||||
Chunk,
|
Chunk,
|
||||||
Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
|
Chunk.id
|
||||||
|
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
|
||||||
)
|
)
|
||||||
.options(joinedload(Chunk.document))
|
.options(joinedload(Chunk.document))
|
||||||
.order_by(text("score DESC"))
|
.order_by(text("score DESC"))
|
||||||
|
|
@ -228,16 +247,20 @@ class ChucksHybridSearchRetriever:
|
||||||
# Convert to serializable dictionaries if no reranker is available or if reranking failed
|
# Convert to serializable dictionaries if no reranker is available or if reranking failed
|
||||||
serialized_results = []
|
serialized_results = []
|
||||||
for chunk, score in chunks_with_scores:
|
for chunk, score in chunks_with_scores:
|
||||||
serialized_results.append({
|
serialized_results.append(
|
||||||
|
{
|
||||||
"chunk_id": chunk.id,
|
"chunk_id": chunk.id,
|
||||||
"content": chunk.content,
|
"content": chunk.content,
|
||||||
"score": float(score), # Ensure score is a Python float
|
"score": float(score), # Ensure score is a Python float
|
||||||
"document": {
|
"document": {
|
||||||
"id": chunk.document.id,
|
"id": chunk.document.id,
|
||||||
"title": chunk.document.title,
|
"title": chunk.document.title,
|
||||||
"document_type": chunk.document.document_type.value if hasattr(chunk.document, 'document_type') else None,
|
"document_type": chunk.document.document_type.value
|
||||||
"metadata": chunk.document.document_metadata
|
if hasattr(chunk.document, "document_type")
|
||||||
|
else None,
|
||||||
|
"metadata": chunk.document.document_metadata,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
return serialized_results
|
return serialized_results
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,13 @@ class DocumentHybridSearchRetriever:
|
||||||
"""
|
"""
|
||||||
self.db_session = db_session
|
self.db_session = db_session
|
||||||
|
|
||||||
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
async def vector_search(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
top_k: int,
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Perform vector similarity search on documents.
|
Perform vector similarity search on documents.
|
||||||
|
|
||||||
|
|
@ -21,10 +27,11 @@ class DocumentHybridSearchRetriever:
|
||||||
Returns:
|
Returns:
|
||||||
List of documents sorted by vector similarity
|
List of documents sorted by vector similarity
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select, func
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from app.db import Document, SearchSpace
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
from app.db import Document, SearchSpace
|
||||||
|
|
||||||
# Get embedding for the query
|
# Get embedding for the query
|
||||||
embedding_model = config.embedding_model_instance
|
embedding_model = config.embedding_model_instance
|
||||||
|
|
@ -43,10 +50,8 @@ class DocumentHybridSearchRetriever:
|
||||||
query = query.where(Document.search_space_id == search_space_id)
|
query = query.where(Document.search_space_id == search_space_id)
|
||||||
|
|
||||||
# Add vector similarity ordering
|
# Add vector similarity ordering
|
||||||
query = (
|
query = query.order_by(Document.embedding.op("<=>")(query_embedding)).limit(
|
||||||
query
|
top_k
|
||||||
.order_by(Document.embedding.op("<=>")(query_embedding))
|
|
||||||
.limit(top_k)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the query
|
# Execute the query
|
||||||
|
|
@ -55,7 +60,13 @@ class DocumentHybridSearchRetriever:
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
async def full_text_search(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
top_k: int,
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Perform full-text keyword search on documents.
|
Perform full-text keyword search on documents.
|
||||||
|
|
||||||
|
|
@ -68,13 +79,14 @@ class DocumentHybridSearchRetriever:
|
||||||
Returns:
|
Returns:
|
||||||
List of documents sorted by text relevance
|
List of documents sorted by text relevance
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select, func, text
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from app.db import Document, SearchSpace
|
from app.db import Document, SearchSpace
|
||||||
|
|
||||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||||
tsvector = func.to_tsvector('english', Document.content)
|
tsvector = func.to_tsvector("english", Document.content)
|
||||||
tsquery = func.plainto_tsquery('english', query_text)
|
tsquery = func.plainto_tsquery("english", query_text)
|
||||||
|
|
||||||
# Build the base query with user ownership check
|
# Build the base query with user ownership check
|
||||||
query = (
|
query = (
|
||||||
|
|
@ -82,7 +94,9 @@ class DocumentHybridSearchRetriever:
|
||||||
.options(joinedload(Document.search_space))
|
.options(joinedload(Document.search_space))
|
||||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||||
.where(SearchSpace.user_id == user_id)
|
.where(SearchSpace.user_id == user_id)
|
||||||
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
|
.where(
|
||||||
|
tsvector.op("@@")(tsquery)
|
||||||
|
) # Only include results that match the query
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add search space filter if provided
|
# Add search space filter if provided
|
||||||
|
|
@ -90,11 +104,7 @@ class DocumentHybridSearchRetriever:
|
||||||
query = query.where(Document.search_space_id == search_space_id)
|
query = query.where(Document.search_space_id == search_space_id)
|
||||||
|
|
||||||
# Add text search ranking
|
# Add text search ranking
|
||||||
query = (
|
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
|
||||||
query
|
|
||||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
|
||||||
.limit(top_k)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute the query
|
# Execute the query
|
||||||
result = await self.db_session.execute(query)
|
result = await self.db_session.execute(query)
|
||||||
|
|
@ -102,7 +112,14 @@ class DocumentHybridSearchRetriever:
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
|
async def hybrid_search(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
top_k: int,
|
||||||
|
user_id: str,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
document_type: str | None = None,
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||||
|
|
||||||
|
|
@ -114,10 +131,11 @@ class DocumentHybridSearchRetriever:
|
||||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select, func, text
|
from sqlalchemy import func, select, text
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from app.db import Document, SearchSpace, DocumentType
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
from app.db import Document, DocumentType, SearchSpace
|
||||||
|
|
||||||
# Get embedding for the query
|
# Get embedding for the query
|
||||||
embedding_model = config.embedding_model_instance
|
embedding_model = config.embedding_model_instance
|
||||||
|
|
@ -128,8 +146,8 @@ class DocumentHybridSearchRetriever:
|
||||||
n_results = top_k * 2 # Get more results for better fusion
|
n_results = top_k * 2 # Get more results for better fusion
|
||||||
|
|
||||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||||
tsvector = func.to_tsvector('english', Document.content)
|
tsvector = func.to_tsvector("english", Document.content)
|
||||||
tsquery = func.plainto_tsquery('english', query_text)
|
tsquery = func.plainto_tsquery("english", query_text)
|
||||||
|
|
||||||
# Base conditions for document filtering
|
# Base conditions for document filtering
|
||||||
base_conditions = [SearchSpace.user_id == user_id]
|
base_conditions = [SearchSpace.user_id == user_id]
|
||||||
|
|
@ -155,15 +173,16 @@ class DocumentHybridSearchRetriever:
|
||||||
semantic_search_cte = (
|
semantic_search_cte = (
|
||||||
select(
|
select(
|
||||||
Document.id,
|
Document.id,
|
||||||
func.rank().over(order_by=Document.embedding.op("<=>")(query_embedding)).label("rank")
|
func.rank()
|
||||||
|
.over(order_by=Document.embedding.op("<=>")(query_embedding))
|
||||||
|
.label("rank"),
|
||||||
)
|
)
|
||||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||||
.where(*base_conditions)
|
.where(*base_conditions)
|
||||||
)
|
)
|
||||||
|
|
||||||
semantic_search_cte = (
|
semantic_search_cte = (
|
||||||
semantic_search_cte
|
semantic_search_cte.order_by(Document.embedding.op("<=>")(query_embedding))
|
||||||
.order_by(Document.embedding.op("<=>")(query_embedding))
|
|
||||||
.limit(n_results)
|
.limit(n_results)
|
||||||
.cte("semantic_search")
|
.cte("semantic_search")
|
||||||
)
|
)
|
||||||
|
|
@ -172,7 +191,9 @@ class DocumentHybridSearchRetriever:
|
||||||
keyword_search_cte = (
|
keyword_search_cte = (
|
||||||
select(
|
select(
|
||||||
Document.id,
|
Document.id,
|
||||||
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
|
func.rank()
|
||||||
|
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
|
||||||
|
.label("rank"),
|
||||||
)
|
)
|
||||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||||
.where(*base_conditions)
|
.where(*base_conditions)
|
||||||
|
|
@ -180,8 +201,7 @@ class DocumentHybridSearchRetriever:
|
||||||
)
|
)
|
||||||
|
|
||||||
keyword_search_cte = (
|
keyword_search_cte = (
|
||||||
keyword_search_cte
|
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
|
||||||
.limit(n_results)
|
.limit(n_results)
|
||||||
.cte("keyword_search")
|
.cte("keyword_search")
|
||||||
)
|
)
|
||||||
|
|
@ -191,20 +211,21 @@ class DocumentHybridSearchRetriever:
|
||||||
select(
|
select(
|
||||||
Document,
|
Document,
|
||||||
(
|
(
|
||||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
|
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
|
||||||
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||||
).label("score")
|
).label("score"),
|
||||||
)
|
)
|
||||||
.select_from(
|
.select_from(
|
||||||
semantic_search_cte.outerjoin(
|
semantic_search_cte.outerjoin(
|
||||||
keyword_search_cte,
|
keyword_search_cte,
|
||||||
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
||||||
full=True
|
full=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.join(
|
.join(
|
||||||
Document,
|
Document,
|
||||||
Document.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
|
Document.id
|
||||||
|
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
|
||||||
)
|
)
|
||||||
.options(joinedload(Document.search_space))
|
.options(joinedload(Document.search_space))
|
||||||
.order_by(text("score DESC"))
|
.order_by(text("score DESC"))
|
||||||
|
|
@ -224,24 +245,35 @@ class DocumentHybridSearchRetriever:
|
||||||
for document, score in documents_with_scores:
|
for document, score in documents_with_scores:
|
||||||
# Fetch associated chunks for this document
|
# Fetch associated chunks for this document
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.db import Chunk
|
from app.db import Chunk
|
||||||
|
|
||||||
chunks_query = select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
|
chunks_query = (
|
||||||
|
select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
|
||||||
|
)
|
||||||
chunks_result = await self.db_session.execute(chunks_query)
|
chunks_result = await self.db_session.execute(chunks_query)
|
||||||
chunks = chunks_result.scalars().all()
|
chunks = chunks_result.scalars().all()
|
||||||
|
|
||||||
# Concatenate chunks content
|
# Concatenate chunks content
|
||||||
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else document.content
|
concatenated_chunks_content = (
|
||||||
|
" ".join([chunk.content for chunk in chunks])
|
||||||
|
if chunks
|
||||||
|
else document.content
|
||||||
|
)
|
||||||
|
|
||||||
serialized_results.append({
|
serialized_results.append(
|
||||||
|
{
|
||||||
"document_id": document.id,
|
"document_id": document.id,
|
||||||
"title": document.title,
|
"title": document.title,
|
||||||
"content": document.content,
|
"content": document.content,
|
||||||
"chunks_content": concatenated_chunks_content,
|
"chunks_content": concatenated_chunks_content,
|
||||||
"document_type": document.document_type.value if hasattr(document, 'document_type') else None,
|
"document_type": document.document_type.value
|
||||||
|
if hasattr(document, "document_type")
|
||||||
|
else None,
|
||||||
"metadata": document.document_metadata,
|
"metadata": document.document_metadata,
|
||||||
"score": float(score), # Ensure score is a Python float
|
"score": float(score), # Ensure score is a Python float
|
||||||
"search_space_id": document.search_space_id
|
"search_space_id": document.search_space_id,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return serialized_results
|
return serialized_results
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from .search_spaces_routes import router as search_spaces_router
|
|
||||||
from .documents_routes import router as documents_router
|
|
||||||
from .podcasts_routes import router as podcasts_router
|
|
||||||
from .chats_routes import router as chats_router
|
from .chats_routes import router as chats_router
|
||||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
from .documents_routes import router as documents_router
|
||||||
from .llm_config_routes import router as llm_config_router
|
from .llm_config_routes import router as llm_config_router
|
||||||
from .logs_routes import router as logs_router
|
from .logs_routes import router as logs_router
|
||||||
|
from .podcasts_routes import router as podcasts_router
|
||||||
|
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||||
|
from .search_spaces_routes import router as search_spaces_router
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,38 +1,40 @@
|
||||||
from typing import List
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from langchain.schema import AIMessage, HumanMessage
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.db import Chat, SearchSpace, User, get_async_session
|
from app.db import Chat, SearchSpace, User, get_async_session
|
||||||
from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate
|
from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate
|
||||||
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.check_ownership import check_ownership
|
from app.utils.check_ownership import check_ownership
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
from langchain.schema import HumanMessage, AIMessage
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat")
|
@router.post("/chat")
|
||||||
async def handle_chat_data(
|
async def handle_chat_data(
|
||||||
request: AISDKChatRequest,
|
request: AISDKChatRequest,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
if messages[-1]['role'] != "user":
|
if messages[-1]["role"] != "user":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Last message must be a user message")
|
status_code=400, detail="Last message must be a user message"
|
||||||
|
)
|
||||||
|
|
||||||
user_query = messages[-1]['content']
|
user_query = messages[-1]["content"]
|
||||||
search_space_id = request.data.get('search_space_id')
|
search_space_id = request.data.get("search_space_id")
|
||||||
research_mode: str = request.data.get('research_mode')
|
research_mode: str = request.data.get("research_mode")
|
||||||
selected_connectors: List[str] = request.data.get('selected_connectors')
|
selected_connectors: list[str] = request.data.get("selected_connectors")
|
||||||
document_ids_to_add_in_context: List[int] = request.data.get('document_ids_to_add_in_context')
|
document_ids_to_add_in_context: list[int] = request.data.get(
|
||||||
|
"document_ids_to_add_in_context"
|
||||||
|
)
|
||||||
|
|
||||||
search_mode_str = request.data.get('search_mode', "CHUNKS")
|
search_mode_str = request.data.get("search_mode", "CHUNKS")
|
||||||
|
|
||||||
# Convert search_space_id to integer if it's a string
|
# Convert search_space_id to integer if it's a string
|
||||||
if search_space_id and isinstance(search_space_id, str):
|
if search_space_id and isinstance(search_space_id, str):
|
||||||
|
|
@ -40,21 +42,23 @@ async def handle_chat_data(
|
||||||
search_space_id = int(search_space_id)
|
search_space_id = int(search_space_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Invalid search_space_id format")
|
status_code=400, detail="Invalid search_space_id format"
|
||||||
|
) from None
|
||||||
|
|
||||||
# Check if the search space belongs to the current user
|
# Check if the search space belongs to the current user
|
||||||
try:
|
try:
|
||||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403, detail="You don't have access to this search space")
|
status_code=403, detail="You don't have access to this search space"
|
||||||
|
) from None
|
||||||
|
|
||||||
langchain_chat_history = []
|
langchain_chat_history = []
|
||||||
for message in messages[:-1]:
|
for message in messages[:-1]:
|
||||||
if message['role'] == "user":
|
if message["role"] == "user":
|
||||||
langchain_chat_history.append(HumanMessage(content=message['content']))
|
langchain_chat_history.append(HumanMessage(content=message["content"]))
|
||||||
elif message['role'] == "assistant":
|
elif message["role"] == "assistant":
|
||||||
langchain_chat_history.append(AIMessage(content=message['content']))
|
langchain_chat_history.append(AIMessage(content=message["content"]))
|
||||||
|
|
||||||
response = StreamingResponse(
|
response = StreamingResponse(
|
||||||
stream_connector_search_results(
|
stream_connector_search_results(
|
||||||
|
|
@ -78,7 +82,7 @@ async def handle_chat_data(
|
||||||
async def create_chat(
|
async def create_chat(
|
||||||
chat: ChatCreate,
|
chat: ChatCreate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
await check_ownership(session, SearchSpace, chat.search_space_id, user)
|
await check_ownership(session, SearchSpace, chat.search_space_id, user)
|
||||||
|
|
@ -89,27 +93,32 @@ async def create_chat(
|
||||||
return db_chat
|
return db_chat
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except IntegrityError as e:
|
except IntegrityError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
status_code=400,
|
||||||
except OperationalError as e:
|
detail="Database constraint violation. Please check your input data.",
|
||||||
|
) from None
|
||||||
|
except OperationalError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=503, detail="Database operation failed. Please try again later.")
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
except Exception as e:
|
) from None
|
||||||
|
except Exception:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail="An unexpected error occurred while creating the chat.")
|
status_code=500,
|
||||||
|
detail="An unexpected error occurred while creating the chat.",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chats/", response_model=List[ChatRead])
|
@router.get("/chats/", response_model=list[ChatRead])
|
||||||
async def read_chats(
|
async def read_chats(
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
search_space_id: int = None,
|
search_space_id: int | None = None,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
query = select(Chat).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
query = select(Chat).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
||||||
|
|
@ -118,23 +127,23 @@ async def read_chats(
|
||||||
if search_space_id is not None:
|
if search_space_id is not None:
|
||||||
query = query.filter(Chat.search_space_id == search_space_id)
|
query = query.filter(Chat.search_space_id == search_space_id)
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(query.offset(skip).limit(limit))
|
||||||
query.offset(skip).limit(limit)
|
|
||||||
)
|
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
except OperationalError:
|
except OperationalError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=503, detail="Database operation failed. Please try again later.")
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail="An unexpected error occurred while fetching chats.")
|
status_code=500, detail="An unexpected error occurred while fetching chats."
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chats/{chat_id}", response_model=ChatRead)
|
@router.get("/chats/{chat_id}", response_model=ChatRead)
|
||||||
async def read_chat(
|
async def read_chat(
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
|
|
@ -145,14 +154,19 @@ async def read_chat(
|
||||||
chat = result.scalars().first()
|
chat = result.scalars().first()
|
||||||
if not chat:
|
if not chat:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail="Chat not found or you don't have permission to access it")
|
status_code=404,
|
||||||
|
detail="Chat not found or you don't have permission to access it",
|
||||||
|
)
|
||||||
return chat
|
return chat
|
||||||
except OperationalError:
|
except OperationalError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=503, detail="Database operation failed. Please try again later.")
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail="An unexpected error occurred while fetching the chat.")
|
status_code=500,
|
||||||
|
detail="An unexpected error occurred while fetching the chat.",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
@router.put("/chats/{chat_id}", response_model=ChatRead)
|
@router.put("/chats/{chat_id}", response_model=ChatRead)
|
||||||
|
|
@ -160,7 +174,7 @@ async def update_chat(
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
chat_update: ChatUpdate,
|
chat_update: ChatUpdate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
db_chat = await read_chat(chat_id, session, user)
|
db_chat = await read_chat(chat_id, session, user)
|
||||||
|
|
@ -175,22 +189,27 @@ async def update_chat(
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
status_code=400,
|
||||||
|
detail="Database constraint violation. Please check your input data.",
|
||||||
|
) from None
|
||||||
except OperationalError:
|
except OperationalError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=503, detail="Database operation failed. Please try again later.")
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
except Exception:
|
except Exception:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail="An unexpected error occurred while updating the chat.")
|
status_code=500,
|
||||||
|
detail="An unexpected error occurred while updating the chat.",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/chats/{chat_id}", response_model=dict)
|
@router.delete("/chats/{chat_id}", response_model=dict)
|
||||||
async def delete_chat(
|
async def delete_chat(
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
db_chat = await read_chat(chat_id, session, user)
|
db_chat = await read_chat(chat_id, session, user)
|
||||||
|
|
@ -202,81 +221,16 @@ async def delete_chat(
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Cannot delete chat due to existing dependencies.")
|
status_code=400, detail="Cannot delete chat due to existing dependencies."
|
||||||
|
) from None
|
||||||
except OperationalError:
|
except OperationalError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=503, detail="Database operation failed. Please try again later.")
|
status_code=503, detail="Database operation failed. Please try again later."
|
||||||
|
) from None
|
||||||
except Exception:
|
except Exception:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail="An unexpected error occurred while deleting the chat.")
|
status_code=500,
|
||||||
|
detail="An unexpected error occurred while deleting the chat.",
|
||||||
|
) from None
|
||||||
# test_data = [
|
|
||||||
# {
|
|
||||||
# "type": "TERMINAL_INFO",
|
|
||||||
# "content": [
|
|
||||||
# {
|
|
||||||
# "id": 1,
|
|
||||||
# "text": "Starting to search for crawled URLs...",
|
|
||||||
# "type": "info"
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "id": 2,
|
|
||||||
# "text": "Found 2 relevant crawled URLs",
|
|
||||||
# "type": "success"
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "type": "SOURCES",
|
|
||||||
# "content": [
|
|
||||||
# {
|
|
||||||
# "id": 1,
|
|
||||||
# "name": "Crawled URLs",
|
|
||||||
# "type": "CRAWLED_URL",
|
|
||||||
# "sources": [
|
|
||||||
# {
|
|
||||||
# "id": 1,
|
|
||||||
# "title": "Webpage Title",
|
|
||||||
# "description": "Webpage Dec",
|
|
||||||
# "url": "https://jsoneditoronline.org/"
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "id": 2,
|
|
||||||
# "title": "Webpage Title",
|
|
||||||
# "description": "Webpage Dec",
|
|
||||||
# "url": "https://www.google.com/"
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "id": 2,
|
|
||||||
# "name": "Files",
|
|
||||||
# "type": "FILE",
|
|
||||||
# "sources": [
|
|
||||||
# {
|
|
||||||
# "id": 3,
|
|
||||||
# "title": "Webpage Title",
|
|
||||||
# "description": "Webpage Dec",
|
|
||||||
# "url": "https://jsoneditoronline.org/"
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "id": 4,
|
|
||||||
# "title": "Webpage Title",
|
|
||||||
# "description": "Webpage Dec",
|
|
||||||
# "url": "https://www.google.com/"
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "type": "ANSWER",
|
|
||||||
# "content": [
|
|
||||||
# "## SurfSense Introduction",
|
|
||||||
# "Surfsense is A Personal NotebookLM and Perplexity-like AI Assistant for Everyone. Research and Never forget Anything. [1] [3]"
|
|
||||||
# ]
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,35 @@
|
||||||
from litellm import atranscription
|
|
||||||
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
from typing import List
|
|
||||||
from app.db import Log, get_async_session, User, SearchSpace, Document, DocumentType
|
|
||||||
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
|
|
||||||
from app.users import current_active_user
|
|
||||||
from app.utils.check_ownership import check_ownership
|
|
||||||
from app.tasks.background_tasks import add_received_markdown_file_document, add_extension_received_document, add_received_file_document_using_unstructured, add_crawled_url_document, add_youtube_video_document, add_received_file_document_using_llamacloud, add_received_file_document_using_docling
|
|
||||||
from app.config import config as app_config
|
|
||||||
# Force asyncio to use standard event loop before unstructured imports
|
# Force asyncio to use standard event loop before unstructured imports
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, UploadFile
|
||||||
|
from litellm import atranscription
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.config import config as app_config
|
||||||
|
from app.db import Document, DocumentType, Log, SearchSpace, User, get_async_session
|
||||||
|
from app.schemas import DocumentRead, DocumentsCreate, DocumentUpdate
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
|
from app.tasks.background_tasks import (
|
||||||
|
add_crawled_url_document,
|
||||||
|
add_extension_received_document,
|
||||||
|
add_received_file_document_using_docling,
|
||||||
|
add_received_file_document_using_llamacloud,
|
||||||
|
add_received_file_document_using_unstructured,
|
||||||
|
add_received_markdown_file_document,
|
||||||
|
add_youtube_video_document,
|
||||||
|
)
|
||||||
|
from app.users import current_active_user
|
||||||
|
from app.utils.check_ownership import check_ownership
|
||||||
|
|
||||||
try:
|
try:
|
||||||
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
||||||
except RuntimeError:
|
except RuntimeError as e:
|
||||||
|
print("Error setting event loop policy", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -29,7 +41,7 @@ async def create_documents(
|
||||||
request: DocumentsCreate,
|
request: DocumentsCreate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Check if the user owns the search space
|
# Check if the user owns the search space
|
||||||
|
|
@ -41,7 +53,7 @@ async def create_documents(
|
||||||
process_extension_document_with_new_session,
|
process_extension_document_with_new_session,
|
||||||
individual_document,
|
individual_document,
|
||||||
request.search_space_id,
|
request.search_space_id,
|
||||||
str(user.id)
|
str(user.id),
|
||||||
)
|
)
|
||||||
elif request.document_type == DocumentType.CRAWLED_URL:
|
elif request.document_type == DocumentType.CRAWLED_URL:
|
||||||
for url in request.content:
|
for url in request.content:
|
||||||
|
|
@ -49,7 +61,7 @@ async def create_documents(
|
||||||
process_crawled_url_with_new_session,
|
process_crawled_url_with_new_session,
|
||||||
url,
|
url,
|
||||||
request.search_space_id,
|
request.search_space_id,
|
||||||
str(user.id)
|
str(user.id),
|
||||||
)
|
)
|
||||||
elif request.document_type == DocumentType.YOUTUBE_VIDEO:
|
elif request.document_type == DocumentType.YOUTUBE_VIDEO:
|
||||||
for url in request.content:
|
for url in request.content:
|
||||||
|
|
@ -57,13 +69,10 @@ async def create_documents(
|
||||||
process_youtube_video_with_new_session,
|
process_youtube_video_with_new_session,
|
||||||
url,
|
url,
|
||||||
request.search_space_id,
|
request.search_space_id,
|
||||||
str(user.id)
|
str(user.id),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=400, detail="Invalid document type")
|
||||||
status_code=400,
|
|
||||||
detail="Invalid document type"
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return {"message": "Documents processed successfully"}
|
return {"message": "Documents processed successfully"}
|
||||||
|
|
@ -72,18 +81,17 @@ async def create_documents(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to process documents: {e!s}"
|
||||||
detail=f"Failed to process documents: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/documents/fileupload")
|
@router.post("/documents/fileupload")
|
||||||
async def create_documents(
|
async def create_documents_file_upload(
|
||||||
files: list[UploadFile],
|
files: list[UploadFile],
|
||||||
search_space_id: int = Form(...),
|
search_space_id: int = Form(...),
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||||
|
|
@ -94,12 +102,13 @@ async def create_documents(
|
||||||
for file in files:
|
for file in files:
|
||||||
try:
|
try:
|
||||||
# Save file to a temporary location to avoid stream issues
|
# Save file to a temporary location to avoid stream issues
|
||||||
import tempfile
|
|
||||||
import aiofiles
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
# Create temp file
|
# Create temp file
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
with tempfile.NamedTemporaryFile(
|
||||||
|
delete=False, suffix=os.path.splitext(file.filename)[1]
|
||||||
|
) as temp_file:
|
||||||
temp_path = temp_file.name
|
temp_path = temp_file.name
|
||||||
|
|
||||||
# Write uploaded file to temp file
|
# Write uploaded file to temp file
|
||||||
|
|
@ -112,13 +121,13 @@ async def create_documents(
|
||||||
temp_path,
|
temp_path,
|
||||||
file.filename,
|
file.filename,
|
||||||
search_space_id,
|
search_space_id,
|
||||||
str(user.id)
|
str(user.id),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
detail=f"Failed to process file {file.filename}: {str(e)}"
|
detail=f"Failed to process file {file.filename}: {e!s}",
|
||||||
)
|
) from e
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return {"message": "Files uploaded for processing"}
|
return {"message": "Files uploaded for processing"}
|
||||||
|
|
@ -127,9 +136,8 @@ async def create_documents(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to upload files: {e!s}"
|
||||||
detail=f"Failed to upload files: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def process_file_in_background(
|
async def process_file_in_background(
|
||||||
|
|
@ -139,62 +147,69 @@ async def process_file_in_background(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
task_logger: TaskLoggingService,
|
task_logger: TaskLoggingService,
|
||||||
log_entry: Log
|
log_entry: Log,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Check if the file is a markdown or text file
|
# Check if the file is a markdown or text file
|
||||||
if filename.lower().endswith(('.md', '.markdown', '.txt')):
|
if filename.lower().endswith((".md", ".markdown", ".txt")):
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing markdown/text file: {filename}",
|
f"Processing markdown/text file: {filename}",
|
||||||
{"file_type": "markdown", "processing_stage": "reading_file"}
|
{"file_type": "markdown", "processing_stage": "reading_file"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# For markdown files, read the content directly
|
# For markdown files, read the content directly
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
markdown_content = f.read()
|
markdown_content = f.read()
|
||||||
|
|
||||||
# Clean up the temp file
|
# Clean up the temp file
|
||||||
import os
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.unlink(file_path)
|
os.unlink(file_path)
|
||||||
except:
|
except Exception as e:
|
||||||
|
print("Error deleting temp file", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Creating document from markdown content: {filename}",
|
f"Creating document from markdown content: {filename}",
|
||||||
{"processing_stage": "creating_document", "content_length": len(markdown_content)}
|
{
|
||||||
|
"processing_stage": "creating_document",
|
||||||
|
"content_length": len(markdown_content),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process markdown directly through specialized function
|
# Process markdown directly through specialized function
|
||||||
result = await add_received_markdown_file_document(
|
result = await add_received_markdown_file_document(
|
||||||
session,
|
session, filename, markdown_content, search_space_id, user_id
|
||||||
filename,
|
|
||||||
markdown_content,
|
|
||||||
search_space_id,
|
|
||||||
user_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Successfully processed markdown file: {filename}",
|
f"Successfully processed markdown file: {filename}",
|
||||||
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "markdown"}
|
{
|
||||||
|
"document_id": result.id,
|
||||||
|
"content_hash": result.content_hash,
|
||||||
|
"file_type": "markdown",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Markdown file already exists (duplicate): {filename}",
|
f"Markdown file already exists (duplicate): {filename}",
|
||||||
{"duplicate_detected": True, "file_type": "markdown"}
|
{"duplicate_detected": True, "file_type": "markdown"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the file is an audio file
|
# Check if the file is an audio file
|
||||||
elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')):
|
elif filename.lower().endswith(
|
||||||
|
(".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
|
||||||
|
):
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing audio file for transcription: {filename}",
|
f"Processing audio file for transcription: {filename}",
|
||||||
{"file_type": "audio", "processing_stage": "starting_transcription"}
|
{"file_type": "audio", "processing_stage": "starting_transcription"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Open the audio file for transcription
|
# Open the audio file for transcription
|
||||||
|
|
@ -205,53 +220,60 @@ async def process_file_in_background(
|
||||||
model=app_config.STT_SERVICE,
|
model=app_config.STT_SERVICE,
|
||||||
file=audio_file,
|
file=audio_file,
|
||||||
api_base=app_config.STT_SERVICE_API_BASE,
|
api_base=app_config.STT_SERVICE_API_BASE,
|
||||||
api_key=app_config.STT_SERVICE_API_KEY
|
api_key=app_config.STT_SERVICE_API_KEY,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
transcription_response = await atranscription(
|
transcription_response = await atranscription(
|
||||||
model=app_config.STT_SERVICE,
|
model=app_config.STT_SERVICE,
|
||||||
api_key=app_config.STT_SERVICE_API_KEY,
|
api_key=app_config.STT_SERVICE_API_KEY,
|
||||||
file=audio_file
|
file=audio_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract the transcribed text
|
# Extract the transcribed text
|
||||||
transcribed_text = transcription_response.get("text", "")
|
transcribed_text = transcription_response.get("text", "")
|
||||||
|
|
||||||
# Add metadata about the transcription
|
# Add metadata about the transcription
|
||||||
transcribed_text = f"# Transcription of {filename}\n\n{transcribed_text}"
|
transcribed_text = (
|
||||||
|
f"# Transcription of {filename}\n\n{transcribed_text}"
|
||||||
|
)
|
||||||
|
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Transcription completed, creating document: {filename}",
|
f"Transcription completed, creating document: {filename}",
|
||||||
{"processing_stage": "transcription_complete", "transcript_length": len(transcribed_text)}
|
{
|
||||||
|
"processing_stage": "transcription_complete",
|
||||||
|
"transcript_length": len(transcribed_text),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up the temp file
|
# Clean up the temp file
|
||||||
try:
|
try:
|
||||||
os.unlink(file_path)
|
os.unlink(file_path)
|
||||||
except:
|
except Exception as e:
|
||||||
|
print("Error deleting temp file", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Process transcription as markdown document
|
# Process transcription as markdown document
|
||||||
result = await add_received_markdown_file_document(
|
result = await add_received_markdown_file_document(
|
||||||
session,
|
session, filename, transcribed_text, search_space_id, user_id
|
||||||
filename,
|
|
||||||
transcribed_text,
|
|
||||||
search_space_id,
|
|
||||||
user_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Successfully transcribed and processed audio file: {filename}",
|
f"Successfully transcribed and processed audio file: {filename}",
|
||||||
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "audio", "transcript_length": len(transcribed_text)}
|
{
|
||||||
|
"document_id": result.id,
|
||||||
|
"content_hash": result.content_hash,
|
||||||
|
"file_type": "audio",
|
||||||
|
"transcript_length": len(transcribed_text),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Audio file transcript already exists (duplicate): {filename}",
|
f"Audio file transcript already exists (duplicate): {filename}",
|
||||||
{"duplicate_detected": True, "file_type": "audio"}
|
{"duplicate_detected": True, "file_type": "audio"},
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -259,7 +281,11 @@ async def process_file_in_background(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing file with Unstructured ETL: {filename}",
|
f"Processing file with Unstructured ETL: {filename}",
|
||||||
{"file_type": "document", "etl_service": "UNSTRUCTURED", "processing_stage": "loading"}
|
{
|
||||||
|
"file_type": "document",
|
||||||
|
"etl_service": "UNSTRUCTURED",
|
||||||
|
"processing_stage": "loading",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_unstructured import UnstructuredLoader
|
from langchain_unstructured import UnstructuredLoader
|
||||||
|
|
@ -280,56 +306,66 @@ async def process_file_in_background(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Unstructured ETL completed, creating document: {filename}",
|
f"Unstructured ETL completed, creating document: {filename}",
|
||||||
{"processing_stage": "etl_complete", "elements_count": len(docs)}
|
{"processing_stage": "etl_complete", "elements_count": len(docs)},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up the temp file
|
# Clean up the temp file
|
||||||
import os
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.unlink(file_path)
|
os.unlink(file_path)
|
||||||
except:
|
except Exception as e:
|
||||||
|
print("Error deleting temp file", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Pass the documents to the existing background task
|
# Pass the documents to the existing background task
|
||||||
result = await add_received_file_document_using_unstructured(
|
result = await add_received_file_document_using_unstructured(
|
||||||
session,
|
session, filename, docs, search_space_id, user_id
|
||||||
filename,
|
|
||||||
docs,
|
|
||||||
search_space_id,
|
|
||||||
user_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Successfully processed file with Unstructured: {filename}",
|
f"Successfully processed file with Unstructured: {filename}",
|
||||||
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "document", "etl_service": "UNSTRUCTURED"}
|
{
|
||||||
|
"document_id": result.id,
|
||||||
|
"content_hash": result.content_hash,
|
||||||
|
"file_type": "document",
|
||||||
|
"etl_service": "UNSTRUCTURED",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Document already exists (duplicate): {filename}",
|
f"Document already exists (duplicate): {filename}",
|
||||||
{"duplicate_detected": True, "file_type": "document", "etl_service": "UNSTRUCTURED"}
|
{
|
||||||
|
"duplicate_detected": True,
|
||||||
|
"file_type": "document",
|
||||||
|
"etl_service": "UNSTRUCTURED",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
elif app_config.ETL_SERVICE == "LLAMACLOUD":
|
elif app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing file with LlamaCloud ETL: {filename}",
|
f"Processing file with LlamaCloud ETL: {filename}",
|
||||||
{"file_type": "document", "etl_service": "LLAMACLOUD", "processing_stage": "parsing"}
|
{
|
||||||
|
"file_type": "document",
|
||||||
|
"etl_service": "LLAMACLOUD",
|
||||||
|
"processing_stage": "parsing",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_cloud_services import LlamaParse
|
from llama_cloud_services import LlamaParse
|
||||||
from llama_cloud_services.parse.utils import ResultType
|
from llama_cloud_services.parse.utils import ResultType
|
||||||
|
|
||||||
|
|
||||||
# Create LlamaParse parser instance
|
# Create LlamaParse parser instance
|
||||||
parser = LlamaParse(
|
parser = LlamaParse(
|
||||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||||
num_workers=1, # Use single worker for file processing
|
num_workers=1, # Use single worker for file processing
|
||||||
verbose=True,
|
verbose=True,
|
||||||
language="en",
|
language="en",
|
||||||
result_type=ResultType.MD
|
result_type=ResultType.MD,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse the file asynchronously
|
# Parse the file asynchronously
|
||||||
|
|
@ -337,18 +373,25 @@ async def process_file_in_background(
|
||||||
|
|
||||||
# Clean up the temp file
|
# Clean up the temp file
|
||||||
import os
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.unlink(file_path)
|
os.unlink(file_path)
|
||||||
except:
|
except Exception as e:
|
||||||
|
print("Error deleting temp file", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Get markdown documents from the result
|
# Get markdown documents from the result
|
||||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
markdown_documents = await result.aget_markdown_documents(
|
||||||
|
split_by_page=False
|
||||||
|
)
|
||||||
|
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"LlamaCloud parsing completed, creating documents: {filename}",
|
f"LlamaCloud parsing completed, creating documents: {filename}",
|
||||||
{"processing_stage": "parsing_complete", "documents_count": len(markdown_documents)}
|
{
|
||||||
|
"processing_stage": "parsing_complete",
|
||||||
|
"documents_count": len(markdown_documents),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for doc in markdown_documents:
|
for doc in markdown_documents:
|
||||||
|
|
@ -361,27 +404,40 @@ async def process_file_in_background(
|
||||||
filename,
|
filename,
|
||||||
llamacloud_markdown_document=markdown_content,
|
llamacloud_markdown_document=markdown_content,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if doc_result:
|
if doc_result:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Successfully processed file with LlamaCloud: {filename}",
|
f"Successfully processed file with LlamaCloud: {filename}",
|
||||||
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "LLAMACLOUD"}
|
{
|
||||||
|
"document_id": doc_result.id,
|
||||||
|
"content_hash": doc_result.content_hash,
|
||||||
|
"file_type": "document",
|
||||||
|
"etl_service": "LLAMACLOUD",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Document already exists (duplicate): {filename}",
|
f"Document already exists (duplicate): {filename}",
|
||||||
{"duplicate_detected": True, "file_type": "document", "etl_service": "LLAMACLOUD"}
|
{
|
||||||
|
"duplicate_detected": True,
|
||||||
|
"file_type": "document",
|
||||||
|
"etl_service": "LLAMACLOUD",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
elif app_config.ETL_SERVICE == "DOCLING":
|
elif app_config.ETL_SERVICE == "DOCLING":
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing file with Docling ETL: {filename}",
|
f"Processing file with Docling ETL: {filename}",
|
||||||
{"file_type": "document", "etl_service": "DOCLING", "processing_stage": "parsing"}
|
{
|
||||||
|
"file_type": "document",
|
||||||
|
"etl_service": "DOCLING",
|
||||||
|
"processing_stage": "parsing",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use Docling service for document processing
|
# Use Docling service for document processing
|
||||||
|
|
@ -395,97 +451,112 @@ async def process_file_in_background(
|
||||||
|
|
||||||
# Clean up the temp file
|
# Clean up the temp file
|
||||||
import os
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.unlink(file_path)
|
os.unlink(file_path)
|
||||||
except:
|
except Exception as e:
|
||||||
|
print("Error deleting temp file", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Docling parsing completed, creating document: {filename}",
|
f"Docling parsing completed, creating document: {filename}",
|
||||||
{"processing_stage": "parsing_complete", "content_length": len(result['content'])}
|
{
|
||||||
|
"processing_stage": "parsing_complete",
|
||||||
|
"content_length": len(result["content"]),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process the document using our Docling background task
|
# Process the document using our Docling background task
|
||||||
doc_result = await add_received_file_document_using_docling(
|
doc_result = await add_received_file_document_using_docling(
|
||||||
session,
|
session,
|
||||||
filename,
|
filename,
|
||||||
docling_markdown_document=result['content'],
|
docling_markdown_document=result["content"],
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if doc_result:
|
if doc_result:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Successfully processed file with Docling: {filename}",
|
f"Successfully processed file with Docling: {filename}",
|
||||||
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "DOCLING"}
|
{
|
||||||
|
"document_id": doc_result.id,
|
||||||
|
"content_hash": doc_result.content_hash,
|
||||||
|
"file_type": "document",
|
||||||
|
"etl_service": "DOCLING",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Document already exists (duplicate): {filename}",
|
f"Document already exists (duplicate): {filename}",
|
||||||
{"duplicate_detected": True, "file_type": "document", "etl_service": "DOCLING"}
|
{
|
||||||
|
"duplicate_detected": True,
|
||||||
|
"file_type": "document",
|
||||||
|
"etl_service": "DOCLING",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Failed to process file: {filename}",
|
f"Failed to process file: {filename}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__, "filename": filename}
|
{"error_type": type(e).__name__, "filename": filename},
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
logging.error(f"Error processing file in background: {str(e)}")
|
|
||||||
|
logging.error(f"Error processing file in background: {e!s}")
|
||||||
raise # Re-raise so the wrapper can also handle it
|
raise # Re-raise so the wrapper can also handle it
|
||||||
|
|
||||||
|
|
||||||
@router.get("/documents/", response_model=List[DocumentRead])
|
@router.get("/documents/", response_model=list[DocumentRead])
|
||||||
async def read_documents(
|
async def read_documents(
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 300,
|
limit: int = 300,
|
||||||
search_space_id: int = None,
|
search_space_id: int | None = None,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
query = select(Document).join(SearchSpace).filter(
|
query = (
|
||||||
SearchSpace.user_id == user.id)
|
select(Document).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
||||||
|
)
|
||||||
|
|
||||||
# Filter by search_space_id if provided
|
# Filter by search_space_id if provided
|
||||||
if search_space_id is not None:
|
if search_space_id is not None:
|
||||||
query = query.filter(Document.search_space_id == search_space_id)
|
query = query.filter(Document.search_space_id == search_space_id)
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(query.offset(skip).limit(limit))
|
||||||
query.offset(skip).limit(limit)
|
|
||||||
)
|
|
||||||
db_documents = result.scalars().all()
|
db_documents = result.scalars().all()
|
||||||
|
|
||||||
# Convert database objects to API-friendly format
|
# Convert database objects to API-friendly format
|
||||||
api_documents = []
|
api_documents = []
|
||||||
for doc in db_documents:
|
for doc in db_documents:
|
||||||
api_documents.append(DocumentRead(
|
api_documents.append(
|
||||||
|
DocumentRead(
|
||||||
id=doc.id,
|
id=doc.id,
|
||||||
title=doc.title,
|
title=doc.title,
|
||||||
document_type=doc.document_type,
|
document_type=doc.document_type,
|
||||||
document_metadata=doc.document_metadata,
|
document_metadata=doc.document_metadata,
|
||||||
content=doc.content,
|
content=doc.content,
|
||||||
created_at=doc.created_at,
|
created_at=doc.created_at,
|
||||||
search_space_id=doc.search_space_id
|
search_space_id=doc.search_space_id,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return api_documents
|
return api_documents
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to fetch documents: {e!s}"
|
||||||
detail=f"Failed to fetch documents: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/documents/{document_id}", response_model=DocumentRead)
|
@router.get("/documents/{document_id}", response_model=DocumentRead)
|
||||||
async def read_document(
|
async def read_document(
|
||||||
document_id: int,
|
document_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
|
|
@ -497,8 +568,7 @@ async def read_document(
|
||||||
|
|
||||||
if not document:
|
if not document:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404, detail=f"Document with id {document_id} not found"
|
||||||
detail=f"Document with id {document_id} not found"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert database object to API-friendly format
|
# Convert database object to API-friendly format
|
||||||
|
|
@ -509,13 +579,12 @@ async def read_document(
|
||||||
document_metadata=document.document_metadata,
|
document_metadata=document.document_metadata,
|
||||||
content=document.content,
|
content=document.content,
|
||||||
created_at=document.created_at,
|
created_at=document.created_at,
|
||||||
search_space_id=document.search_space_id
|
search_space_id=document.search_space_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to fetch document: {e!s}"
|
||||||
detail=f"Failed to fetch document: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/documents/{document_id}", response_model=DocumentRead)
|
@router.put("/documents/{document_id}", response_model=DocumentRead)
|
||||||
|
|
@ -523,7 +592,7 @@ async def update_document(
|
||||||
document_id: int,
|
document_id: int,
|
||||||
document_update: DocumentUpdate,
|
document_update: DocumentUpdate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Query the document directly instead of using read_document function
|
# Query the document directly instead of using read_document function
|
||||||
|
|
@ -536,8 +605,7 @@ async def update_document(
|
||||||
|
|
||||||
if not db_document:
|
if not db_document:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404, detail=f"Document with id {document_id} not found"
|
||||||
detail=f"Document with id {document_id} not found"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
update_data = document_update.model_dump(exclude_unset=True)
|
update_data = document_update.model_dump(exclude_unset=True)
|
||||||
|
|
@ -554,23 +622,22 @@ async def update_document(
|
||||||
document_metadata=db_document.document_metadata,
|
document_metadata=db_document.document_metadata,
|
||||||
content=db_document.content,
|
content=db_document.content,
|
||||||
created_at=db_document.created_at,
|
created_at=db_document.created_at,
|
||||||
search_space_id=db_document.search_space_id
|
search_space_id=db_document.search_space_id,
|
||||||
)
|
)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to update document: {e!s}"
|
||||||
detail=f"Failed to update document: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/documents/{document_id}", response_model=dict)
|
@router.delete("/documents/{document_id}", response_model=dict)
|
||||||
async def delete_document(
|
async def delete_document(
|
||||||
document_id: int,
|
document_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Query the document directly instead of using read_document function
|
# Query the document directly instead of using read_document function
|
||||||
|
|
@ -583,8 +650,7 @@ async def delete_document(
|
||||||
|
|
||||||
if not document:
|
if not document:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404, detail=f"Document with id {document_id} not found"
|
||||||
detail=f"Document with id {document_id} not found"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await session.delete(document)
|
await session.delete(document)
|
||||||
|
|
@ -595,15 +661,12 @@ async def delete_document(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to delete document: {e!s}"
|
||||||
detail=f"Failed to delete document: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def process_extension_document_with_new_session(
|
async def process_extension_document_with_new_session(
|
||||||
individual_document,
|
individual_document, search_space_id: int, user_id: str
|
||||||
search_space_id: int,
|
|
||||||
user_id: str
|
|
||||||
):
|
):
|
||||||
"""Create a new session and process extension document."""
|
"""Create a new session and process extension document."""
|
||||||
from app.db import async_session_maker
|
from app.db import async_session_maker
|
||||||
|
|
@ -622,40 +685,41 @@ async def process_extension_document_with_new_session(
|
||||||
"document_type": "EXTENSION",
|
"document_type": "EXTENSION",
|
||||||
"url": individual_document.metadata.VisitedWebPageURL,
|
"url": individual_document.metadata.VisitedWebPageURL,
|
||||||
"title": individual_document.metadata.VisitedWebPageTitle,
|
"title": individual_document.metadata.VisitedWebPageTitle,
|
||||||
"user_id": user_id
|
"user_id": user_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await add_extension_received_document(session, individual_document, search_space_id, user_id)
|
result = await add_extension_received_document(
|
||||||
|
session, individual_document, search_space_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Successfully processed extension document: {individual_document.metadata.VisitedWebPageTitle}",
|
f"Successfully processed extension document: {individual_document.metadata.VisitedWebPageTitle}",
|
||||||
{"document_id": result.id, "content_hash": result.content_hash}
|
{"document_id": result.id, "content_hash": result.content_hash},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Extension document already exists (duplicate): {individual_document.metadata.VisitedWebPageTitle}",
|
f"Extension document already exists (duplicate): {individual_document.metadata.VisitedWebPageTitle}",
|
||||||
{"duplicate_detected": True}
|
{"duplicate_detected": True},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Failed to process extension document: {individual_document.metadata.VisitedWebPageTitle}",
|
f"Failed to process extension document: {individual_document.metadata.VisitedWebPageTitle}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__}
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
logging.error(f"Error processing extension document: {str(e)}")
|
|
||||||
|
logging.error(f"Error processing extension document: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
async def process_crawled_url_with_new_session(
|
async def process_crawled_url_with_new_session(
|
||||||
url: str,
|
url: str, search_space_id: int, user_id: str
|
||||||
search_space_id: int,
|
|
||||||
user_id: str
|
|
||||||
):
|
):
|
||||||
"""Create a new session and process crawled URL."""
|
"""Create a new session and process crawled URL."""
|
||||||
from app.db import async_session_maker
|
from app.db import async_session_maker
|
||||||
|
|
@ -670,44 +734,44 @@ async def process_crawled_url_with_new_session(
|
||||||
task_name="process_crawled_url",
|
task_name="process_crawled_url",
|
||||||
source="document_processor",
|
source="document_processor",
|
||||||
message=f"Starting URL crawling and processing for: {url}",
|
message=f"Starting URL crawling and processing for: {url}",
|
||||||
metadata={
|
metadata={"document_type": "CRAWLED_URL", "url": url, "user_id": user_id},
|
||||||
"document_type": "CRAWLED_URL",
|
|
||||||
"url": url,
|
|
||||||
"user_id": user_id
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await add_crawled_url_document(session, url, search_space_id, user_id)
|
result = await add_crawled_url_document(
|
||||||
|
session, url, search_space_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Successfully crawled and processed URL: {url}",
|
f"Successfully crawled and processed URL: {url}",
|
||||||
{"document_id": result.id, "title": result.title, "content_hash": result.content_hash}
|
{
|
||||||
|
"document_id": result.id,
|
||||||
|
"title": result.title,
|
||||||
|
"content_hash": result.content_hash,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"URL document already exists (duplicate): {url}",
|
f"URL document already exists (duplicate): {url}",
|
||||||
{"duplicate_detected": True}
|
{"duplicate_detected": True},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Failed to crawl URL: {url}",
|
f"Failed to crawl URL: {url}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__}
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
logging.error(f"Error processing crawled URL: {str(e)}")
|
|
||||||
|
logging.error(f"Error processing crawled URL: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
async def process_file_in_background_with_new_session(
|
async def process_file_in_background_with_new_session(
|
||||||
file_path: str,
|
file_path: str, filename: str, search_space_id: int, user_id: str
|
||||||
filename: str,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str
|
|
||||||
):
|
):
|
||||||
"""Create a new session and process file."""
|
"""Create a new session and process file."""
|
||||||
from app.db import async_session_maker
|
from app.db import async_session_maker
|
||||||
|
|
@ -726,12 +790,20 @@ async def process_file_in_background_with_new_session(
|
||||||
"document_type": "FILE",
|
"document_type": "FILE",
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"file_path": file_path,
|
"file_path": file_path,
|
||||||
"user_id": user_id
|
"user_id": user_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await process_file_in_background(file_path, filename, search_space_id, user_id, session, task_logger, log_entry)
|
await process_file_in_background(
|
||||||
|
file_path,
|
||||||
|
filename,
|
||||||
|
search_space_id,
|
||||||
|
user_id,
|
||||||
|
session,
|
||||||
|
task_logger,
|
||||||
|
log_entry,
|
||||||
|
)
|
||||||
|
|
||||||
# Note: success/failure logging is handled within process_file_in_background
|
# Note: success/failure logging is handled within process_file_in_background
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -739,16 +811,15 @@ async def process_file_in_background_with_new_session(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Failed to process file: {filename}",
|
f"Failed to process file: {filename}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__}
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
logging.error(f"Error processing file: {str(e)}")
|
|
||||||
|
logging.error(f"Error processing file: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
async def process_youtube_video_with_new_session(
|
async def process_youtube_video_with_new_session(
|
||||||
url: str,
|
url: str, search_space_id: int, user_id: str
|
||||||
search_space_id: int,
|
|
||||||
user_id: str
|
|
||||||
):
|
):
|
||||||
"""Create a new session and process YouTube video."""
|
"""Create a new session and process YouTube video."""
|
||||||
from app.db import async_session_maker
|
from app.db import async_session_maker
|
||||||
|
|
@ -763,36 +834,37 @@ async def process_youtube_video_with_new_session(
|
||||||
task_name="process_youtube_video",
|
task_name="process_youtube_video",
|
||||||
source="document_processor",
|
source="document_processor",
|
||||||
message=f"Starting YouTube video processing for: {url}",
|
message=f"Starting YouTube video processing for: {url}",
|
||||||
metadata={
|
metadata={"document_type": "YOUTUBE_VIDEO", "url": url, "user_id": user_id},
|
||||||
"document_type": "YOUTUBE_VIDEO",
|
|
||||||
"url": url,
|
|
||||||
"user_id": user_id
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await add_youtube_video_document(session, url, search_space_id, user_id)
|
result = await add_youtube_video_document(
|
||||||
|
session, url, search_space_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Successfully processed YouTube video: {result.title}",
|
f"Successfully processed YouTube video: {result.title}",
|
||||||
{"document_id": result.id, "video_id": result.document_metadata.get("video_id"), "content_hash": result.content_hash}
|
{
|
||||||
|
"document_id": result.id,
|
||||||
|
"video_id": result.document_metadata.get("video_id"),
|
||||||
|
"content_hash": result.content_hash,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"YouTube video document already exists (duplicate): {url}",
|
f"YouTube video document already exists (duplicate): {url}",
|
||||||
{"duplicate_detected": True}
|
{"duplicate_detected": True},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Failed to process YouTube video: {url}",
|
f"Failed to process YouTube video: {url}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__}
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
logging.error(f"Error processing YouTube video: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
|
logging.error(f"Error processing YouTube video: {e!s}")
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,40 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from typing import List, Optional
|
|
||||||
from pydantic import BaseModel
|
from app.db import LLMConfig, User, get_async_session
|
||||||
from app.db import get_async_session, User, LLMConfig
|
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||||
from app.schemas import LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
|
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.check_ownership import check_ownership
|
from app.utils.check_ownership import check_ownership
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class LLMPreferencesUpdate(BaseModel):
|
class LLMPreferencesUpdate(BaseModel):
|
||||||
"""Schema for updating user LLM preferences"""
|
"""Schema for updating user LLM preferences"""
|
||||||
long_context_llm_id: Optional[int] = None
|
|
||||||
fast_llm_id: Optional[int] = None
|
long_context_llm_id: int | None = None
|
||||||
strategic_llm_id: Optional[int] = None
|
fast_llm_id: int | None = None
|
||||||
|
strategic_llm_id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class LLMPreferencesRead(BaseModel):
|
class LLMPreferencesRead(BaseModel):
|
||||||
"""Schema for reading user LLM preferences"""
|
"""Schema for reading user LLM preferences"""
|
||||||
long_context_llm_id: Optional[int] = None
|
|
||||||
fast_llm_id: Optional[int] = None
|
long_context_llm_id: int | None = None
|
||||||
strategic_llm_id: Optional[int] = None
|
fast_llm_id: int | None = None
|
||||||
long_context_llm: Optional[LLMConfigRead] = None
|
strategic_llm_id: int | None = None
|
||||||
fast_llm: Optional[LLMConfigRead] = None
|
long_context_llm: LLMConfigRead | None = None
|
||||||
strategic_llm: Optional[LLMConfigRead] = None
|
fast_llm: LLMConfigRead | None = None
|
||||||
|
strategic_llm: LLMConfigRead | None = None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/llm-configs/", response_model=LLMConfigRead)
|
@router.post("/llm-configs/", response_model=LLMConfigRead)
|
||||||
async def create_llm_config(
|
async def create_llm_config(
|
||||||
llm_config: LLMConfigCreate,
|
llm_config: LLMConfigCreate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Create a new LLM configuration for the authenticated user"""
|
"""Create a new LLM configuration for the authenticated user"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -43,16 +48,16 @@ async def create_llm_config(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to create LLM configuration: {e!s}"
|
||||||
detail=f"Failed to create LLM configuration: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/llm-configs/", response_model=List[LLMConfigRead])
|
|
||||||
|
@router.get("/llm-configs/", response_model=list[LLMConfigRead])
|
||||||
async def read_llm_configs(
|
async def read_llm_configs(
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 200,
|
limit: int = 200,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Get all LLM configurations for the authenticated user"""
|
"""Get all LLM configurations for the authenticated user"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -65,15 +70,15 @@ async def read_llm_configs(
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
|
||||||
detail=f"Failed to fetch LLM configurations: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
||||||
async def read_llm_config(
|
async def read_llm_config(
|
||||||
llm_config_id: int,
|
llm_config_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Get a specific LLM configuration by ID"""
|
"""Get a specific LLM configuration by ID"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -83,16 +88,16 @@ async def read_llm_config(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to fetch LLM configuration: {e!s}"
|
||||||
detail=f"Failed to fetch LLM configuration: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
||||||
async def update_llm_config(
|
async def update_llm_config(
|
||||||
llm_config_id: int,
|
llm_config_id: int,
|
||||||
llm_config_update: LLMConfigUpdate,
|
llm_config_update: LLMConfigUpdate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Update an existing LLM configuration"""
|
"""Update an existing LLM configuration"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -110,15 +115,15 @@ async def update_llm_config(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to update LLM configuration: {e!s}"
|
||||||
detail=f"Failed to update LLM configuration: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.delete("/llm-configs/{llm_config_id}", response_model=dict)
|
@router.delete("/llm-configs/{llm_config_id}", response_model=dict)
|
||||||
async def delete_llm_config(
|
async def delete_llm_config(
|
||||||
llm_config_id: int,
|
llm_config_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Delete an LLM configuration"""
|
"""Delete an LLM configuration"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -131,16 +136,17 @@ async def delete_llm_config(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to delete LLM configuration: {e!s}"
|
||||||
detail=f"Failed to delete LLM configuration: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
# User LLM Preferences endpoints
|
# User LLM Preferences endpoints
|
||||||
|
|
||||||
|
|
||||||
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||||
async def get_user_llm_preferences(
|
async def get_user_llm_preferences(
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Get the current user's LLM preferences"""
|
"""Get the current user's LLM preferences"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -161,7 +167,7 @@ async def get_user_llm_preferences(
|
||||||
long_context_llm = await session.execute(
|
long_context_llm = await session.execute(
|
||||||
select(LLMConfig).filter(
|
select(LLMConfig).filter(
|
||||||
LLMConfig.id == user.long_context_llm_id,
|
LLMConfig.id == user.long_context_llm_id,
|
||||||
LLMConfig.user_id == user.id
|
LLMConfig.user_id == user.id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
llm_config = long_context_llm.scalars().first()
|
llm_config = long_context_llm.scalars().first()
|
||||||
|
|
@ -171,8 +177,7 @@ async def get_user_llm_preferences(
|
||||||
if user.fast_llm_id:
|
if user.fast_llm_id:
|
||||||
fast_llm = await session.execute(
|
fast_llm = await session.execute(
|
||||||
select(LLMConfig).filter(
|
select(LLMConfig).filter(
|
||||||
LLMConfig.id == user.fast_llm_id,
|
LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id
|
||||||
LLMConfig.user_id == user.id
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
llm_config = fast_llm.scalars().first()
|
llm_config = fast_llm.scalars().first()
|
||||||
|
|
@ -182,8 +187,7 @@ async def get_user_llm_preferences(
|
||||||
if user.strategic_llm_id:
|
if user.strategic_llm_id:
|
||||||
strategic_llm = await session.execute(
|
strategic_llm = await session.execute(
|
||||||
select(LLMConfig).filter(
|
select(LLMConfig).filter(
|
||||||
LLMConfig.id == user.strategic_llm_id,
|
LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id
|
||||||
LLMConfig.user_id == user.id
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
llm_config = strategic_llm.scalars().first()
|
llm_config = strategic_llm.scalars().first()
|
||||||
|
|
@ -193,35 +197,34 @@ async def get_user_llm_preferences(
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
|
||||||
detail=f"Failed to fetch LLM preferences: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||||
async def update_user_llm_preferences(
|
async def update_user_llm_preferences(
|
||||||
preferences: LLMPreferencesUpdate,
|
preferences: LLMPreferencesUpdate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Update the current user's LLM preferences"""
|
"""Update the current user's LLM preferences"""
|
||||||
try:
|
try:
|
||||||
# Validate that all provided LLM config IDs belong to the user
|
# Validate that all provided LLM config IDs belong to the user
|
||||||
update_data = preferences.model_dump(exclude_unset=True)
|
update_data = preferences.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
for key, llm_config_id in update_data.items():
|
for _key, llm_config_id in update_data.items():
|
||||||
if llm_config_id is not None:
|
if llm_config_id is not None:
|
||||||
# Verify ownership of the LLM config
|
# Verify ownership of the LLM config
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(LLMConfig).filter(
|
select(LLMConfig).filter(
|
||||||
LLMConfig.id == llm_config_id,
|
LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id
|
||||||
LLMConfig.user_id == user.id
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
llm_config = result.scalars().first()
|
llm_config = result.scalars().first()
|
||||||
if not llm_config:
|
if not llm_config:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it"
|
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update user preferences
|
# Update user preferences
|
||||||
|
|
@ -238,6 +241,5 @@ async def update_user_llm_preferences(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
|
||||||
detail=f"Failed to update LLM preferences: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,23 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
from sqlalchemy import and_, desc
|
|
||||||
from typing import List, Optional
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from app.db import get_async_session, User, SearchSpace, Log, LogLevel, LogStatus
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from app.schemas import LogCreate, LogUpdate, LogRead, LogFilter
|
from sqlalchemy import and_, desc
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import Log, LogLevel, LogStatus, SearchSpace, User, get_async_session
|
||||||
|
from app.schemas import LogCreate, LogRead, LogUpdate
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.check_ownership import check_ownership
|
from app.utils.check_ownership import check_ownership
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logs/", response_model=LogRead)
|
@router.post("/logs/", response_model=LogRead)
|
||||||
async def create_log(
|
async def create_log(
|
||||||
log: LogCreate,
|
log: LogCreate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Create a new log entry."""
|
"""Create a new log entry."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -33,22 +34,22 @@ async def create_log(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to create log: {e!s}"
|
||||||
detail=f"Failed to create log: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/logs/", response_model=List[LogRead])
|
|
||||||
|
@router.get("/logs/", response_model=list[LogRead])
|
||||||
async def read_logs(
|
async def read_logs(
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
search_space_id: Optional[int] = None,
|
search_space_id: int | None = None,
|
||||||
level: Optional[LogLevel] = None,
|
level: LogLevel | None = None,
|
||||||
status: Optional[LogStatus] = None,
|
status: LogStatus | None = None,
|
||||||
source: Optional[str] = None,
|
source: str | None = None,
|
||||||
start_date: Optional[datetime] = None,
|
start_date: datetime | None = None,
|
||||||
end_date: Optional[datetime] = None,
|
end_date: datetime | None = None,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Get logs with optional filtering."""
|
"""Get logs with optional filtering."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -93,15 +94,15 @@ async def read_logs(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to fetch logs: {e!s}"
|
||||||
detail=f"Failed to fetch logs: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/logs/{log_id}", response_model=LogRead)
|
@router.get("/logs/{log_id}", response_model=LogRead)
|
||||||
async def read_log(
|
async def read_log(
|
||||||
log_id: int,
|
log_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Get a specific log by ID."""
|
"""Get a specific log by ID."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -121,16 +122,16 @@ async def read_log(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to fetch log: {e!s}"
|
||||||
detail=f"Failed to fetch log: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.put("/logs/{log_id}", response_model=LogRead)
|
@router.put("/logs/{log_id}", response_model=LogRead)
|
||||||
async def update_log(
|
async def update_log(
|
||||||
log_id: int,
|
log_id: int,
|
||||||
log_update: LogUpdate,
|
log_update: LogUpdate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Update a log entry."""
|
"""Update a log entry."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -158,15 +159,15 @@ async def update_log(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to update log: {e!s}"
|
||||||
detail=f"Failed to update log: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.delete("/logs/{log_id}")
|
@router.delete("/logs/{log_id}")
|
||||||
async def delete_log(
|
async def delete_log(
|
||||||
log_id: int,
|
log_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Delete a log entry."""
|
"""Delete a log entry."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -189,16 +190,16 @@ async def delete_log(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to delete log: {e!s}"
|
||||||
detail=f"Failed to delete log: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/logs/search-space/{search_space_id}/summary")
|
@router.get("/logs/search-space/{search_space_id}/summary")
|
||||||
async def get_logs_summary(
|
async def get_logs_summary(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
hours: int = 24,
|
hours: int = 24,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Get a summary of logs for a search space in the last X hours."""
|
"""Get a summary of logs for a search space in the last X hours."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -212,10 +213,7 @@ async def get_logs_summary(
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Log)
|
select(Log)
|
||||||
.filter(
|
.filter(
|
||||||
and_(
|
and_(Log.search_space_id == search_space_id, Log.created_at >= since)
|
||||||
Log.search_space_id == search_space_id,
|
|
||||||
Log.created_at >= since
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
.order_by(desc(Log.created_at))
|
.order_by(desc(Log.created_at))
|
||||||
)
|
)
|
||||||
|
|
@ -229,14 +227,16 @@ async def get_logs_summary(
|
||||||
"by_level": {},
|
"by_level": {},
|
||||||
"by_source": {},
|
"by_source": {},
|
||||||
"active_tasks": [],
|
"active_tasks": [],
|
||||||
"recent_failures": []
|
"recent_failures": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Count by status and level
|
# Count by status and level
|
||||||
for log in logs:
|
for log in logs:
|
||||||
# Status counts
|
# Status counts
|
||||||
status_str = log.status.value
|
status_str = log.status.value
|
||||||
summary["by_status"][status_str] = summary["by_status"].get(status_str, 0) + 1
|
summary["by_status"][status_str] = (
|
||||||
|
summary["by_status"].get(status_str, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
# Level counts
|
# Level counts
|
||||||
level_str = log.level.value
|
level_str = log.level.value
|
||||||
|
|
@ -244,30 +244,46 @@ async def get_logs_summary(
|
||||||
|
|
||||||
# Source counts
|
# Source counts
|
||||||
if log.source:
|
if log.source:
|
||||||
summary["by_source"][log.source] = summary["by_source"].get(log.source, 0) + 1
|
summary["by_source"][log.source] = (
|
||||||
|
summary["by_source"].get(log.source, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
# Active tasks (IN_PROGRESS)
|
# Active tasks (IN_PROGRESS)
|
||||||
if log.status == LogStatus.IN_PROGRESS:
|
if log.status == LogStatus.IN_PROGRESS:
|
||||||
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown"
|
task_name = (
|
||||||
summary["active_tasks"].append({
|
log.log_metadata.get("task_name", "Unknown")
|
||||||
|
if log.log_metadata
|
||||||
|
else "Unknown"
|
||||||
|
)
|
||||||
|
summary["active_tasks"].append(
|
||||||
|
{
|
||||||
"id": log.id,
|
"id": log.id,
|
||||||
"task_name": task_name,
|
"task_name": task_name,
|
||||||
"message": log.message,
|
"message": log.message,
|
||||||
"started_at": log.created_at,
|
"started_at": log.created_at,
|
||||||
"source": log.source
|
"source": log.source,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Recent failures
|
# Recent failures
|
||||||
if log.status == LogStatus.FAILED and len(summary["recent_failures"]) < 10:
|
if log.status == LogStatus.FAILED and len(summary["recent_failures"]) < 10:
|
||||||
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown"
|
task_name = (
|
||||||
summary["recent_failures"].append({
|
log.log_metadata.get("task_name", "Unknown")
|
||||||
|
if log.log_metadata
|
||||||
|
else "Unknown"
|
||||||
|
)
|
||||||
|
summary["recent_failures"].append(
|
||||||
|
{
|
||||||
"id": log.id,
|
"id": log.id,
|
||||||
"task_name": task_name,
|
"task_name": task_name,
|
||||||
"message": log.message,
|
"message": log.message,
|
||||||
"failed_at": log.created_at,
|
"failed_at": log.created_at,
|
||||||
"source": log.source,
|
"source": log.source,
|
||||||
"error_details": log.log_metadata.get("error_details") if log.log_metadata else None
|
"error_details": log.log_metadata.get("error_details")
|
||||||
})
|
if log.log_metadata
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
@ -275,6 +291,5 @@ async def get_logs_summary(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to generate logs summary: {e!s}"
|
||||||
detail=f"Failed to generate logs summary: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,31 @@
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
|
||||||
from typing import List
|
|
||||||
from app.db import get_async_session, User, SearchSpace, Podcast, Chat
|
|
||||||
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
|
|
||||||
from app.users import current_active_user
|
|
||||||
from app.utils.check_ownership import check_ownership
|
|
||||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import Chat, Podcast, SearchSpace, User, get_async_session
|
||||||
|
from app.schemas import (
|
||||||
|
PodcastCreate,
|
||||||
|
PodcastGenerateRequest,
|
||||||
|
PodcastRead,
|
||||||
|
PodcastUpdate,
|
||||||
|
)
|
||||||
|
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||||
|
from app.users import current_active_user
|
||||||
|
from app.utils.check_ownership import check_ownership
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/podcasts/", response_model=PodcastRead)
|
@router.post("/podcasts/", response_model=PodcastRead)
|
||||||
async def create_podcast(
|
async def create_podcast(
|
||||||
podcast: PodcastCreate,
|
podcast: PodcastCreate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
await check_ownership(session, SearchSpace, podcast.search_space_id, user)
|
await check_ownership(session, SearchSpace, podcast.search_space_id, user)
|
||||||
|
|
@ -29,22 +36,30 @@ async def create_podcast(
|
||||||
return db_podcast
|
return db_podcast
|
||||||
except HTTPException as he:
|
except HTTPException as he:
|
||||||
raise he
|
raise he
|
||||||
except IntegrityError as e:
|
except IntegrityError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=400, detail="Podcast creation failed due to constraint violation")
|
raise HTTPException(
|
||||||
except SQLAlchemyError as e:
|
status_code=400,
|
||||||
|
detail="Podcast creation failed due to constraint violation",
|
||||||
|
) from None
|
||||||
|
except SQLAlchemyError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=500, detail="Database error occurred while creating podcast")
|
raise HTTPException(
|
||||||
except Exception as e:
|
status_code=500, detail="Database error occurred while creating podcast"
|
||||||
|
) from None
|
||||||
|
except Exception:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=500, detail="An unexpected error occurred")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="An unexpected error occurred"
|
||||||
|
) from None
|
||||||
|
|
||||||
@router.get("/podcasts/", response_model=List[PodcastRead])
|
|
||||||
|
@router.get("/podcasts/", response_model=list[PodcastRead])
|
||||||
async def read_podcasts(
|
async def read_podcasts(
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
if skip < 0 or limit < 1:
|
if skip < 0 or limit < 1:
|
||||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||||
|
|
@ -58,13 +73,16 @@ async def read_podcasts(
|
||||||
)
|
)
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcasts")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Database error occurred while fetching podcasts"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
|
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||||
async def read_podcast(
|
async def read_podcast(
|
||||||
podcast_id: int,
|
podcast_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
|
|
@ -76,20 +94,23 @@ async def read_podcast(
|
||||||
if not podcast:
|
if not podcast:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail="Podcast not found or you don't have permission to access it"
|
detail="Podcast not found or you don't have permission to access it",
|
||||||
)
|
)
|
||||||
return podcast
|
return podcast
|
||||||
except HTTPException as he:
|
except HTTPException as he:
|
||||||
raise he
|
raise he
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcast")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Database error occurred while fetching podcast"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
|
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||||
async def update_podcast(
|
async def update_podcast(
|
||||||
podcast_id: int,
|
podcast_id: int,
|
||||||
podcast_update: PodcastUpdate,
|
podcast_update: PodcastUpdate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
db_podcast = await read_podcast(podcast_id, session, user)
|
db_podcast = await read_podcast(podcast_id, session, user)
|
||||||
|
|
@ -103,16 +124,21 @@ async def update_podcast(
|
||||||
raise he
|
raise he
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=400, detail="Update failed due to constraint violation")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Update failed due to constraint violation"
|
||||||
|
) from None
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=500, detail="Database error occurred while updating podcast")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Database error occurred while updating podcast"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
||||||
async def delete_podcast(
|
async def delete_podcast(
|
||||||
podcast_id: int,
|
podcast_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
db_podcast = await read_podcast(podcast_id, session, user)
|
db_podcast = await read_podcast(podcast_id, session, user)
|
||||||
|
|
@ -123,30 +149,34 @@ async def delete_podcast(
|
||||||
raise he
|
raise he
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Database error occurred while deleting podcast"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
async def generate_chat_podcast_with_new_session(
|
async def generate_chat_podcast_with_new_session(
|
||||||
chat_id: int,
|
chat_id: int, search_space_id: int, podcast_title: str, user_id: int
|
||||||
search_space_id: int,
|
|
||||||
podcast_title: str,
|
|
||||||
user_id: int
|
|
||||||
):
|
):
|
||||||
"""Create a new session and process chat podcast generation."""
|
"""Create a new session and process chat podcast generation."""
|
||||||
from app.db import async_session_maker
|
from app.db import async_session_maker
|
||||||
|
|
||||||
async with async_session_maker() as session:
|
async with async_session_maker() as session:
|
||||||
try:
|
try:
|
||||||
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title, user_id)
|
await generate_chat_podcast(
|
||||||
|
session, chat_id, search_space_id, podcast_title, user_id
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
import logging
|
||||||
logging.error(f"Error generating podcast from chat: {str(e)}")
|
|
||||||
|
logging.error(f"Error generating podcast from chat: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/podcasts/generate/")
|
@router.post("/podcasts/generate/")
|
||||||
async def generate_podcast(
|
async def generate_podcast(
|
||||||
request: PodcastGenerateRequest,
|
request: PodcastGenerateRequest,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Check if the user owns the search space
|
# Check if the user owns the search space
|
||||||
|
|
@ -154,10 +184,15 @@ async def generate_podcast(
|
||||||
|
|
||||||
if request.type == "CHAT":
|
if request.type == "CHAT":
|
||||||
# Verify that all chat IDs belong to this user and search space
|
# Verify that all chat IDs belong to this user and search space
|
||||||
query = select(Chat).filter(
|
query = (
|
||||||
|
select(Chat)
|
||||||
|
.filter(
|
||||||
Chat.id.in_(request.ids),
|
Chat.id.in_(request.ids),
|
||||||
Chat.search_space_id == request.search_space_id
|
Chat.search_space_id == request.search_space_id,
|
||||||
).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
)
|
||||||
|
.join(SearchSpace)
|
||||||
|
.filter(SearchSpace.user_id == user.id)
|
||||||
|
)
|
||||||
|
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
valid_chats = result.scalars().all()
|
valid_chats = result.scalars().all()
|
||||||
|
|
@ -167,7 +202,7 @@ async def generate_podcast(
|
||||||
if len(valid_chat_ids) != len(request.ids):
|
if len(valid_chat_ids) != len(request.ids):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail="One or more chat IDs do not belong to this user or search space"
|
detail="One or more chat IDs do not belong to this user or search space",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only add a single task with the first chat ID
|
# Only add a single task with the first chat ID
|
||||||
|
|
@ -177,7 +212,7 @@ async def generate_podcast(
|
||||||
chat_id,
|
chat_id,
|
||||||
request.search_space_id,
|
request.search_space_id,
|
||||||
request.podcast_title,
|
request.podcast_title,
|
||||||
user.id
|
user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -185,21 +220,29 @@ async def generate_podcast(
|
||||||
}
|
}
|
||||||
except HTTPException as he:
|
except HTTPException as he:
|
||||||
raise he
|
raise he
|
||||||
except IntegrityError as e:
|
except IntegrityError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=400, detail="Podcast generation failed due to constraint violation")
|
raise HTTPException(
|
||||||
except SQLAlchemyError as e:
|
status_code=400,
|
||||||
|
detail="Podcast generation failed due to constraint violation",
|
||||||
|
) from None
|
||||||
|
except SQLAlchemyError:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=500, detail="Database error occurred while generating podcast")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Database error occurred while generating podcast"
|
||||||
|
) from None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"An unexpected error occurred: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/podcasts/{podcast_id}/stream")
|
@router.get("/podcasts/{podcast_id}/stream")
|
||||||
async def stream_podcast(
|
async def stream_podcast(
|
||||||
podcast_id: int,
|
podcast_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Stream a podcast audio file."""
|
"""Stream a podcast audio file."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -214,7 +257,7 @@ async def stream_podcast(
|
||||||
if not podcast:
|
if not podcast:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail="Podcast not found or you don't have permission to access it"
|
detail="Podcast not found or you don't have permission to access it",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the file path
|
# Get the file path
|
||||||
|
|
@ -235,11 +278,13 @@ async def stream_podcast(
|
||||||
media_type="audio/mpeg",
|
media_type="audio/mpeg",
|
||||||
headers={
|
headers={
|
||||||
"Accept-Ranges": "bytes",
|
"Accept-Ranges": "bytes",
|
||||||
"Content-Disposition": f"inline; filename={Path(file_path).name}"
|
"Content-Disposition": f"inline; filename={Path(file_path).name}",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException as he:
|
except HTTPException as he:
|
||||||
raise he
|
raise he
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"Error streaming podcast: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Error streaming podcast: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,13 @@ Note: Each user can have only one connector of each type (SERPER_API, TAVILY_API
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.connectors.github_connector import GitHubConnector
|
from app.connectors.github_connector import GitHubConnector
|
||||||
from app.db import (
|
from app.db import (
|
||||||
|
|
@ -39,11 +45,6 @@ from app.tasks.connectors_indexing_tasks import (
|
||||||
)
|
)
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.check_ownership import check_ownership
|
from app.utils.check_ownership import check_ownership
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -57,7 +58,7 @@ class GitHubPATRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
# --- New Endpoint to list GitHub Repositories ---
|
# --- New Endpoint to list GitHub Repositories ---
|
||||||
@router.post("/github/repositories/", response_model=List[Dict[str, Any]])
|
@router.post("/github/repositories/", response_model=list[dict[str, Any]])
|
||||||
async def list_github_repositories(
|
async def list_github_repositories(
|
||||||
pat_request: GitHubPATRequest,
|
pat_request: GitHubPATRequest,
|
||||||
user: User = Depends(current_active_user), # Ensure the user is logged in
|
user: User = Depends(current_active_user), # Ensure the user is logged in
|
||||||
|
|
@ -74,15 +75,13 @@ async def list_github_repositories(
|
||||||
return repositories
|
return repositories
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# Handle invalid token error specifically
|
# Handle invalid token error specifically
|
||||||
logger.error(f"GitHub PAT validation failed for user {user.id}: {str(e)}")
|
logger.error(f"GitHub PAT validation failed for user {user.id}: {e!s}")
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {e!s}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Failed to fetch GitHub repositories for user {user.id}: {e!s}")
|
||||||
f"Failed to fetch GitHub repositories for user {user.id}: {str(e)}"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail="Failed to fetch GitHub repositories."
|
status_code=500, detail="Failed to fetch GitHub repositories."
|
||||||
)
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead)
|
@router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead)
|
||||||
|
|
@ -118,32 +117,32 @@ async def create_search_source_connector(
|
||||||
return db_connector
|
return db_connector
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(status_code=422, detail=f"Validation error: {str(e)}")
|
raise HTTPException(status_code=422, detail=f"Validation error: {e!s}") from e
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=409,
|
status_code=409,
|
||||||
detail=f"Integrity error: A connector with this type already exists. {str(e)}",
|
detail=f"Integrity error: A connector with this type already exists. {e!s}",
|
||||||
)
|
) from e
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create search source connector: {str(e)}")
|
logger.error(f"Failed to create search source connector: {e!s}")
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"Failed to create search source connector: {str(e)}",
|
detail=f"Failed to create search source connector: {e!s}",
|
||||||
)
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/search-source-connectors/", response_model=List[SearchSourceConnectorRead]
|
"/search-source-connectors/", response_model=list[SearchSourceConnectorRead]
|
||||||
)
|
)
|
||||||
async def read_search_source_connectors(
|
async def read_search_source_connectors(
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
search_space_id: int = None,
|
search_space_id: int | None = None,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
|
|
@ -160,8 +159,8 @@ async def read_search_source_connectors(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"Failed to fetch search source connectors: {str(e)}",
|
detail=f"Failed to fetch search source connectors: {e!s}",
|
||||||
)
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|
@ -179,8 +178,8 @@ async def read_search_source_connector(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"Failed to fetch search source connector: {str(e)}"
|
status_code=500, detail=f"Failed to fetch search source connector: {e!s}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
|
|
@ -238,8 +237,8 @@ async def update_search_source_connector(
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
# Raise specific validation error for the merged config
|
# Raise specific validation error for the merged config
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=422, detail=f"Validation error for merged config: {str(e)}"
|
status_code=422, detail=f"Validation error for merged config: {e!s}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
# If validation passes, update the main update_data dict with the merged config
|
# If validation passes, update the main update_data dict with the merged config
|
||||||
update_data["config"] = merged_config
|
update_data["config"] = merged_config
|
||||||
|
|
@ -272,8 +271,8 @@ async def update_search_source_connector(
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
# This might occur if connector_type constraint is violated somehow after the check
|
# This might occur if connector_type constraint is violated somehow after the check
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=409, detail=f"Database integrity error during update: {str(e)}"
|
status_code=409, detail=f"Database integrity error during update: {e!s}"
|
||||||
)
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -282,8 +281,8 @@ async def update_search_source_connector(
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"Failed to update search source connector: {str(e)}",
|
detail=f"Failed to update search source connector: {e!s}",
|
||||||
)
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/search-source-connectors/{connector_id}", response_model=dict)
|
@router.delete("/search-source-connectors/{connector_id}", response_model=dict)
|
||||||
|
|
@ -306,12 +305,12 @@ async def delete_search_source_connector(
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"Failed to delete search source connector: {str(e)}",
|
detail=f"Failed to delete search source connector: {e!s}",
|
||||||
)
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/search-source-connectors/{connector_id}/index", response_model=Dict[str, Any]
|
"/search-source-connectors/{connector_id}/index", response_model=dict[str, Any]
|
||||||
)
|
)
|
||||||
async def index_connector_content(
|
async def index_connector_content(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
|
|
@ -356,7 +355,7 @@ async def index_connector_content(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the search space belongs to the user
|
# Check if the search space belongs to the user
|
||||||
search_space = await check_ownership(
|
_search_space = await check_ownership(
|
||||||
session, SearchSpace, search_space_id, user
|
session, SearchSpace, search_space_id, user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -381,10 +380,7 @@ async def index_connector_content(
|
||||||
else:
|
else:
|
||||||
indexing_from = start_date
|
indexing_from = start_date
|
||||||
|
|
||||||
if end_date is None:
|
indexing_to = end_date if end_date else today_str
|
||||||
indexing_to = today_str
|
|
||||||
else:
|
|
||||||
indexing_to = end_date
|
|
||||||
|
|
||||||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||||
# Run indexing in background
|
# Run indexing in background
|
||||||
|
|
@ -497,8 +493,8 @@ async def index_connector_content(
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"Failed to initiate indexing: {str(e)}"
|
status_code=500, detail=f"Failed to initiate indexing: {e!s}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
|
|
||||||
async def update_connector_last_indexed(session: AsyncSession, connector_id: int):
|
async def update_connector_last_indexed(session: AsyncSession, connector_id: int):
|
||||||
|
|
@ -523,7 +519,7 @@ async def update_connector_last_indexed(session: AsyncSession, connector_id: int
|
||||||
logger.info(f"Updated last_indexed_at for connector {connector_id}")
|
logger.info(f"Updated last_indexed_at for connector {connector_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to update last_indexed_at for connector {connector_id}: {str(e)}"
|
f"Failed to update last_indexed_at for connector {connector_id}: {e!s}"
|
||||||
)
|
)
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
|
|
||||||
|
|
@ -587,7 +583,7 @@ async def run_slack_indexing(
|
||||||
f"Slack indexing failed or no documents processed: {error_or_warning}"
|
f"Slack indexing failed or no documents processed: {error_or_warning}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in background Slack indexing task: {str(e)}")
|
logger.error(f"Error in background Slack indexing task: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
async def run_notion_indexing_with_new_session(
|
async def run_notion_indexing_with_new_session(
|
||||||
|
|
@ -649,7 +645,7 @@ async def run_notion_indexing(
|
||||||
f"Notion indexing failed or no documents processed: {error_or_warning}"
|
f"Notion indexing failed or no documents processed: {error_or_warning}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in background Notion indexing task: {str(e)}")
|
logger.error(f"Error in background Notion indexing task: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for GitHub indexing
|
# Add new helper functions for GitHub indexing
|
||||||
|
|
@ -829,7 +825,7 @@ async def run_discord_indexing(
|
||||||
f"Discord indexing failed or no documents processed: {error_or_warning}"
|
f"Discord indexing failed or no documents processed: {error_or_warning}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in background Discord indexing task: {str(e)}")
|
logger.error(f"Error in background Discord indexing task: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
# Add new helper functions for Jira indexing
|
# Add new helper functions for Jira indexing
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from typing import List
|
|
||||||
from app.db import get_async_session, User, SearchSpace
|
from app.db import SearchSpace, User, get_async_session
|
||||||
from app.schemas import SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
from app.schemas import SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.check_ownership import check_ownership
|
from app.utils.check_ownership import check_ownership
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/searchspaces/", response_model=SearchSpaceRead)
|
@router.post("/searchspaces/", response_model=SearchSpaceRead)
|
||||||
async def create_search_space(
|
async def create_search_space(
|
||||||
search_space: SearchSpaceCreate,
|
search_space: SearchSpaceCreate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id)
|
db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id)
|
||||||
|
|
@ -27,16 +27,16 @@ async def create_search_space(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to create search space: {e!s}"
|
||||||
detail=f"Failed to create search space: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/searchspaces/", response_model=List[SearchSpaceRead])
|
|
||||||
|
@router.get("/searchspaces/", response_model=list[SearchSpaceRead])
|
||||||
async def read_search_spaces(
|
async def read_search_spaces(
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 200,
|
limit: int = 200,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
|
|
@ -48,37 +48,41 @@ async def read_search_spaces(
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to fetch search spaces: {e!s}"
|
||||||
detail=f"Failed to fetch search spaces: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||||
async def read_search_space(
|
async def read_search_space(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
search_space = await check_ownership(
|
||||||
|
session, SearchSpace, search_space_id, user
|
||||||
|
)
|
||||||
return search_space
|
return search_space
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to fetch search space: {e!s}"
|
||||||
detail=f"Failed to fetch search space: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
@router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||||
async def update_search_space(
|
async def update_search_space(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
search_space_update: SearchSpaceUpdate,
|
search_space_update: SearchSpaceUpdate,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
db_search_space = await check_ownership(
|
||||||
|
session, SearchSpace, search_space_id, user
|
||||||
|
)
|
||||||
update_data = search_space_update.model_dump(exclude_unset=True)
|
update_data = search_space_update.model_dump(exclude_unset=True)
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(db_search_space, key, value)
|
setattr(db_search_space, key, value)
|
||||||
|
|
@ -90,18 +94,20 @@ async def update_search_space(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to update search space: {e!s}"
|
||||||
detail=f"Failed to update search space: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
||||||
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
||||||
async def delete_search_space(
|
async def delete_search_space(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
db_search_space = await check_ownership(
|
||||||
|
session, SearchSpace, search_space_id, user
|
||||||
|
)
|
||||||
await session.delete(db_search_space)
|
await session.delete(db_search_space)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return {"message": "Search space deleted successfully"}
|
return {"message": "Search space deleted successfully"}
|
||||||
|
|
@ -110,6 +116,5 @@ async def delete_search_space(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=f"Failed to delete search space: {e!s}"
|
||||||
detail=f"Failed to delete search space: {str(e)}"
|
) from e
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,62 +1,78 @@
|
||||||
from .base import TimestampModel, IDModel
|
from .base import IDModel, TimestampModel
|
||||||
from .users import UserRead, UserCreate, UserUpdate
|
from .chats import AISDKChatRequest, ChatBase, ChatCreate, ChatRead, ChatUpdate
|
||||||
from .search_space import SearchSpaceBase, SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
||||||
from .documents import (
|
from .documents import (
|
||||||
ExtensionDocumentMetadata,
|
|
||||||
ExtensionDocumentContent,
|
|
||||||
DocumentBase,
|
DocumentBase,
|
||||||
|
DocumentRead,
|
||||||
DocumentsCreate,
|
DocumentsCreate,
|
||||||
DocumentUpdate,
|
DocumentUpdate,
|
||||||
DocumentRead,
|
ExtensionDocumentContent,
|
||||||
|
ExtensionDocumentMetadata,
|
||||||
)
|
)
|
||||||
from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
|
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||||
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
|
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||||
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
|
from .podcasts import (
|
||||||
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
|
PodcastBase,
|
||||||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
|
PodcastCreate,
|
||||||
from .logs import LogBase, LogCreate, LogUpdate, LogRead, LogFilter
|
PodcastGenerateRequest,
|
||||||
|
PodcastRead,
|
||||||
|
PodcastUpdate,
|
||||||
|
)
|
||||||
|
from .search_source_connector import (
|
||||||
|
SearchSourceConnectorBase,
|
||||||
|
SearchSourceConnectorCreate,
|
||||||
|
SearchSourceConnectorRead,
|
||||||
|
SearchSourceConnectorUpdate,
|
||||||
|
)
|
||||||
|
from .search_space import (
|
||||||
|
SearchSpaceBase,
|
||||||
|
SearchSpaceCreate,
|
||||||
|
SearchSpaceRead,
|
||||||
|
SearchSpaceUpdate,
|
||||||
|
)
|
||||||
|
from .users import UserCreate, UserRead, UserUpdate
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AISDKChatRequest",
|
"AISDKChatRequest",
|
||||||
"TimestampModel",
|
|
||||||
"IDModel",
|
|
||||||
"UserRead",
|
|
||||||
"UserCreate",
|
|
||||||
"UserUpdate",
|
|
||||||
"SearchSpaceBase",
|
|
||||||
"SearchSpaceCreate",
|
|
||||||
"SearchSpaceUpdate",
|
|
||||||
"SearchSpaceRead",
|
|
||||||
"ExtensionDocumentMetadata",
|
|
||||||
"ExtensionDocumentContent",
|
|
||||||
"DocumentBase",
|
|
||||||
"DocumentsCreate",
|
|
||||||
"DocumentUpdate",
|
|
||||||
"DocumentRead",
|
|
||||||
"ChunkBase",
|
|
||||||
"ChunkCreate",
|
|
||||||
"ChunkUpdate",
|
|
||||||
"ChunkRead",
|
|
||||||
"PodcastBase",
|
|
||||||
"PodcastCreate",
|
|
||||||
"PodcastUpdate",
|
|
||||||
"PodcastRead",
|
|
||||||
"PodcastGenerateRequest",
|
|
||||||
"ChatBase",
|
"ChatBase",
|
||||||
"ChatCreate",
|
"ChatCreate",
|
||||||
"ChatUpdate",
|
|
||||||
"ChatRead",
|
"ChatRead",
|
||||||
"SearchSourceConnectorBase",
|
"ChatUpdate",
|
||||||
"SearchSourceConnectorCreate",
|
"ChunkBase",
|
||||||
"SearchSourceConnectorUpdate",
|
"ChunkCreate",
|
||||||
"SearchSourceConnectorRead",
|
"ChunkRead",
|
||||||
|
"ChunkUpdate",
|
||||||
|
"DocumentBase",
|
||||||
|
"DocumentRead",
|
||||||
|
"DocumentUpdate",
|
||||||
|
"DocumentsCreate",
|
||||||
|
"ExtensionDocumentContent",
|
||||||
|
"ExtensionDocumentMetadata",
|
||||||
|
"IDModel",
|
||||||
"LLMConfigBase",
|
"LLMConfigBase",
|
||||||
"LLMConfigCreate",
|
"LLMConfigCreate",
|
||||||
"LLMConfigUpdate",
|
|
||||||
"LLMConfigRead",
|
"LLMConfigRead",
|
||||||
|
"LLMConfigUpdate",
|
||||||
"LogBase",
|
"LogBase",
|
||||||
"LogCreate",
|
"LogCreate",
|
||||||
"LogUpdate",
|
|
||||||
"LogRead",
|
|
||||||
"LogFilter",
|
"LogFilter",
|
||||||
|
"LogRead",
|
||||||
|
"LogUpdate",
|
||||||
|
"PodcastBase",
|
||||||
|
"PodcastCreate",
|
||||||
|
"PodcastGenerateRequest",
|
||||||
|
"PodcastRead",
|
||||||
|
"PodcastUpdate",
|
||||||
|
"SearchSourceConnectorBase",
|
||||||
|
"SearchSourceConnectorCreate",
|
||||||
|
"SearchSourceConnectorRead",
|
||||||
|
"SearchSourceConnectorUpdate",
|
||||||
|
"SearchSpaceBase",
|
||||||
|
"SearchSpaceCreate",
|
||||||
|
"SearchSpaceRead",
|
||||||
|
"SearchSpaceUpdate",
|
||||||
|
"TimestampModel",
|
||||||
|
"UserCreate",
|
||||||
|
"UserRead",
|
||||||
|
"UserUpdate",
|
||||||
]
|
]
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
class TimestampModel(BaseModel):
|
class TimestampModel(BaseModel):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
class IDModel(BaseModel):
|
class IDModel(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from app.db import ChatType
|
from app.db import ChatType
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
|
@ -9,20 +10,20 @@ from .base import IDModel, TimestampModel
|
||||||
class ChatBase(BaseModel):
|
class ChatBase(BaseModel):
|
||||||
type: ChatType
|
type: ChatType
|
||||||
title: str
|
title: str
|
||||||
initial_connectors: Optional[List[str]] = None
|
initial_connectors: list[str] | None = None
|
||||||
messages: List[Any]
|
messages: list[Any]
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
|
||||||
|
|
||||||
class ClientAttachment(BaseModel):
|
class ClientAttachment(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
contentType: str
|
content_type: str
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
|
|
||||||
class ToolInvocation(BaseModel):
|
class ToolInvocation(BaseModel):
|
||||||
toolCallId: str
|
tool_call_id: str
|
||||||
toolName: str
|
tool_name: str
|
||||||
args: dict
|
args: dict
|
||||||
result: dict
|
result: dict
|
||||||
|
|
||||||
|
|
@ -33,15 +34,19 @@ class ToolInvocation(BaseModel):
|
||||||
# experimental_attachments: Optional[List[ClientAttachment]] = None
|
# experimental_attachments: Optional[List[ClientAttachment]] = None
|
||||||
# toolInvocations: Optional[List[ToolInvocation]] = None
|
# toolInvocations: Optional[List[ToolInvocation]] = None
|
||||||
|
|
||||||
|
|
||||||
class AISDKChatRequest(BaseModel):
|
class AISDKChatRequest(BaseModel):
|
||||||
messages: List[Any]
|
messages: list[Any]
|
||||||
data: Optional[Dict[str, Any]] = None
|
data: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCreate(ChatBase):
|
class ChatCreate(ChatBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ChatUpdate(ChatBase):
|
class ChatUpdate(ChatBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ChatRead(ChatBase, IDModel, TimestampModel):
|
class ChatRead(ChatBase, IDModel, TimestampModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
@ -1,15 +1,20 @@
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
|
||||||
class ChunkBase(BaseModel):
|
class ChunkBase(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
document_id: int
|
document_id: int
|
||||||
|
|
||||||
|
|
||||||
class ChunkCreate(ChunkBase):
|
class ChunkCreate(ChunkBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ChunkUpdate(ChunkBase):
|
class ChunkUpdate(ChunkBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ChunkRead(ChunkBase, IDModel, TimestampModel):
|
class ChunkRead(ChunkBase, IDModel, TimestampModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
from typing import List
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
from app.db import DocumentType
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from app.db import DocumentType
|
||||||
|
|
||||||
|
|
||||||
class ExtensionDocumentMetadata(BaseModel):
|
class ExtensionDocumentMetadata(BaseModel):
|
||||||
BrowsingSessionId: str
|
BrowsingSessionId: str
|
||||||
VisitedWebPageURL: str
|
VisitedWebPageURL: str
|
||||||
|
|
@ -11,21 +13,28 @@ class ExtensionDocumentMetadata(BaseModel):
|
||||||
VisitedWebPageReffererURL: str
|
VisitedWebPageReffererURL: str
|
||||||
VisitedWebPageVisitDurationInMilliseconds: str
|
VisitedWebPageVisitDurationInMilliseconds: str
|
||||||
|
|
||||||
|
|
||||||
class ExtensionDocumentContent(BaseModel):
|
class ExtensionDocumentContent(BaseModel):
|
||||||
metadata: ExtensionDocumentMetadata
|
metadata: ExtensionDocumentMetadata
|
||||||
pageContent: str
|
pageContent: str # noqa: N815
|
||||||
|
|
||||||
|
|
||||||
class DocumentBase(BaseModel):
|
class DocumentBase(BaseModel):
|
||||||
document_type: DocumentType
|
document_type: DocumentType
|
||||||
content: List[ExtensionDocumentContent] | List[str] | str # Updated to allow string content
|
content: (
|
||||||
|
list[ExtensionDocumentContent] | list[str] | str
|
||||||
|
) # Updated to allow string content
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
|
||||||
|
|
||||||
class DocumentsCreate(DocumentBase):
|
class DocumentsCreate(DocumentBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DocumentUpdate(DocumentBase):
|
class DocumentUpdate(DocumentBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DocumentRead(BaseModel):
|
class DocumentRead(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
title: str
|
title: str
|
||||||
|
|
@ -36,4 +45,3 @@ class DocumentRead(BaseModel):
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,57 @@
|
||||||
from datetime import datetime
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, Dict, Any
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from .base import IDModel, TimestampModel
|
|
||||||
from app.db import LiteLLMProvider
|
from app.db import LiteLLMProvider
|
||||||
|
|
||||||
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
|
||||||
class LLMConfigBase(BaseModel):
|
class LLMConfigBase(BaseModel):
|
||||||
name: str = Field(..., max_length=100, description="User-friendly name for the LLM configuration")
|
name: str = Field(
|
||||||
|
..., max_length=100, description="User-friendly name for the LLM configuration"
|
||||||
|
)
|
||||||
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
|
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
|
||||||
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
|
custom_provider: str | None = Field(
|
||||||
model_name: str = Field(..., max_length=100, description="Model name without provider prefix")
|
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
..., max_length=100, description="Model name without provider prefix"
|
||||||
|
)
|
||||||
api_key: str = Field(..., description="API key for the provider")
|
api_key: str = Field(..., description="API key for the provider")
|
||||||
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
|
api_base: str | None = Field(
|
||||||
litellm_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional LiteLLM parameters")
|
None, max_length=500, description="Optional API base URL"
|
||||||
|
)
|
||||||
|
litellm_params: dict[str, Any] | None = Field(
|
||||||
|
default=None, description="Additional LiteLLM parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMConfigCreate(LLMConfigBase):
|
class LLMConfigCreate(LLMConfigBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LLMConfigUpdate(BaseModel):
|
class LLMConfigUpdate(BaseModel):
|
||||||
name: Optional[str] = Field(None, max_length=100, description="User-friendly name for the LLM configuration")
|
name: str | None = Field(
|
||||||
provider: Optional[LiteLLMProvider] = Field(None, description="LiteLLM provider type")
|
None, max_length=100, description="User-friendly name for the LLM configuration"
|
||||||
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
|
)
|
||||||
model_name: Optional[str] = Field(None, max_length=100, description="Model name without provider prefix")
|
provider: LiteLLMProvider | None = Field(None, description="LiteLLM provider type")
|
||||||
api_key: Optional[str] = Field(None, description="API key for the provider")
|
custom_provider: str | None = Field(
|
||||||
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
|
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||||
litellm_params: Optional[Dict[str, Any]] = Field(None, description="Additional LiteLLM parameters")
|
)
|
||||||
|
model_name: str | None = Field(
|
||||||
|
None, max_length=100, description="Model name without provider prefix"
|
||||||
|
)
|
||||||
|
api_key: str | None = Field(None, description="API key for the provider")
|
||||||
|
api_base: str | None = Field(
|
||||||
|
None, max_length=500, description="Optional API base URL"
|
||||||
|
)
|
||||||
|
litellm_params: dict[str, Any] | None = Field(
|
||||||
|
None, description="Additional LiteLLM parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
|
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
|
||||||
id: int
|
id: int
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,37 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from .base import IDModel, TimestampModel
|
|
||||||
from app.db import LogLevel, LogStatus
|
from app.db import LogLevel, LogStatus
|
||||||
|
|
||||||
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
|
||||||
class LogBase(BaseModel):
|
class LogBase(BaseModel):
|
||||||
level: LogLevel
|
level: LogLevel
|
||||||
status: LogStatus
|
status: LogStatus
|
||||||
message: str
|
message: str
|
||||||
source: Optional[str] = None
|
source: str | None = None
|
||||||
log_metadata: Optional[Dict[str, Any]] = None
|
log_metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class LogCreate(BaseModel):
|
class LogCreate(BaseModel):
|
||||||
level: LogLevel
|
level: LogLevel
|
||||||
status: LogStatus
|
status: LogStatus
|
||||||
message: str
|
message: str
|
||||||
source: Optional[str] = None
|
source: str | None = None
|
||||||
log_metadata: Optional[Dict[str, Any]] = None
|
log_metadata: dict[str, Any] | None = None
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
|
||||||
|
|
||||||
class LogUpdate(BaseModel):
|
class LogUpdate(BaseModel):
|
||||||
level: Optional[LogLevel] = None
|
level: LogLevel | None = None
|
||||||
status: Optional[LogStatus] = None
|
status: LogStatus | None = None
|
||||||
message: Optional[str] = None
|
message: str | None = None
|
||||||
source: Optional[str] = None
|
source: str | None = None
|
||||||
log_metadata: Optional[Dict[str, Any]] = None
|
log_metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class LogRead(LogBase, IDModel, TimestampModel):
|
class LogRead(LogBase, IDModel, TimestampModel):
|
||||||
id: int
|
id: int
|
||||||
|
|
@ -33,12 +40,13 @@ class LogRead(LogBase, IDModel, TimestampModel):
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
class LogFilter(BaseModel):
|
class LogFilter(BaseModel):
|
||||||
level: Optional[LogLevel] = None
|
level: LogLevel | None = None
|
||||||
status: Optional[LogStatus] = None
|
status: LogStatus | None = None
|
||||||
source: Optional[str] = None
|
source: str | None = None
|
||||||
search_space_id: Optional[int] = None
|
search_space_id: int | None = None
|
||||||
start_date: Optional[datetime] = None
|
start_date: datetime | None = None
|
||||||
end_date: Optional[datetime] = None
|
end_date: datetime | None = None
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
@ -1,24 +1,31 @@
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing import Any, List, Literal
|
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
|
||||||
class PodcastBase(BaseModel):
|
class PodcastBase(BaseModel):
|
||||||
title: str
|
title: str
|
||||||
podcast_transcript: List[Any]
|
podcast_transcript: list[Any]
|
||||||
file_location: str = ""
|
file_location: str = ""
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
|
||||||
|
|
||||||
class PodcastCreate(PodcastBase):
|
class PodcastCreate(PodcastBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PodcastUpdate(PodcastBase):
|
class PodcastUpdate(PodcastBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
class PodcastGenerateRequest(BaseModel):
|
class PodcastGenerateRequest(BaseModel):
|
||||||
type: Literal["DOCUMENT", "CHAT"]
|
type: Literal["DOCUMENT", "CHAT"]
|
||||||
ids: List[int]
|
ids: list[int]
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
podcast_title: str = "SurfSense Podcast"
|
podcast_title: str = "SurfSense Podcast"
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
from app.db import SearchSourceConnectorType
|
from app.db import SearchSourceConnectorType
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
|
||||||
|
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
|
@ -12,14 +13,14 @@ class SearchSourceConnectorBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
connector_type: SearchSourceConnectorType
|
connector_type: SearchSourceConnectorType
|
||||||
is_indexable: bool
|
is_indexable: bool
|
||||||
last_indexed_at: Optional[datetime] = None
|
last_indexed_at: datetime | None = None
|
||||||
config: Dict[str, Any]
|
config: dict[str, Any]
|
||||||
|
|
||||||
@field_validator("config")
|
@field_validator("config")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config_for_connector_type(
|
def validate_config_for_connector_type(
|
||||||
cls, config: Dict[str, Any], values: Dict[str, Any]
|
cls, config: dict[str, Any], values: dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
connector_type = values.data.get("connector_type")
|
connector_type = values.data.get("connector_type")
|
||||||
|
|
||||||
if connector_type == SearchSourceConnectorType.SERPER_API:
|
if connector_type == SearchSourceConnectorType.SERPER_API:
|
||||||
|
|
@ -150,11 +151,11 @@ class SearchSourceConnectorCreate(SearchSourceConnectorBase):
|
||||||
|
|
||||||
|
|
||||||
class SearchSourceConnectorUpdate(BaseModel):
|
class SearchSourceConnectorUpdate(BaseModel):
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
connector_type: Optional[SearchSourceConnectorType] = None
|
connector_type: SearchSourceConnectorType | None = None
|
||||||
is_indexable: Optional[bool] = None
|
is_indexable: bool | None = None
|
||||||
last_indexed_at: Optional[datetime] = None
|
last_indexed_at: datetime | None = None
|
||||||
config: Optional[Dict[str, Any]] = None
|
config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
|
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,24 @@
|
||||||
from datetime import datetime
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from datetime import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from .base import IDModel, TimestampModel
|
from .base import IDModel, TimestampModel
|
||||||
|
|
||||||
|
|
||||||
class SearchSpaceBase(BaseModel):
|
class SearchSpaceBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class SearchSpaceCreate(SearchSpaceBase):
|
class SearchSpaceCreate(SearchSpaceBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SearchSpaceUpdate(SearchSpaceBase):
|
class SearchSpaceUpdate(SearchSpaceBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||||
id: int
|
id: int
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi_users import schemas
|
from fastapi_users import schemas
|
||||||
|
|
||||||
|
|
||||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(schemas.BaseUserCreate):
|
class UserCreate(schemas.BaseUserCreate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(schemas.BaseUserUpdate):
|
class UserUpdate(schemas.BaseUserUpdate):
|
||||||
pass
|
pass
|
||||||
|
|
@ -1,5 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
from linkup import LinkupClient
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from tavily import TavilyClient
|
||||||
|
|
||||||
from app.agents.researcher.configuration import SearchMode
|
from app.agents.researcher.configuration import SearchMode
|
||||||
from app.db import (
|
from app.db import (
|
||||||
|
|
@ -11,15 +17,10 @@ from app.db import (
|
||||||
)
|
)
|
||||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||||
from linkup import LinkupClient
|
|
||||||
from sqlalchemy import func
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
from tavily import TavilyClient
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectorService:
|
class ConnectorService:
|
||||||
def __init__(self, session: AsyncSession, user_id: str = None):
|
def __init__(self, session: AsyncSession, user_id: str | None = None):
|
||||||
self.session = session
|
self.session = session
|
||||||
self.chunk_retriever = ChucksHybridSearchRetriever(session)
|
self.chunk_retriever = ChucksHybridSearchRetriever(session)
|
||||||
self.document_retriever = DocumentHybridSearchRetriever(session)
|
self.document_retriever = DocumentHybridSearchRetriever(session)
|
||||||
|
|
@ -52,7 +53,7 @@ class ConnectorService:
|
||||||
f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}"
|
f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error initializing source_id_counter: {str(e)}")
|
print(f"Error initializing source_id_counter: {e!s}")
|
||||||
# Fallback to default value
|
# Fallback to default value
|
||||||
self.source_id_counter = 1
|
self.source_id_counter = 1
|
||||||
|
|
||||||
|
|
@ -204,7 +205,9 @@ class ConnectorService:
|
||||||
|
|
||||||
return result_object, files_chunks
|
return result_object, files_chunks
|
||||||
|
|
||||||
def _transform_document_results(self, document_results: List[Dict]) -> List[Dict]:
|
def _transform_document_results(
|
||||||
|
self, document_results: list[dict[str, Any]]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Transform results from document_retriever.hybrid_search() to match the format
|
Transform results from document_retriever.hybrid_search() to match the format
|
||||||
expected by the processing code.
|
expected by the processing code.
|
||||||
|
|
@ -233,7 +236,7 @@ class ConnectorService:
|
||||||
|
|
||||||
async def get_connector_by_type(
|
async def get_connector_by_type(
|
||||||
self, user_id: str, connector_type: SearchSourceConnectorType
|
self, user_id: str, connector_type: SearchSourceConnectorType
|
||||||
) -> Optional[SearchSourceConnector]:
|
) -> SearchSourceConnector | None:
|
||||||
"""
|
"""
|
||||||
Get a connector by type for a specific user
|
Get a connector by type for a specific user
|
||||||
|
|
||||||
|
|
@ -350,7 +353,7 @@ class ConnectorService:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the error and return empty results
|
# Log the error and return empty results
|
||||||
print(f"Error searching with Tavily: {str(e)}")
|
print(f"Error searching with Tavily: {e!s}")
|
||||||
return {
|
return {
|
||||||
"id": 3,
|
"id": 3,
|
||||||
"name": "Tavily Search",
|
"name": "Tavily Search",
|
||||||
|
|
@ -596,7 +599,7 @@ class ConnectorService:
|
||||||
# Process each chunk and create sources directly without deduplication
|
# Process each chunk and create sources directly without deduplication
|
||||||
sources_list = []
|
sources_list = []
|
||||||
async with self.counter_lock:
|
async with self.counter_lock:
|
||||||
for i, chunk in enumerate(extension_chunks):
|
for _, chunk in enumerate(extension_chunks):
|
||||||
# Extract document metadata
|
# Extract document metadata
|
||||||
document = chunk.get("document", {})
|
document = chunk.get("document", {})
|
||||||
metadata = document.get("metadata", {})
|
metadata = document.get("metadata", {})
|
||||||
|
|
@ -608,7 +611,7 @@ class ConnectorService:
|
||||||
visit_duration = metadata.get(
|
visit_duration = metadata.get(
|
||||||
"VisitedWebPageVisitDurationInMilliseconds", ""
|
"VisitedWebPageVisitDurationInMilliseconds", ""
|
||||||
)
|
)
|
||||||
browsing_session_id = metadata.get("BrowsingSessionId", "")
|
_browsing_session_id = metadata.get("BrowsingSessionId", "")
|
||||||
|
|
||||||
# Create a more descriptive title for extension data
|
# Create a more descriptive title for extension data
|
||||||
title = webpage_title
|
title = webpage_title
|
||||||
|
|
@ -622,7 +625,7 @@ class ConnectorService:
|
||||||
else visit_date
|
else visit_date
|
||||||
)
|
)
|
||||||
title += f" (visited: {formatted_date})"
|
title += f" (visited: {formatted_date})"
|
||||||
except:
|
except Exception:
|
||||||
# Fallback if date parsing fails
|
# Fallback if date parsing fails
|
||||||
title += f" (visited: {visit_date})"
|
title += f" (visited: {visit_date})"
|
||||||
|
|
||||||
|
|
@ -642,7 +645,7 @@ class ConnectorService:
|
||||||
|
|
||||||
if description:
|
if description:
|
||||||
description += f" | Duration: {duration_text}"
|
description += f" | Duration: {duration_text}"
|
||||||
except:
|
except Exception:
|
||||||
# Fallback if duration parsing fails
|
# Fallback if duration parsing fails
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -1180,7 +1183,7 @@ class ConnectorService:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the error and return empty results
|
# Log the error and return empty results
|
||||||
print(f"Error searching with Linkup: {str(e)}")
|
print(f"Error searching with Linkup: {e!s}")
|
||||||
return {
|
return {
|
||||||
"id": 10,
|
"id": 10,
|
||||||
"name": "Linkup Search",
|
"name": "Linkup Search",
|
||||||
|
|
@ -1239,7 +1242,7 @@ class ConnectorService:
|
||||||
# Process each chunk and create sources directly without deduplication
|
# Process each chunk and create sources directly without deduplication
|
||||||
sources_list = []
|
sources_list = []
|
||||||
async with self.counter_lock:
|
async with self.counter_lock:
|
||||||
for i, chunk in enumerate(discord_chunks):
|
for _, chunk in enumerate(discord_chunks):
|
||||||
# Extract document metadata
|
# Extract document metadata
|
||||||
document = chunk.get("document", {})
|
document = chunk.get("document", {})
|
||||||
metadata = document.get("metadata", {})
|
metadata = document.get("metadata", {})
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,13 @@ SSL-safe implementation with pre-downloaded models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import ssl
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Any
|
import ssl
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DoclingService:
|
class DoclingService:
|
||||||
"""Docling service for enhanced document processing with SSL fixes."""
|
"""Docling service for enhanced document processing with SSL fixes."""
|
||||||
|
|
||||||
|
|
@ -29,11 +30,12 @@ class DoclingService:
|
||||||
ssl._create_default_https_context = ssl._create_unverified_context
|
ssl._create_default_https_context = ssl._create_unverified_context
|
||||||
|
|
||||||
# Set SSL environment variables if not already set
|
# Set SSL environment variables if not already set
|
||||||
if not os.environ.get('SSL_CERT_FILE'):
|
if not os.environ.get("SSL_CERT_FILE"):
|
||||||
try:
|
try:
|
||||||
import certifi
|
import certifi
|
||||||
os.environ['SSL_CERT_FILE'] = certifi.where()
|
|
||||||
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
|
os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||||
|
os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -45,6 +47,7 @@ class DoclingService:
|
||||||
"""Check and configure GPU support for WSL2 environment."""
|
"""Check and configure GPU support for WSL2 environment."""
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
gpu_count = torch.cuda.device_count()
|
gpu_count = torch.cuda.device_count()
|
||||||
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown"
|
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown"
|
||||||
|
|
@ -64,10 +67,10 @@ class DoclingService:
|
||||||
def _initialize_docling(self):
|
def _initialize_docling(self):
|
||||||
"""Initialize Docling with version-safe configuration."""
|
"""Initialize Docling with version-safe configuration."""
|
||||||
try:
|
try:
|
||||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
||||||
from docling.datamodel.base_models import InputFormat
|
from docling.datamodel.base_models import InputFormat
|
||||||
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||||
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||||
|
|
||||||
logger.info("🔧 Initializing Docling with version-safe configuration...")
|
logger.info("🔧 Initializing Docling with version-safe configuration...")
|
||||||
|
|
||||||
|
|
@ -75,19 +78,19 @@ class DoclingService:
|
||||||
pipeline_options = PdfPipelineOptions()
|
pipeline_options = PdfPipelineOptions()
|
||||||
|
|
||||||
# Disable OCR (user request)
|
# Disable OCR (user request)
|
||||||
if hasattr(pipeline_options, 'do_ocr'):
|
if hasattr(pipeline_options, "do_ocr"):
|
||||||
pipeline_options.do_ocr = False
|
pipeline_options.do_ocr = False
|
||||||
logger.info("⚠️ OCR disabled by user request")
|
logger.info("⚠️ OCR disabled by user request")
|
||||||
else:
|
else:
|
||||||
logger.warning("⚠️ OCR attribute not available in this Docling version")
|
logger.warning("⚠️ OCR attribute not available in this Docling version")
|
||||||
|
|
||||||
# Enable table structure if available
|
# Enable table structure if available
|
||||||
if hasattr(pipeline_options, 'do_table_structure'):
|
if hasattr(pipeline_options, "do_table_structure"):
|
||||||
pipeline_options.do_table_structure = True
|
pipeline_options.do_table_structure = True
|
||||||
logger.info("✅ Table structure detection enabled")
|
logger.info("✅ Table structure detection enabled")
|
||||||
|
|
||||||
# Configure GPU acceleration for WSL2 if available
|
# Configure GPU acceleration for WSL2 if available
|
||||||
if hasattr(pipeline_options, 'accelerator_device'):
|
if hasattr(pipeline_options, "accelerator_device"):
|
||||||
if self.use_gpu:
|
if self.use_gpu:
|
||||||
try:
|
try:
|
||||||
pipeline_options.accelerator_device = "cuda"
|
pipeline_options.accelerator_device = "cuda"
|
||||||
|
|
@ -99,98 +102,112 @@ class DoclingService:
|
||||||
pipeline_options.accelerator_device = "cpu"
|
pipeline_options.accelerator_device = "cpu"
|
||||||
logger.info("🖥️ Using CPU acceleration")
|
logger.info("🖥️ Using CPU acceleration")
|
||||||
else:
|
else:
|
||||||
logger.info("ℹ️ Accelerator device attribute not available in this Docling version")
|
logger.info(
|
||||||
|
"⚠️ Accelerator device attribute not available in this Docling version"
|
||||||
|
)
|
||||||
|
|
||||||
# Create PDF format option with backend
|
# Create PDF format option with backend
|
||||||
pdf_format_option = PdfFormatOption(
|
pdf_format_option = PdfFormatOption(
|
||||||
pipeline_options=pipeline_options,
|
pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend
|
||||||
backend=PyPdfiumDocumentBackend
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize DocumentConverter
|
# Initialize DocumentConverter
|
||||||
self.converter = DocumentConverter(
|
self.converter = DocumentConverter(
|
||||||
format_options={
|
format_options={InputFormat.PDF: pdf_format_option}
|
||||||
InputFormat.PDF: pdf_format_option
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU"
|
acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU"
|
||||||
logger.info(f"✅ Docling initialized successfully with {acceleration_type} acceleration")
|
logger.info(
|
||||||
|
f"✅ Docling initialized successfully with {acceleration_type} acceleration"
|
||||||
|
)
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"❌ Docling not installed: {e}")
|
logger.error(f"❌ Docling not installed: {e}")
|
||||||
raise RuntimeError(f"Docling not available: {e}")
|
raise RuntimeError(f"Docling not available: {e}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ Docling initialization failed: {e}")
|
logger.error(f"❌ Docling initialization failed: {e}")
|
||||||
raise RuntimeError(f"Docling initialization failed: {e}")
|
raise RuntimeError(f"Docling initialization failed: {e}") from e
|
||||||
|
|
||||||
def _configure_easyocr_local_models(self):
|
def _configure_easyocr_local_models(self):
|
||||||
"""Configure EasyOCR to use pre-downloaded local models."""
|
"""Configure EasyOCR to use pre-downloaded local models."""
|
||||||
try:
|
try:
|
||||||
import easyocr
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import easyocr
|
||||||
|
|
||||||
# Set SSL environment for EasyOCR downloads
|
# Set SSL environment for EasyOCR downloads
|
||||||
os.environ['CURL_CA_BUNDLE'] = ''
|
os.environ["CURL_CA_BUNDLE"] = ""
|
||||||
os.environ['REQUESTS_CA_BUNDLE'] = ''
|
os.environ["REQUESTS_CA_BUNDLE"] = ""
|
||||||
|
|
||||||
# Try to use local models first, fallback to download if needed
|
# Try to use local models first, fallback to download if needed
|
||||||
try:
|
try:
|
||||||
reader = easyocr.Reader(['en'],
|
reader = easyocr.Reader(
|
||||||
|
["en"],
|
||||||
download_enabled=False,
|
download_enabled=False,
|
||||||
model_storage_directory="/root/.EasyOCR/model")
|
model_storage_directory="/root/.EasyOCR/model",
|
||||||
|
)
|
||||||
logger.info("✅ EasyOCR configured for local models")
|
logger.info("✅ EasyOCR configured for local models")
|
||||||
return reader
|
return reader
|
||||||
except:
|
except Exception:
|
||||||
# If local models fail, allow download with SSL bypass
|
# If local models fail, allow download with SSL bypass
|
||||||
logger.info("🔄 Local models failed, attempting download with SSL bypass...")
|
logger.info(
|
||||||
reader = easyocr.Reader(['en'],
|
"🔄 Local models failed, attempting download with SSL bypass..."
|
||||||
|
)
|
||||||
|
reader = easyocr.Reader(
|
||||||
|
["en"],
|
||||||
download_enabled=True,
|
download_enabled=True,
|
||||||
model_storage_directory="/root/.EasyOCR/model")
|
model_storage_directory="/root/.EasyOCR/model",
|
||||||
|
)
|
||||||
logger.info("✅ EasyOCR configured with downloaded models")
|
logger.info("✅ EasyOCR configured with downloaded models")
|
||||||
return reader
|
return reader
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ EasyOCR configuration failed: {e}")
|
logger.warning(f"⚠️ EasyOCR configuration failed: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def process_document(self, file_path: str, filename: str = None) -> Dict[str, Any]:
|
async def process_document(
|
||||||
|
self, file_path: str, filename: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Process document with Docling using pre-downloaded models."""
|
"""Process document with Docling using pre-downloaded models."""
|
||||||
|
|
||||||
if self.converter is None:
|
if self.converter is None:
|
||||||
raise RuntimeError("Docling converter not initialized")
|
raise RuntimeError("Docling converter not initialized")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"🔄 Processing {filename} with Docling (using local models)...")
|
logger.info(
|
||||||
|
f"🔄 Processing {filename} with Docling (using local models)..."
|
||||||
|
)
|
||||||
|
|
||||||
# Process document with local models
|
# Process document with local models
|
||||||
result = self.converter.convert(file_path)
|
result = self.converter.convert(file_path)
|
||||||
|
|
||||||
# Extract content using version-safe methods
|
# Extract content using version-safe methods
|
||||||
content = None
|
content = None
|
||||||
if hasattr(result, 'document') and result.document:
|
if hasattr(result, "document") and result.document:
|
||||||
# Try different export methods (version compatibility)
|
# Try different export methods (version compatibility)
|
||||||
if hasattr(result.document, 'export_to_markdown'):
|
if hasattr(result.document, "export_to_markdown"):
|
||||||
content = result.document.export_to_markdown()
|
content = result.document.export_to_markdown()
|
||||||
logger.info("📄 Used export_to_markdown method")
|
logger.info("📄 Used export_to_markdown method")
|
||||||
elif hasattr(result.document, 'to_markdown'):
|
elif hasattr(result.document, "to_markdown"):
|
||||||
content = result.document.to_markdown()
|
content = result.document.to_markdown()
|
||||||
logger.info("📄 Used to_markdown method")
|
logger.info("📄 Used to_markdown method")
|
||||||
elif hasattr(result.document, 'text'):
|
elif hasattr(result.document, "text"):
|
||||||
content = result.document.text
|
content = result.document.text
|
||||||
logger.info("📄 Used text property")
|
logger.info("📄 Used text property")
|
||||||
elif hasattr(result.document, '__str__'):
|
elif hasattr(result.document, "__str__"):
|
||||||
content = str(result.document)
|
content = str(result.document)
|
||||||
logger.info("📄 Used string conversion")
|
logger.info("📄 Used string conversion")
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
logger.info(f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)")
|
logger.info(
|
||||||
|
f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'content': content,
|
"content": content,
|
||||||
'full_text': content,
|
"full_text": content,
|
||||||
'service_used': 'docling',
|
"service_used": "docling",
|
||||||
'status': 'success',
|
"status": "success",
|
||||||
'processing_notes': 'Processed with Docling using pre-downloaded models'
|
"processing_notes": "Processed with Docling using pre-downloaded models",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError("No content could be extracted from document")
|
raise ValueError("No content could be extracted from document")
|
||||||
|
|
@ -201,14 +218,12 @@ class DoclingService:
|
||||||
logger.error(f"❌ Docling processing failed for {filename}: {e}")
|
logger.error(f"❌ Docling processing failed for {filename}: {e}")
|
||||||
# Log the full error for debugging
|
# Log the full error for debugging
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||||
raise RuntimeError(f"Docling processing failed: {e}")
|
raise RuntimeError(f"Docling processing failed: {e}") from e
|
||||||
|
|
||||||
async def process_large_document_summary(
|
async def process_large_document_summary(
|
||||||
self,
|
self, content: str, llm, document_title: str = "Document"
|
||||||
content: str,
|
|
||||||
llm,
|
|
||||||
document_title: str = "Document"
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process large documents using chunked LLM summarization.
|
Process large documents using chunked LLM summarization.
|
||||||
|
|
@ -222,24 +237,28 @@ class DoclingService:
|
||||||
Final summary of the document
|
Final summary of the document
|
||||||
"""
|
"""
|
||||||
# Large document threshold (100K characters ≈ 25K tokens)
|
# Large document threshold (100K characters ≈ 25K tokens)
|
||||||
LARGE_DOCUMENT_THRESHOLD = 100_000
|
large_document_threshold = 100_000
|
||||||
|
|
||||||
if len(content) <= LARGE_DOCUMENT_THRESHOLD:
|
if len(content) <= large_document_threshold:
|
||||||
# For smaller documents, use direct processing
|
# For smaller documents, use direct processing
|
||||||
logger.info(f"📄 Document size: {len(content)} chars - using direct processing")
|
logger.info(
|
||||||
|
f"📄 Document size: {len(content)} chars - using direct processing"
|
||||||
|
)
|
||||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||||
|
|
||||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
||||||
result = await summary_chain.ainvoke({"document": content})
|
result = await summary_chain.ainvoke({"document": content})
|
||||||
return result.content
|
return result.content
|
||||||
|
|
||||||
logger.info(f"📚 Large document detected: {len(content)} chars - using chunked processing")
|
logger.info(
|
||||||
|
f"📚 Large document detected: {len(content)} chars - using chunked processing"
|
||||||
|
)
|
||||||
|
|
||||||
# Import chunker from config
|
# Import chunker from config
|
||||||
from app.config import config
|
# Create LLM-optimized chunks (8K tokens max for safety)
|
||||||
|
from chonkie import OverlapRefinery, RecursiveChunker
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
# Create LLM-optimized chunks (8K tokens max for safety)
|
|
||||||
from chonkie import RecursiveChunker, OverlapRefinery
|
|
||||||
llm_chunker = RecursiveChunker(
|
llm_chunker = RecursiveChunker(
|
||||||
chunk_size=8000 # Conservative for most LLMs
|
chunk_size=8000 # Conservative for most LLMs
|
||||||
)
|
)
|
||||||
|
|
@ -247,7 +266,7 @@ class DoclingService:
|
||||||
# Apply overlap refinery for context preservation (10% overlap = 800 tokens)
|
# Apply overlap refinery for context preservation (10% overlap = 800 tokens)
|
||||||
overlap_refinery = OverlapRefinery(
|
overlap_refinery = OverlapRefinery(
|
||||||
context_size=0.1, # 10% overlap for context preservation
|
context_size=0.1, # 10% overlap for context preservation
|
||||||
method="suffix" # Add next chunk context to current chunk
|
method="suffix", # Add next chunk context to current chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
# First chunk the content, then apply overlap refinery
|
# First chunk the content, then apply overlap refinery
|
||||||
|
|
@ -274,21 +293,25 @@ Chunk {chunk_number}/{total_chunks}:
|
||||||
<document_chunk>
|
<document_chunk>
|
||||||
{chunk}
|
{chunk}
|
||||||
</document_chunk>
|
</document_chunk>
|
||||||
</INSTRUCTIONS>"""
|
</INSTRUCTIONS>""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process each chunk individually
|
# Process each chunk individually
|
||||||
chunk_summaries = []
|
chunk_summaries = []
|
||||||
for i, chunk in enumerate(chunks, 1):
|
for i, chunk in enumerate(chunks, 1):
|
||||||
try:
|
try:
|
||||||
logger.info(f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)")
|
logger.info(
|
||||||
|
f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)"
|
||||||
|
)
|
||||||
|
|
||||||
chunk_chain = chunk_template | llm
|
chunk_chain = chunk_template | llm
|
||||||
chunk_result = await chunk_chain.ainvoke({
|
chunk_result = await chunk_chain.ainvoke(
|
||||||
|
{
|
||||||
"chunk": chunk.text,
|
"chunk": chunk.text,
|
||||||
"chunk_number": i,
|
"chunk_number": i,
|
||||||
"total_chunks": total_chunks
|
"total_chunks": total_chunks,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
chunk_summary = chunk_result.content
|
chunk_summary = chunk_result.content
|
||||||
chunk_summaries.append(f"=== Section {i} ===\n{chunk_summary}")
|
chunk_summaries.append(f"=== Section {i} ===\n{chunk_summary}")
|
||||||
|
|
@ -318,19 +341,20 @@ Ensure:
|
||||||
<section_summaries>
|
<section_summaries>
|
||||||
{summaries}
|
{summaries}
|
||||||
</section_summaries>
|
</section_summaries>
|
||||||
</INSTRUCTIONS>"""
|
</INSTRUCTIONS>""",
|
||||||
)
|
)
|
||||||
|
|
||||||
combined_summaries = "\n\n".join(chunk_summaries)
|
combined_summaries = "\n\n".join(chunk_summaries)
|
||||||
combine_chain = combine_template | llm
|
combine_chain = combine_template | llm
|
||||||
|
|
||||||
final_result = await combine_chain.ainvoke({
|
final_result = await combine_chain.ainvoke(
|
||||||
"summaries": combined_summaries,
|
{"summaries": combined_summaries, "document_title": document_title}
|
||||||
"document_title": document_title
|
)
|
||||||
})
|
|
||||||
|
|
||||||
final_summary = final_result.content
|
final_summary = final_result.content
|
||||||
logger.info(f"✅ Large document processing complete: {len(final_summary)} chars summary")
|
logger.info(
|
||||||
|
f"✅ Large document processing complete: {len(final_summary)} chars summary"
|
||||||
|
)
|
||||||
|
|
||||||
return final_summary
|
return final_summary
|
||||||
|
|
||||||
|
|
@ -341,6 +365,7 @@ Ensure:
|
||||||
logger.warning("⚠️ Using fallback combined summary")
|
logger.warning("⚠️ Using fallback combined summary")
|
||||||
return fallback_summary
|
return fallback_summary
|
||||||
|
|
||||||
|
|
||||||
def create_docling_service() -> DoclingService:
|
def create_docling_service() -> DoclingService:
|
||||||
"""Create a Docling service instance."""
|
"""Create a Docling service instance."""
|
||||||
return DoclingService()
|
return DoclingService()
|
||||||
|
|
@ -1,23 +1,23 @@
|
||||||
from typing import Optional
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
from langchain_community.chat_models import ChatLiteLLM
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from app.db import User, LLMConfig
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import LLMConfig, User
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMRole:
|
class LLMRole:
|
||||||
LONG_CONTEXT = "long_context"
|
LONG_CONTEXT = "long_context"
|
||||||
FAST = "fast"
|
FAST = "fast"
|
||||||
STRATEGIC = "strategic"
|
STRATEGIC = "strategic"
|
||||||
|
|
||||||
|
|
||||||
async def get_user_llm_instance(
|
async def get_user_llm_instance(
|
||||||
session: AsyncSession,
|
session: AsyncSession, user_id: str, role: str
|
||||||
user_id: str,
|
) -> ChatLiteLLM | None:
|
||||||
role: str
|
|
||||||
) -> Optional[ChatLiteLLM]:
|
|
||||||
"""
|
"""
|
||||||
Get a ChatLiteLLM instance for a specific user and role.
|
Get a ChatLiteLLM instance for a specific user and role.
|
||||||
|
|
||||||
|
|
@ -31,9 +31,7 @@ async def get_user_llm_instance(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get user with their LLM preferences
|
# Get user with their LLM preferences
|
||||||
result = await session.execute(
|
result = await session.execute(select(User).where(User.id == user_id))
|
||||||
select(User).where(User.id == user_id)
|
|
||||||
)
|
|
||||||
user = result.scalars().first()
|
user = result.scalars().first()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
|
|
@ -59,8 +57,7 @@ async def get_user_llm_instance(
|
||||||
# Get the LLM configuration
|
# Get the LLM configuration
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(LLMConfig).where(
|
select(LLMConfig).where(
|
||||||
LLMConfig.id == llm_config_id,
|
LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id
|
||||||
LLMConfig.user_id == user_id
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
llm_config = result.scalars().first()
|
llm_config = result.scalars().first()
|
||||||
|
|
@ -84,7 +81,9 @@ async def get_user_llm_instance(
|
||||||
"MISTRAL": "mistral",
|
"MISTRAL": "mistral",
|
||||||
# Add more mappings as needed
|
# Add more mappings as needed
|
||||||
}
|
}
|
||||||
provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower())
|
provider_prefix = provider_map.get(
|
||||||
|
llm_config.provider.value, llm_config.provider.value.lower()
|
||||||
|
)
|
||||||
model_string = f"{provider_prefix}/{llm_config.model_name}"
|
model_string = f"{provider_prefix}/{llm_config.model_name}"
|
||||||
|
|
||||||
# Create ChatLiteLLM instance
|
# Create ChatLiteLLM instance
|
||||||
|
|
@ -104,17 +103,26 @@ async def get_user_llm_instance(
|
||||||
return ChatLiteLLM(**litellm_kwargs)
|
return ChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}")
|
logger.error(
|
||||||
|
f"Error getting LLM instance for user {user_id}, role {role}: {e!s}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
|
||||||
|
async def get_user_long_context_llm(
|
||||||
|
session: AsyncSession, user_id: str
|
||||||
|
) -> ChatLiteLLM | None:
|
||||||
"""Get user's long context LLM instance."""
|
"""Get user's long context LLM instance."""
|
||||||
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
|
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
|
||||||
|
|
||||||
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
|
||||||
|
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None:
|
||||||
"""Get user's fast LLM instance."""
|
"""Get user's fast LLM instance."""
|
||||||
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
|
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
|
||||||
|
|
||||||
async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
|
||||||
|
async def get_user_strategic_llm(
|
||||||
|
session: AsyncSession, user_id: str
|
||||||
|
) -> ChatLiteLLM | None:
|
||||||
"""Get user's strategic LLM instance."""
|
"""Get user's strategic LLM instance."""
|
||||||
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import datetime
|
import datetime
|
||||||
from langchain.schema import HumanMessage, SystemMessage, AIMessage
|
from typing import Any
|
||||||
from app.config import config
|
|
||||||
from app.services.llm_service import get_user_strategic_llm
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
from app.services.llm_service import get_user_strategic_llm
|
||||||
|
|
||||||
|
|
||||||
class QueryService:
|
class QueryService:
|
||||||
|
|
@ -16,7 +17,7 @@ class QueryService:
|
||||||
user_query: str,
|
user_query: str,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
chat_history_str: Optional[str] = None
|
chat_history_str: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Reformulate the user query using the user's strategic LLM to make it more
|
Reformulate the user query using the user's strategic LLM to make it more
|
||||||
|
|
@ -38,7 +39,9 @@ class QueryService:
|
||||||
# Get the user's strategic LLM instance
|
# Get the user's strategic LLM instance
|
||||||
llm = await get_user_strategic_llm(session, user_id)
|
llm = await get_user_strategic_llm(session, user_id)
|
||||||
if not llm:
|
if not llm:
|
||||||
print(f"Warning: No strategic LLM configured for user {user_id}. Using original query.")
|
print(
|
||||||
|
f"Warning: No strategic LLM configured for user {user_id}. Using original query."
|
||||||
|
)
|
||||||
return user_query
|
return user_query
|
||||||
|
|
||||||
# Create system message with instructions
|
# Create system message with instructions
|
||||||
|
|
@ -92,9 +95,8 @@ class QueryService:
|
||||||
print(f"Error reformulating query: {e}")
|
print(f"Error reformulating query: {e}")
|
||||||
return user_query
|
return user_query
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def langchain_chat_history_to_str(chat_history: List[Any]) -> str:
|
async def langchain_chat_history_to_str(chat_history: list[Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a list of chat history messages to a string.
|
Convert a list of chat history messages to a string.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from rerankers import Document as RerankerDocument
|
from rerankers import Document as RerankerDocument
|
||||||
|
|
||||||
|
|
||||||
class RerankerService:
|
class RerankerService:
|
||||||
"""
|
"""
|
||||||
Service for reranking documents using a configured reranker
|
Service for reranking documents using a configured reranker
|
||||||
|
|
@ -16,7 +18,9 @@ class RerankerService:
|
||||||
"""
|
"""
|
||||||
self.reranker_instance = reranker_instance
|
self.reranker_instance = reranker_instance
|
||||||
|
|
||||||
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def rerank_documents(
|
||||||
|
self, query_text: str, documents: list[dict[str, Any]]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Rerank documents using the configured reranker
|
Rerank documents using the configured reranker
|
||||||
|
|
||||||
|
|
@ -44,18 +48,17 @@ class RerankerService:
|
||||||
text=content,
|
text=content,
|
||||||
doc_id=chunk_id,
|
doc_id=chunk_id,
|
||||||
metadata={
|
metadata={
|
||||||
'document_id': document_info.get("id", ""),
|
"document_id": document_info.get("id", ""),
|
||||||
'document_title': document_info.get("title", ""),
|
"document_title": document_info.get("title", ""),
|
||||||
'document_type': document_info.get("document_type", ""),
|
"document_type": document_info.get("document_type", ""),
|
||||||
'rrf_score': score
|
"rrf_score": score,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Rerank using the configured reranker
|
# Rerank using the configured reranker
|
||||||
reranking_results = self.reranker_instance.rank(
|
reranking_results = self.reranker_instance.rank(
|
||||||
query=query_text,
|
query=query_text, docs=reranker_docs
|
||||||
docs=reranker_docs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process the results from the reranker
|
# Process the results from the reranker
|
||||||
|
|
@ -63,7 +66,14 @@ class RerankerService:
|
||||||
serialized_results = []
|
serialized_results = []
|
||||||
for result in reranking_results.results:
|
for result in reranking_results.results:
|
||||||
# Find the original document by id
|
# Find the original document by id
|
||||||
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None)
|
original_doc = next(
|
||||||
|
(
|
||||||
|
doc
|
||||||
|
for doc in documents
|
||||||
|
if doc.get("chunk_id") == result.document.doc_id
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
if original_doc:
|
if original_doc:
|
||||||
# Create a new document with the reranked score
|
# Create a new document with the reranked score
|
||||||
reranked_doc = original_doc.copy()
|
reranked_doc = original_doc.copy()
|
||||||
|
|
@ -75,12 +85,12 @@ class RerankerService:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the error
|
# Log the error
|
||||||
logging.error(f"Error during reranking: {str(e)}")
|
logging.error(f"Error during reranking: {e!s}")
|
||||||
# Fall back to original documents without reranking
|
# Fall back to original documents without reranking
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_reranker_instance() -> Optional['RerankerService']:
|
def get_reranker_instance() -> Optional["RerankerService"]:
|
||||||
"""
|
"""
|
||||||
Get a reranker service instance from the global configuration.
|
Get a reranker service instance from the global configuration.
|
||||||
|
|
||||||
|
|
@ -89,7 +99,6 @@ class RerankerService:
|
||||||
"""
|
"""
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
|
||||||
if hasattr(config, 'reranker_instance') and config.reranker_instance:
|
if hasattr(config, "reranker_instance") and config.reranker_instance:
|
||||||
return RerankerService(config.reranker_instance)
|
return RerankerService(config.reranker_instance)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -1,27 +1,15 @@
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class StreamingService:
|
class StreamingService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.terminal_idx = 1
|
self.terminal_idx = 1
|
||||||
self.message_annotations = [
|
self.message_annotations = [
|
||||||
{
|
{"type": "TERMINAL_INFO", "content": []},
|
||||||
"type": "TERMINAL_INFO",
|
{"type": "SOURCES", "content": []},
|
||||||
"content": []
|
{"type": "ANSWER", "content": []},
|
||||||
},
|
{"type": "FURTHER_QUESTIONS", "content": []},
|
||||||
{
|
|
||||||
"type": "SOURCES",
|
|
||||||
"content": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "ANSWER",
|
|
||||||
"content": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "FURTHER_QUESTIONS",
|
|
||||||
"content": []
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# DEPRECATED: This sends the full annotation array every time (inefficient)
|
# DEPRECATED: This sends the full annotation array every time (inefficient)
|
||||||
|
|
@ -35,7 +23,7 @@ class StreamingService:
|
||||||
Returns:
|
Returns:
|
||||||
str: The formatted annotations string
|
str: The formatted annotations string
|
||||||
"""
|
"""
|
||||||
return f'8:{json.dumps(self.message_annotations)}\n'
|
return f"8:{json.dumps(self.message_annotations)}\n"
|
||||||
|
|
||||||
def format_terminal_info_delta(self, text: str, message_type: str = "info") -> str:
|
def format_terminal_info_delta(self, text: str, message_type: str = "info") -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -58,7 +46,7 @@ class StreamingService:
|
||||||
annotation = {"type": "TERMINAL_INFO", "content": [message]}
|
annotation = {"type": "TERMINAL_INFO", "content": [message]}
|
||||||
return f"8:[{json.dumps(annotation)}]\n"
|
return f"8:[{json.dumps(annotation)}]\n"
|
||||||
|
|
||||||
def format_sources_delta(self, sources: List[Dict[str, Any]]) -> str:
|
def format_sources_delta(self, sources: list[dict[str, Any]]) -> str:
|
||||||
"""
|
"""
|
||||||
Format sources as a delta annotation
|
Format sources as a delta annotation
|
||||||
|
|
||||||
|
|
@ -95,7 +83,7 @@ class StreamingService:
|
||||||
annotation = {"type": "ANSWER", "content": [answer_chunk]}
|
annotation = {"type": "ANSWER", "content": [answer_chunk]}
|
||||||
return f"8:[{json.dumps(annotation)}]\n"
|
return f"8:[{json.dumps(annotation)}]\n"
|
||||||
|
|
||||||
def format_answer_annotation(self, answer_lines: List[str]) -> str:
|
def format_answer_annotation(self, answer_lines: list[str]) -> str:
|
||||||
"""
|
"""
|
||||||
Format the complete answer as a replacement annotation
|
Format the complete answer as a replacement annotation
|
||||||
|
|
||||||
|
|
@ -113,7 +101,7 @@ class StreamingService:
|
||||||
return f"8:[{json.dumps(annotation)}]\n"
|
return f"8:[{json.dumps(annotation)}]\n"
|
||||||
|
|
||||||
def format_further_questions_delta(
|
def format_further_questions_delta(
|
||||||
self, further_questions: List[Dict[str, Any]]
|
self, further_questions: list[dict[str, Any]]
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format further questions as a delta annotation
|
Format further questions as a delta annotation
|
||||||
|
|
@ -155,7 +143,9 @@ class StreamingService:
|
||||||
"""
|
"""
|
||||||
return f"3:{json.dumps(error_message)}\n"
|
return f"3:{json.dumps(error_message)}\n"
|
||||||
|
|
||||||
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
|
def format_completion(
|
||||||
|
self, prompt_tokens: int = 156, completion_tokens: int = 204
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format a completion message
|
Format a completion message
|
||||||
|
|
||||||
|
|
@ -172,7 +162,7 @@ class StreamingService:
|
||||||
"usage": {
|
"usage": {
|
||||||
"promptTokens": prompt_tokens,
|
"promptTokens": prompt_tokens,
|
||||||
"completionTokens": completion_tokens,
|
"completionTokens": completion_tokens,
|
||||||
"totalTokens": total_tokens
|
"totalTokens": total_tokens,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
return f"d:{json.dumps(completion_data)}\n"
|
||||||
return f'd:{json.dumps(completion_data)}\n'
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from app.db import Log, LogLevel, LogStatus
|
|
||||||
import logging
|
import logging
|
||||||
import json
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import Log, LogLevel, LogStatus
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TaskLoggingService:
|
class TaskLoggingService:
|
||||||
"""Service for logging background tasks using the database Log model"""
|
"""Service for logging background tasks using the database Log model"""
|
||||||
|
|
||||||
|
|
@ -19,7 +21,7 @@ class TaskLoggingService:
|
||||||
task_name: str,
|
task_name: str,
|
||||||
source: str,
|
source: str,
|
||||||
message: str,
|
message: str,
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> Log:
|
) -> Log:
|
||||||
"""
|
"""
|
||||||
Log the start of a task with IN_PROGRESS status
|
Log the start of a task with IN_PROGRESS status
|
||||||
|
|
@ -34,10 +36,9 @@ class TaskLoggingService:
|
||||||
Log: The created log entry
|
Log: The created log entry
|
||||||
"""
|
"""
|
||||||
log_metadata = metadata or {}
|
log_metadata = metadata or {}
|
||||||
log_metadata.update({
|
log_metadata.update(
|
||||||
"task_name": task_name,
|
{"task_name": task_name, "started_at": datetime.utcnow().isoformat()}
|
||||||
"started_at": datetime.utcnow().isoformat()
|
)
|
||||||
})
|
|
||||||
|
|
||||||
log_entry = Log(
|
log_entry = Log(
|
||||||
level=LogLevel.INFO,
|
level=LogLevel.INFO,
|
||||||
|
|
@ -45,7 +46,7 @@ class TaskLoggingService:
|
||||||
message=message,
|
message=message,
|
||||||
source=source,
|
source=source,
|
||||||
log_metadata=log_metadata,
|
log_metadata=log_metadata,
|
||||||
search_space_id=self.search_space_id
|
search_space_id=self.search_space_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session.add(log_entry)
|
self.session.add(log_entry)
|
||||||
|
|
@ -59,7 +60,7 @@ class TaskLoggingService:
|
||||||
self,
|
self,
|
||||||
log_entry: Log,
|
log_entry: Log,
|
||||||
message: str,
|
message: str,
|
||||||
additional_metadata: Optional[Dict[str, Any]] = None
|
additional_metadata: dict[str, Any] | None = None,
|
||||||
) -> Log:
|
) -> Log:
|
||||||
"""
|
"""
|
||||||
Update a log entry to SUCCESS status
|
Update a log entry to SUCCESS status
|
||||||
|
|
@ -86,7 +87,11 @@ class TaskLoggingService:
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
await self.session.refresh(log_entry)
|
await self.session.refresh(log_entry)
|
||||||
|
|
||||||
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
|
task_name = (
|
||||||
|
log_entry.log_metadata.get("task_name", "unknown")
|
||||||
|
if log_entry.log_metadata
|
||||||
|
else "unknown"
|
||||||
|
)
|
||||||
logger.info(f"Completed task {task_name}: {message}")
|
logger.info(f"Completed task {task_name}: {message}")
|
||||||
return log_entry
|
return log_entry
|
||||||
|
|
||||||
|
|
@ -94,8 +99,8 @@ class TaskLoggingService:
|
||||||
self,
|
self,
|
||||||
log_entry: Log,
|
log_entry: Log,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
error_details: Optional[str] = None,
|
error_details: str | None = None,
|
||||||
additional_metadata: Optional[Dict[str, Any]] = None
|
additional_metadata: dict[str, Any] | None = None,
|
||||||
) -> Log:
|
) -> Log:
|
||||||
"""
|
"""
|
||||||
Update a log entry to FAILED status
|
Update a log entry to FAILED status
|
||||||
|
|
@ -118,10 +123,9 @@ class TaskLoggingService:
|
||||||
if log_entry.log_metadata is None:
|
if log_entry.log_metadata is None:
|
||||||
log_entry.log_metadata = {}
|
log_entry.log_metadata = {}
|
||||||
|
|
||||||
log_entry.log_metadata.update({
|
log_entry.log_metadata.update(
|
||||||
"failed_at": datetime.utcnow().isoformat(),
|
{"failed_at": datetime.utcnow().isoformat(), "error_details": error_details}
|
||||||
"error_details": error_details
|
)
|
||||||
})
|
|
||||||
|
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
log_entry.log_metadata.update(additional_metadata)
|
log_entry.log_metadata.update(additional_metadata)
|
||||||
|
|
@ -129,7 +133,11 @@ class TaskLoggingService:
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
await self.session.refresh(log_entry)
|
await self.session.refresh(log_entry)
|
||||||
|
|
||||||
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
|
task_name = (
|
||||||
|
log_entry.log_metadata.get("task_name", "unknown")
|
||||||
|
if log_entry.log_metadata
|
||||||
|
else "unknown"
|
||||||
|
)
|
||||||
logger.error(f"Failed task {task_name}: {error_message}")
|
logger.error(f"Failed task {task_name}: {error_message}")
|
||||||
if error_details:
|
if error_details:
|
||||||
logger.error(f"Error details: {error_details}")
|
logger.error(f"Error details: {error_details}")
|
||||||
|
|
@ -140,7 +148,7 @@ class TaskLoggingService:
|
||||||
self,
|
self,
|
||||||
log_entry: Log,
|
log_entry: Log,
|
||||||
progress_message: str,
|
progress_message: str,
|
||||||
progress_metadata: Optional[Dict[str, Any]] = None
|
progress_metadata: dict[str, Any] | None = None,
|
||||||
) -> Log:
|
) -> Log:
|
||||||
"""
|
"""
|
||||||
Update a log entry with progress information while keeping IN_PROGRESS status
|
Update a log entry with progress information while keeping IN_PROGRESS status
|
||||||
|
|
@ -159,12 +167,18 @@ class TaskLoggingService:
|
||||||
if log_entry.log_metadata is None:
|
if log_entry.log_metadata is None:
|
||||||
log_entry.log_metadata = {}
|
log_entry.log_metadata = {}
|
||||||
log_entry.log_metadata.update(progress_metadata)
|
log_entry.log_metadata.update(progress_metadata)
|
||||||
log_entry.log_metadata["last_progress_update"] = datetime.utcnow().isoformat()
|
log_entry.log_metadata["last_progress_update"] = (
|
||||||
|
datetime.utcnow().isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
await self.session.refresh(log_entry)
|
await self.session.refresh(log_entry)
|
||||||
|
|
||||||
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
|
task_name = (
|
||||||
|
log_entry.log_metadata.get("task_name", "unknown")
|
||||||
|
if log_entry.log_metadata
|
||||||
|
else "unknown"
|
||||||
|
)
|
||||||
logger.info(f"Progress update for task {task_name}: {progress_message}")
|
logger.info(f"Progress update for task {task_name}: {progress_message}")
|
||||||
return log_entry
|
return log_entry
|
||||||
|
|
||||||
|
|
@ -173,7 +187,7 @@ class TaskLoggingService:
|
||||||
level: LogLevel,
|
level: LogLevel,
|
||||||
source: str,
|
source: str,
|
||||||
message: str,
|
message: str,
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> Log:
|
) -> Log:
|
||||||
"""
|
"""
|
||||||
Log a simple event (not a long-running task)
|
Log a simple event (not a long-running task)
|
||||||
|
|
@ -193,7 +207,7 @@ class TaskLoggingService:
|
||||||
message=message,
|
message=message,
|
||||||
source=source,
|
source=source,
|
||||||
log_metadata=metadata or {},
|
log_metadata=metadata or {},
|
||||||
search_space_id=self.search_space_id
|
search_space_id=self.search_space_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session.add(log_entry)
|
self.session.add(log_entry)
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,33 @@
|
||||||
from typing import Optional, List
|
import logging
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import validators
|
||||||
|
from langchain_community.document_loaders import AsyncChromiumLoader, FireCrawlLoader
|
||||||
|
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||||
|
from langchain_core.documents import Document as LangChainDocument
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from app.db import Document, DocumentType, Chunk
|
from youtube_transcript_api import YouTubeTranscriptApi
|
||||||
from app.schemas import ExtensionDocumentContent
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
from app.db import Chunk, Document, DocumentType
|
||||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||||
from app.utils.document_converters import convert_document_to_markdown, generate_content_hash
|
from app.schemas import ExtensionDocumentContent
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from langchain_core.documents import Document as LangChainDocument
|
from app.utils.document_converters import (
|
||||||
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
|
convert_document_to_markdown,
|
||||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
generate_content_hash,
|
||||||
import validators
|
)
|
||||||
from youtube_transcript_api import YouTubeTranscriptApi
|
|
||||||
from urllib.parse import urlparse, parse_qs
|
|
||||||
import aiohttp
|
|
||||||
import logging
|
|
||||||
|
|
||||||
md = MarkdownifyTransformer()
|
md = MarkdownifyTransformer()
|
||||||
|
|
||||||
|
|
||||||
async def add_crawled_url_document(
|
async def add_crawled_url_document(
|
||||||
session: AsyncSession, url: str, search_space_id: int, user_id: str
|
session: AsyncSession, url: str, search_space_id: int, user_id: str
|
||||||
) -> Optional[Document]:
|
) -> Document | None:
|
||||||
task_logger = TaskLoggingService(session, search_space_id)
|
task_logger = TaskLoggingService(session, search_space_id)
|
||||||
|
|
||||||
# Log task start
|
# Log task start
|
||||||
|
|
@ -30,15 +35,13 @@ async def add_crawled_url_document(
|
||||||
task_name="crawl_url_document",
|
task_name="crawl_url_document",
|
||||||
source="background_task",
|
source="background_task",
|
||||||
message=f"Starting URL crawling process for: {url}",
|
message=f"Starting URL crawling process for: {url}",
|
||||||
metadata={"url": url, "user_id": str(user_id)}
|
metadata={"url": url, "user_id": str(user_id)},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# URL validation step
|
# URL validation step
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry, f"Validating URL: {url}", {"stage": "validation"}
|
||||||
f"Validating URL: {url}",
|
|
||||||
{"stage": "validation"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not validators.url(url):
|
if not validators.url(url):
|
||||||
|
|
@ -48,7 +51,10 @@ async def add_crawled_url_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Setting up crawler for URL: {url}",
|
f"Setting up crawler for URL: {url}",
|
||||||
{"stage": "crawler_setup", "firecrawl_available": bool(config.FIRECRAWL_API_KEY)}
|
{
|
||||||
|
"stage": "crawler_setup",
|
||||||
|
"firecrawl_available": bool(config.FIRECRAWL_API_KEY),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.FIRECRAWL_API_KEY:
|
if config.FIRECRAWL_API_KEY:
|
||||||
|
|
@ -68,21 +74,21 @@ async def add_crawled_url_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Crawling URL content: {url}",
|
f"Crawling URL content: {url}",
|
||||||
{"stage": "crawling", "crawler_type": type(crawl_loader).__name__}
|
{"stage": "crawling", "crawler_type": type(crawl_loader).__name__},
|
||||||
)
|
)
|
||||||
|
|
||||||
url_crawled = await crawl_loader.aload()
|
url_crawled = await crawl_loader.aload()
|
||||||
|
|
||||||
if type(crawl_loader) == FireCrawlLoader:
|
if isinstance(crawl_loader, FireCrawlLoader):
|
||||||
content_in_markdown = url_crawled[0].page_content
|
content_in_markdown = url_crawled[0].page_content
|
||||||
elif type(crawl_loader) == AsyncChromiumLoader:
|
elif isinstance(crawl_loader, AsyncChromiumLoader):
|
||||||
content_in_markdown = md.transform_documents(url_crawled)[0].page_content
|
content_in_markdown = md.transform_documents(url_crawled)[0].page_content
|
||||||
|
|
||||||
# Format document
|
# Format document
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing crawled content from: {url}",
|
f"Processing crawled content from: {url}",
|
||||||
{"stage": "content_processing", "content_length": len(content_in_markdown)}
|
{"stage": "content_processing", "content_length": len(content_in_markdown)},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format document metadata in a more maintainable way
|
# Format document metadata in a more maintainable way
|
||||||
|
|
@ -117,7 +123,7 @@ async def add_crawled_url_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Checking for duplicate content: {url}",
|
f"Checking for duplicate content: {url}",
|
||||||
{"stage": "duplicate_check", "content_hash": content_hash}
|
{"stage": "duplicate_check", "content_hash": content_hash},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if document with this content hash already exists
|
# Check if document with this content hash already exists
|
||||||
|
|
@ -130,16 +136,21 @@ async def add_crawled_url_document(
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Document already exists for URL: {url}",
|
f"Document already exists for URL: {url}",
|
||||||
{"duplicate_detected": True, "existing_document_id": existing_document.id}
|
{
|
||||||
|
"duplicate_detected": True,
|
||||||
|
"existing_document_id": existing_document.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||||
)
|
)
|
||||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
|
||||||
return existing_document
|
return existing_document
|
||||||
|
|
||||||
# Get LLM for summary generation
|
# Get LLM for summary generation
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Preparing for summary generation: {url}",
|
f"Preparing for summary generation: {url}",
|
||||||
{"stage": "llm_setup"}
|
{"stage": "llm_setup"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get user's long context LLM
|
# Get user's long context LLM
|
||||||
|
|
@ -151,7 +162,7 @@ async def add_crawled_url_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Generating summary for URL content: {url}",
|
f"Generating summary for URL content: {url}",
|
||||||
{"stage": "summary_generation"}
|
{"stage": "summary_generation"},
|
||||||
)
|
)
|
||||||
|
|
||||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||||
|
|
@ -165,7 +176,7 @@ async def add_crawled_url_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing content chunks for URL: {url}",
|
f"Processing content chunks for URL: {url}",
|
||||||
{"stage": "chunk_processing"}
|
{"stage": "chunk_processing"},
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
|
|
@ -180,13 +191,13 @@ async def add_crawled_url_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Creating document in database for URL: {url}",
|
f"Creating document in database for URL: {url}",
|
||||||
{"stage": "document_creation", "chunks_count": len(chunks)}
|
{"stage": "document_creation", "chunks_count": len(chunks)},
|
||||||
)
|
)
|
||||||
|
|
||||||
document = Document(
|
document = Document(
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
title=url_crawled[0].metadata["title"]
|
title=url_crawled[0].metadata["title"]
|
||||||
if type(crawl_loader) == FireCrawlLoader
|
if isinstance(crawl_loader, FireCrawlLoader)
|
||||||
else url_crawled[0].metadata["source"],
|
else url_crawled[0].metadata["source"],
|
||||||
document_type=DocumentType.CRAWLED_URL,
|
document_type=DocumentType.CRAWLED_URL,
|
||||||
document_metadata=url_crawled[0].metadata,
|
document_metadata=url_crawled[0].metadata,
|
||||||
|
|
@ -209,8 +220,8 @@ async def add_crawled_url_document(
|
||||||
"title": document.title,
|
"title": document.title,
|
||||||
"content_hash": content_hash,
|
"content_hash": content_hash,
|
||||||
"chunks_count": len(chunks),
|
"chunks_count": len(chunks),
|
||||||
"summary_length": len(summary_content)
|
"summary_length": len(summary_content),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
@ -221,7 +232,7 @@ async def add_crawled_url_document(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Database error while processing URL: {url}",
|
f"Database error while processing URL: {url}",
|
||||||
str(db_error),
|
str(db_error),
|
||||||
{"error_type": "SQLAlchemyError"}
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
raise db_error
|
raise db_error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -230,14 +241,17 @@ async def add_crawled_url_document(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Failed to crawl URL: {url}",
|
f"Failed to crawl URL: {url}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__}
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
raise RuntimeError(f"Failed to crawl URL: {str(e)}")
|
raise RuntimeError(f"Failed to crawl URL: {e!s}") from e
|
||||||
|
|
||||||
|
|
||||||
async def add_extension_received_document(
|
async def add_extension_received_document(
|
||||||
session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int, user_id: str
|
session: AsyncSession,
|
||||||
) -> Optional[Document]:
|
content: ExtensionDocumentContent,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> Document | None:
|
||||||
"""
|
"""
|
||||||
Process and store document content received from the SurfSense Extension.
|
Process and store document content received from the SurfSense Extension.
|
||||||
|
|
||||||
|
|
@ -259,8 +273,8 @@ async def add_extension_received_document(
|
||||||
metadata={
|
metadata={
|
||||||
"url": content.metadata.VisitedWebPageURL,
|
"url": content.metadata.VisitedWebPageURL,
|
||||||
"title": content.metadata.VisitedWebPageTitle,
|
"title": content.metadata.VisitedWebPageTitle,
|
||||||
"user_id": str(user_id)
|
"user_id": str(user_id),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -306,9 +320,14 @@ async def add_extension_received_document(
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Extension document already exists: {content.metadata.VisitedWebPageTitle}",
|
f"Extension document already exists: {content.metadata.VisitedWebPageTitle}",
|
||||||
{"duplicate_detected": True, "existing_document_id": existing_document.id}
|
{
|
||||||
|
"duplicate_detected": True,
|
||||||
|
"existing_document_id": existing_document.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||||
)
|
)
|
||||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
|
||||||
return existing_document
|
return existing_document
|
||||||
|
|
||||||
# Get user's long context LLM
|
# Get user's long context LLM
|
||||||
|
|
@ -356,8 +375,8 @@ async def add_extension_received_document(
|
||||||
{
|
{
|
||||||
"document_id": document.id,
|
"document_id": document.id,
|
||||||
"content_hash": content_hash,
|
"content_hash": content_hash,
|
||||||
"url": content.metadata.VisitedWebPageURL
|
"url": content.metadata.VisitedWebPageURL,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
@ -368,7 +387,7 @@ async def add_extension_received_document(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Database error processing extension document: {content.metadata.VisitedWebPageTitle}",
|
f"Database error processing extension document: {content.metadata.VisitedWebPageTitle}",
|
||||||
str(db_error),
|
str(db_error),
|
||||||
{"error_type": "SQLAlchemyError"}
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
raise db_error
|
raise db_error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -377,14 +396,18 @@ async def add_extension_received_document(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Failed to process extension document: {content.metadata.VisitedWebPageTitle}",
|
f"Failed to process extension document: {content.metadata.VisitedWebPageTitle}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__}
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
raise RuntimeError(f"Failed to process extension document: {str(e)}")
|
raise RuntimeError(f"Failed to process extension document: {e!s}") from e
|
||||||
|
|
||||||
|
|
||||||
async def add_received_markdown_file_document(
|
async def add_received_markdown_file_document(
|
||||||
session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int, user_id: str
|
session: AsyncSession,
|
||||||
) -> Optional[Document]:
|
file_name: str,
|
||||||
|
file_in_markdown: str,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> Document | None:
|
||||||
task_logger = TaskLoggingService(session, search_space_id)
|
task_logger = TaskLoggingService(session, search_space_id)
|
||||||
|
|
||||||
# Log task start
|
# Log task start
|
||||||
|
|
@ -392,7 +415,11 @@ async def add_received_markdown_file_document(
|
||||||
task_name="markdown_file_document",
|
task_name="markdown_file_document",
|
||||||
source="background_task",
|
source="background_task",
|
||||||
message=f"Processing markdown file: {file_name}",
|
message=f"Processing markdown file: {file_name}",
|
||||||
metadata={"filename": file_name, "user_id": str(user_id), "content_length": len(file_in_markdown)}
|
metadata={
|
||||||
|
"filename": file_name,
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"content_length": len(file_in_markdown),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -408,9 +435,14 @@ async def add_received_markdown_file_document(
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Markdown file document already exists: {file_name}",
|
f"Markdown file document already exists: {file_name}",
|
||||||
{"duplicate_detected": True, "existing_document_id": existing_document.id}
|
{
|
||||||
|
"duplicate_detected": True,
|
||||||
|
"existing_document_id": existing_document.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||||
)
|
)
|
||||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
|
||||||
return existing_document
|
return existing_document
|
||||||
|
|
||||||
# Get user's long context LLM
|
# Get user's long context LLM
|
||||||
|
|
@ -459,8 +491,8 @@ async def add_received_markdown_file_document(
|
||||||
"document_id": document.id,
|
"document_id": document.id,
|
||||||
"content_hash": content_hash,
|
"content_hash": content_hash,
|
||||||
"chunks_count": len(chunks),
|
"chunks_count": len(chunks),
|
||||||
"summary_length": len(summary_content)
|
"summary_length": len(summary_content),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
@ -470,7 +502,7 @@ async def add_received_markdown_file_document(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Database error processing markdown file: {file_name}",
|
f"Database error processing markdown file: {file_name}",
|
||||||
str(db_error),
|
str(db_error),
|
||||||
{"error_type": "SQLAlchemyError"}
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
raise db_error
|
raise db_error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -479,18 +511,18 @@ async def add_received_markdown_file_document(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Failed to process markdown file: {file_name}",
|
f"Failed to process markdown file: {file_name}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__}
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
raise RuntimeError(f"Failed to process file document: {e!s}") from e
|
||||||
|
|
||||||
|
|
||||||
async def add_received_file_document_using_unstructured(
|
async def add_received_file_document_using_unstructured(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
file_name: str,
|
file_name: str,
|
||||||
unstructured_processed_elements: List[LangChainDocument],
|
unstructured_processed_elements: list[LangChainDocument],
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> Optional[Document]:
|
) -> Document | None:
|
||||||
try:
|
try:
|
||||||
file_in_markdown = await convert_document_to_markdown(
|
file_in_markdown = await convert_document_to_markdown(
|
||||||
unstructured_processed_elements
|
unstructured_processed_elements
|
||||||
|
|
@ -505,7 +537,9 @@ async def add_received_file_document_using_unstructured(
|
||||||
existing_document = existing_doc_result.scalars().first()
|
existing_document = existing_doc_result.scalars().first()
|
||||||
|
|
||||||
if existing_document:
|
if existing_document:
|
||||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
logging.info(
|
||||||
|
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||||
|
)
|
||||||
return existing_document
|
return existing_document
|
||||||
|
|
||||||
# TODO: Check if file_markdown exceeds token limit of embedding model
|
# TODO: Check if file_markdown exceeds token limit of embedding model
|
||||||
|
|
@ -555,7 +589,7 @@ async def add_received_file_document_using_unstructured(
|
||||||
raise db_error
|
raise db_error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
raise RuntimeError(f"Failed to process file document: {e!s}") from e
|
||||||
|
|
||||||
|
|
||||||
async def add_received_file_document_using_llamacloud(
|
async def add_received_file_document_using_llamacloud(
|
||||||
|
|
@ -564,7 +598,7 @@ async def add_received_file_document_using_llamacloud(
|
||||||
llamacloud_markdown_document: str,
|
llamacloud_markdown_document: str,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> Optional[Document]:
|
) -> Document | None:
|
||||||
"""
|
"""
|
||||||
Process and store document content parsed by LlamaCloud.
|
Process and store document content parsed by LlamaCloud.
|
||||||
|
|
||||||
|
|
@ -590,7 +624,9 @@ async def add_received_file_document_using_llamacloud(
|
||||||
existing_document = existing_doc_result.scalars().first()
|
existing_document = existing_doc_result.scalars().first()
|
||||||
|
|
||||||
if existing_document:
|
if existing_document:
|
||||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
logging.info(
|
||||||
|
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||||
|
)
|
||||||
return existing_document
|
return existing_document
|
||||||
|
|
||||||
# Get user's long context LLM
|
# Get user's long context LLM
|
||||||
|
|
@ -638,7 +674,9 @@ async def add_received_file_document_using_llamacloud(
|
||||||
raise db_error
|
raise db_error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise RuntimeError(f"Failed to process file document using LlamaCloud: {str(e)}")
|
raise RuntimeError(
|
||||||
|
f"Failed to process file document using LlamaCloud: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
async def add_received_file_document_using_docling(
|
async def add_received_file_document_using_docling(
|
||||||
|
|
@ -647,7 +685,7 @@ async def add_received_file_document_using_docling(
|
||||||
docling_markdown_document: str,
|
docling_markdown_document: str,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> Optional[Document]:
|
) -> Document | None:
|
||||||
"""
|
"""
|
||||||
Process and store document content parsed by Docling.
|
Process and store document content parsed by Docling.
|
||||||
|
|
||||||
|
|
@ -673,7 +711,9 @@ async def add_received_file_document_using_docling(
|
||||||
existing_document = existing_doc_result.scalars().first()
|
existing_document = existing_doc_result.scalars().first()
|
||||||
|
|
||||||
if existing_document:
|
if existing_document:
|
||||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
logging.info(
|
||||||
|
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||||
|
)
|
||||||
return existing_document
|
return existing_document
|
||||||
|
|
||||||
# Get user's long context LLM
|
# Get user's long context LLM
|
||||||
|
|
@ -683,12 +723,11 @@ async def add_received_file_document_using_docling(
|
||||||
|
|
||||||
# Generate summary using chunked processing for large documents
|
# Generate summary using chunked processing for large documents
|
||||||
from app.services.docling_service import create_docling_service
|
from app.services.docling_service import create_docling_service
|
||||||
|
|
||||||
docling_service = create_docling_service()
|
docling_service = create_docling_service()
|
||||||
|
|
||||||
summary_content = await docling_service.process_large_document_summary(
|
summary_content = await docling_service.process_large_document_summary(
|
||||||
content=file_in_markdown,
|
content=file_in_markdown, llm=user_llm, document_title=file_name
|
||||||
llm=user_llm,
|
|
||||||
document_title=file_name
|
|
||||||
)
|
)
|
||||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||||
|
|
||||||
|
|
@ -726,7 +765,9 @@ async def add_received_file_document_using_docling(
|
||||||
raise db_error
|
raise db_error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise RuntimeError(f"Failed to process file document using Docling: {str(e)}")
|
raise RuntimeError(
|
||||||
|
f"Failed to process file document using Docling: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
async def add_youtube_video_document(
|
async def add_youtube_video_document(
|
||||||
|
|
@ -755,7 +796,7 @@ async def add_youtube_video_document(
|
||||||
task_name="youtube_video_document",
|
task_name="youtube_video_document",
|
||||||
source="background_task",
|
source="background_task",
|
||||||
message=f"Starting YouTube video processing for: {url}",
|
message=f"Starting YouTube video processing for: {url}",
|
||||||
metadata={"url": url, "user_id": str(user_id)}
|
metadata={"url": url, "user_id": str(user_id)},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -763,7 +804,7 @@ async def add_youtube_video_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Extracting video ID from URL: {url}",
|
f"Extracting video ID from URL: {url}",
|
||||||
{"stage": "video_id_extraction"}
|
{"stage": "video_id_extraction"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_youtube_video_id(url: str):
|
def get_youtube_video_id(url: str):
|
||||||
|
|
@ -790,14 +831,14 @@ async def add_youtube_video_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Video ID extracted: {video_id}",
|
f"Video ID extracted: {video_id}",
|
||||||
{"stage": "video_id_extracted", "video_id": video_id}
|
{"stage": "video_id_extracted", "video_id": video_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get video metadata
|
# Get video metadata
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Fetching video metadata for: {video_id}",
|
f"Fetching video metadata for: {video_id}",
|
||||||
{"stage": "metadata_fetch"}
|
{"stage": "metadata_fetch"},
|
||||||
)
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
|
|
@ -806,21 +847,27 @@ async def add_youtube_video_document(
|
||||||
}
|
}
|
||||||
oembed_url = "https://www.youtube.com/oembed"
|
oembed_url = "https://www.youtube.com/oembed"
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as http_session:
|
async with (
|
||||||
async with http_session.get(oembed_url, params=params) as response:
|
aiohttp.ClientSession() as http_session,
|
||||||
|
http_session.get(oembed_url, params=params) as response,
|
||||||
|
):
|
||||||
video_data = await response.json()
|
video_data = await response.json()
|
||||||
|
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Video metadata fetched: {video_data.get('title', 'Unknown')}",
|
f"Video metadata fetched: {video_data.get('title', 'Unknown')}",
|
||||||
{"stage": "metadata_fetched", "title": video_data.get('title'), "author": video_data.get('author_name')}
|
{
|
||||||
|
"stage": "metadata_fetched",
|
||||||
|
"title": video_data.get("title"),
|
||||||
|
"author": video_data.get("author_name"),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get video transcript
|
# Get video transcript
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Fetching transcript for video: {video_id}",
|
f"Fetching transcript for video: {video_id}",
|
||||||
{"stage": "transcript_fetch"}
|
{"stage": "transcript_fetch"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -838,21 +885,25 @@ async def add_youtube_video_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Transcript fetched successfully: {len(captions)} segments",
|
f"Transcript fetched successfully: {len(captions)} segments",
|
||||||
{"stage": "transcript_fetched", "segments_count": len(captions), "transcript_length": len(transcript_text)}
|
{
|
||||||
|
"stage": "transcript_fetched",
|
||||||
|
"segments_count": len(captions),
|
||||||
|
"transcript_length": len(transcript_text),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
transcript_text = f"No captions available for this video. Error: {str(e)}"
|
transcript_text = f"No captions available for this video. Error: {e!s}"
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"No transcript available for video: {video_id}",
|
f"No transcript available for video: {video_id}",
|
||||||
{"stage": "transcript_unavailable", "error": str(e)}
|
{"stage": "transcript_unavailable", "error": str(e)},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format document
|
# Format document
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing video content: {video_data.get('title', 'YouTube Video')}",
|
f"Processing video content: {video_data.get('title', 'YouTube Video')}",
|
||||||
{"stage": "content_processing"}
|
{"stage": "content_processing"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format document metadata in a more maintainable way
|
# Format document metadata in a more maintainable way
|
||||||
|
|
@ -890,7 +941,7 @@ async def add_youtube_video_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Checking for duplicate video content: {video_id}",
|
f"Checking for duplicate video content: {video_id}",
|
||||||
{"stage": "duplicate_check", "content_hash": content_hash}
|
{"stage": "duplicate_check", "content_hash": content_hash},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if document with this content hash already exists
|
# Check if document with this content hash already exists
|
||||||
|
|
@ -903,16 +954,22 @@ async def add_youtube_video_document(
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"YouTube video document already exists: {video_data.get('title', 'YouTube Video')}",
|
f"YouTube video document already exists: {video_data.get('title', 'YouTube Video')}",
|
||||||
{"duplicate_detected": True, "existing_document_id": existing_document.id, "video_id": video_id}
|
{
|
||||||
|
"duplicate_detected": True,
|
||||||
|
"existing_document_id": existing_document.id,
|
||||||
|
"video_id": video_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||||
)
|
)
|
||||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
|
||||||
return existing_document
|
return existing_document
|
||||||
|
|
||||||
# Get LLM for summary generation
|
# Get LLM for summary generation
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Preparing for summary generation: {video_data.get('title', 'YouTube Video')}",
|
f"Preparing for summary generation: {video_data.get('title', 'YouTube Video')}",
|
||||||
{"stage": "llm_setup"}
|
{"stage": "llm_setup"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get user's long context LLM
|
# Get user's long context LLM
|
||||||
|
|
@ -924,7 +981,7 @@ async def add_youtube_video_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Generating summary for video: {video_data.get('title', 'YouTube Video')}",
|
f"Generating summary for video: {video_data.get('title', 'YouTube Video')}",
|
||||||
{"stage": "summary_generation"}
|
{"stage": "summary_generation"},
|
||||||
)
|
)
|
||||||
|
|
||||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||||
|
|
@ -938,7 +995,7 @@ async def add_youtube_video_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing content chunks for video: {video_data.get('title', 'YouTube Video')}",
|
f"Processing content chunks for video: {video_data.get('title', 'YouTube Video')}",
|
||||||
{"stage": "chunk_processing"}
|
{"stage": "chunk_processing"},
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
|
|
@ -953,7 +1010,7 @@ async def add_youtube_video_document(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Creating YouTube video document in database: {video_data.get('title', 'YouTube Video')}",
|
f"Creating YouTube video document in database: {video_data.get('title', 'YouTube Video')}",
|
||||||
{"stage": "document_creation", "chunks_count": len(chunks)}
|
{"stage": "document_creation", "chunks_count": len(chunks)},
|
||||||
)
|
)
|
||||||
|
|
||||||
document = Document(
|
document = Document(
|
||||||
|
|
@ -988,8 +1045,8 @@ async def add_youtube_video_document(
|
||||||
"content_hash": content_hash,
|
"content_hash": content_hash,
|
||||||
"chunks_count": len(chunks),
|
"chunks_count": len(chunks),
|
||||||
"summary_length": len(summary_content),
|
"summary_length": len(summary_content),
|
||||||
"has_transcript": "No captions available" not in transcript_text
|
"has_transcript": "No captions available" not in transcript_text,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
@ -999,7 +1056,10 @@ async def add_youtube_video_document(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Database error while processing YouTube video: {url}",
|
f"Database error while processing YouTube video: {url}",
|
||||||
str(db_error),
|
str(db_error),
|
||||||
{"error_type": "SQLAlchemyError", "video_id": video_id if 'video_id' in locals() else None}
|
{
|
||||||
|
"error_type": "SQLAlchemyError",
|
||||||
|
"video_id": video_id if "video_id" in locals() else None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
raise db_error
|
raise db_error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1008,7 +1068,10 @@ async def add_youtube_video_document(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Failed to process YouTube video: {url}",
|
f"Failed to process YouTube video: {url}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__, "video_id": video_id if 'video_id' in locals() else None}
|
{
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
"video_id": video_id if "video_id" in locals() else None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
logging.error(f"Failed to process YouTube video: {str(e)}")
|
logging.error(f"Failed to process YouTube video: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import UTC, datetime, timedelta
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
from slack_sdk.errors import SlackApiError
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.connectors.discord_connector import DiscordConnector
|
from app.connectors.discord_connector import DiscordConnector
|
||||||
|
|
@ -21,10 +25,6 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.utils.document_converters import generate_content_hash
|
from app.utils.document_converters import generate_content_hash
|
||||||
from slack_sdk.errors import SlackApiError
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -35,10 +35,10 @@ async def index_slack_messages(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
start_date: str = None,
|
start_date: str | None = None,
|
||||||
end_date: str = None,
|
end_date: str | None = None,
|
||||||
update_last_indexed: bool = True,
|
update_last_indexed: bool = True,
|
||||||
) -> Tuple[int, Optional[str]]:
|
) -> tuple[int, str | None]:
|
||||||
"""
|
"""
|
||||||
Index Slack messages from all accessible channels.
|
Index Slack messages from all accessible channels.
|
||||||
|
|
||||||
|
|
@ -192,7 +192,7 @@ async def index_slack_messages(
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": "ChannelFetchError"},
|
{"error_type": "ChannelFetchError"},
|
||||||
)
|
)
|
||||||
return 0, f"Failed to get Slack channels: {str(e)}"
|
return 0, f"Failed to get Slack channels: {e!s}"
|
||||||
|
|
||||||
if not channels:
|
if not channels:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
|
|
@ -400,13 +400,13 @@ async def index_slack_messages(
|
||||||
|
|
||||||
except SlackApiError as slack_error:
|
except SlackApiError as slack_error:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Slack API error for channel {channel_name}: {str(slack_error)}"
|
f"Slack API error for channel {channel_name}: {slack_error!s}"
|
||||||
)
|
)
|
||||||
skipped_channels.append(f"{channel_name} (Slack API error)")
|
skipped_channels.append(f"{channel_name} (Slack API error)")
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue # Skip this channel and continue with others
|
continue # Skip this channel and continue with others
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing channel {channel_name}: {str(e)}")
|
logger.error(f"Error processing channel {channel_name}: {e!s}")
|
||||||
skipped_channels.append(f"{channel_name} (processing error)")
|
skipped_channels.append(f"{channel_name} (processing error)")
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue # Skip this channel and continue with others
|
continue # Skip this channel and continue with others
|
||||||
|
|
@ -453,8 +453,8 @@ async def index_slack_messages(
|
||||||
str(db_error),
|
str(db_error),
|
||||||
{"error_type": "SQLAlchemyError"},
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
logger.error(f"Database error: {str(db_error)}")
|
logger.error(f"Database error: {db_error!s}")
|
||||||
return 0, f"Database error: {str(db_error)}"
|
return 0, f"Database error: {db_error!s}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
|
|
@ -463,8 +463,8 @@ async def index_slack_messages(
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__},
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
logger.error(f"Failed to index Slack messages: {str(e)}")
|
logger.error(f"Failed to index Slack messages: {e!s}")
|
||||||
return 0, f"Failed to index Slack messages: {str(e)}"
|
return 0, f"Failed to index Slack messages: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
async def index_notion_pages(
|
async def index_notion_pages(
|
||||||
|
|
@ -472,10 +472,10 @@ async def index_notion_pages(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
start_date: str = None,
|
start_date: str | None = None,
|
||||||
end_date: str = None,
|
end_date: str | None = None,
|
||||||
update_last_indexed: bool = True,
|
update_last_indexed: bool = True,
|
||||||
) -> Tuple[int, Optional[str]]:
|
) -> tuple[int, str | None]:
|
||||||
"""
|
"""
|
||||||
Index Notion pages from all accessible pages.
|
Index Notion pages from all accessible pages.
|
||||||
|
|
||||||
|
|
@ -611,8 +611,8 @@ async def index_notion_pages(
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": "PageFetchError"},
|
{"error_type": "PageFetchError"},
|
||||||
)
|
)
|
||||||
logger.error(f"Error fetching Notion pages: {str(e)}", exc_info=True)
|
logger.error(f"Error fetching Notion pages: {e!s}", exc_info=True)
|
||||||
return 0, f"Failed to get Notion pages: {str(e)}"
|
return 0, f"Failed to get Notion pages: {e!s}"
|
||||||
|
|
||||||
if not pages:
|
if not pages:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
|
|
@ -799,7 +799,7 @@ async def index_notion_pages(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error processing Notion page {page.get('title', 'Unknown')}: {str(e)}",
|
f"Error processing Notion page {page.get('title', 'Unknown')}: {e!s}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
skipped_pages.append(
|
skipped_pages.append(
|
||||||
|
|
@ -852,9 +852,9 @@ async def index_notion_pages(
|
||||||
{"error_type": "SQLAlchemyError"},
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Database error during Notion indexing: {str(db_error)}", exc_info=True
|
f"Database error during Notion indexing: {db_error!s}", exc_info=True
|
||||||
)
|
)
|
||||||
return 0, f"Database error: {str(db_error)}"
|
return 0, f"Database error: {db_error!s}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
|
|
@ -863,8 +863,8 @@ async def index_notion_pages(
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__},
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
logger.error(f"Failed to index Notion pages: {str(e)}", exc_info=True)
|
logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True)
|
||||||
return 0, f"Failed to index Notion pages: {str(e)}"
|
return 0, f"Failed to index Notion pages: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
async def index_github_repos(
|
async def index_github_repos(
|
||||||
|
|
@ -872,10 +872,10 @@ async def index_github_repos(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
start_date: str = None,
|
start_date: str | None = None,
|
||||||
end_date: str = None,
|
end_date: str | None = None,
|
||||||
update_last_indexed: bool = True,
|
update_last_indexed: bool = True,
|
||||||
) -> Tuple[int, Optional[str]]:
|
) -> tuple[int, str | None]:
|
||||||
"""
|
"""
|
||||||
Index code and documentation files from accessible GitHub repositories.
|
Index code and documentation files from accessible GitHub repositories.
|
||||||
|
|
||||||
|
|
@ -978,7 +978,7 @@ async def index_github_repos(
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": "ClientInitializationError"},
|
{"error_type": "ClientInitializationError"},
|
||||||
)
|
)
|
||||||
return 0, f"Failed to initialize GitHub client: {str(e)}"
|
return 0, f"Failed to initialize GitHub client: {e!s}"
|
||||||
|
|
||||||
# 4. Validate selected repositories
|
# 4. Validate selected repositories
|
||||||
# For simplicity, we'll proceed with the list provided.
|
# For simplicity, we'll proceed with the list provided.
|
||||||
|
|
@ -1097,7 +1097,7 @@ async def index_github_repos(
|
||||||
"url": file_url,
|
"url": file_url,
|
||||||
"sha": file_sha,
|
"sha": file_sha,
|
||||||
"type": file_type,
|
"type": file_type,
|
||||||
"indexed_at": datetime.now(timezone.utc).isoformat(),
|
"indexed_at": datetime.now(UTC).isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create new document
|
# Create new document
|
||||||
|
|
@ -1175,10 +1175,10 @@ async def index_linear_issues(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
start_date: str = None,
|
start_date: str | None = None,
|
||||||
end_date: str = None,
|
end_date: str | None = None,
|
||||||
update_last_indexed: bool = True,
|
update_last_indexed: bool = True,
|
||||||
) -> Tuple[int, Optional[str]]:
|
) -> tuple[int, str | None]:
|
||||||
"""
|
"""
|
||||||
Index Linear issues and comments.
|
Index Linear issues and comments.
|
||||||
|
|
||||||
|
|
@ -1339,8 +1339,8 @@ async def index_linear_issues(
|
||||||
logger.info(f"Retrieved {len(issues)} issues from Linear API")
|
logger.info(f"Retrieved {len(issues)} issues from Linear API")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception when calling Linear API: {str(e)}", exc_info=True)
|
logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True)
|
||||||
return 0, f"Failed to get Linear issues: {str(e)}"
|
return 0, f"Failed to get Linear issues: {e!s}"
|
||||||
|
|
||||||
if not issues:
|
if not issues:
|
||||||
logger.info("No Linear issues found for the specified date range")
|
logger.info("No Linear issues found for the specified date range")
|
||||||
|
|
@ -1481,7 +1481,7 @@ async def index_linear_issues(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error processing issue {issue.get('identifier', 'Unknown')}: {str(e)}",
|
f"Error processing issue {issue.get('identifier', 'Unknown')}: {e!s}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
skipped_issues.append(
|
skipped_issues.append(
|
||||||
|
|
@ -1528,8 +1528,8 @@ async def index_linear_issues(
|
||||||
str(db_error),
|
str(db_error),
|
||||||
{"error_type": "SQLAlchemyError"},
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
logger.error(f"Database error: {str(db_error)}", exc_info=True)
|
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||||
return 0, f"Database error: {str(db_error)}"
|
return 0, f"Database error: {db_error!s}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
|
|
@ -1538,8 +1538,8 @@ async def index_linear_issues(
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__},
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
logger.error(f"Failed to index Linear issues: {str(e)}", exc_info=True)
|
logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True)
|
||||||
return 0, f"Failed to index Linear issues: {str(e)}"
|
return 0, f"Failed to index Linear issues: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
async def index_discord_messages(
|
async def index_discord_messages(
|
||||||
|
|
@ -1547,10 +1547,10 @@ async def index_discord_messages(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
start_date: str = None,
|
start_date: str | None = None,
|
||||||
end_date: str = None,
|
end_date: str | None = None,
|
||||||
update_last_indexed: bool = True,
|
update_last_indexed: bool = True,
|
||||||
) -> Tuple[int, Optional[str]]:
|
) -> tuple[int, str | None]:
|
||||||
"""
|
"""
|
||||||
Index Discord messages from all accessible channels.
|
Index Discord messages from all accessible channels.
|
||||||
|
|
||||||
|
|
@ -1632,13 +1632,11 @@ async def index_discord_messages(
|
||||||
# Calculate date range
|
# Calculate date range
|
||||||
if start_date is None or end_date is None:
|
if start_date is None or end_date is None:
|
||||||
# Fall back to calculating dates based on last_indexed_at
|
# Fall back to calculating dates based on last_indexed_at
|
||||||
calculated_end_date = datetime.now(timezone.utc)
|
calculated_end_date = datetime.now(UTC)
|
||||||
|
|
||||||
# Use last_indexed_at as start date if available, otherwise use 365 days ago
|
# Use last_indexed_at as start date if available, otherwise use 365 days ago
|
||||||
if connector.last_indexed_at:
|
if connector.last_indexed_at:
|
||||||
calculated_start_date = connector.last_indexed_at.replace(
|
calculated_start_date = connector.last_indexed_at.replace(tzinfo=UTC)
|
||||||
tzinfo=timezone.utc
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date"
|
f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date"
|
||||||
)
|
)
|
||||||
|
|
@ -1655,7 +1653,7 @@ async def index_discord_messages(
|
||||||
# Convert YYYY-MM-DD to ISO format
|
# Convert YYYY-MM-DD to ISO format
|
||||||
start_date_iso = (
|
start_date_iso = (
|
||||||
datetime.strptime(start_date, "%Y-%m-%d")
|
datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
.replace(tzinfo=timezone.utc)
|
.replace(tzinfo=UTC)
|
||||||
.isoformat()
|
.isoformat()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1665,20 +1663,18 @@ async def index_discord_messages(
|
||||||
# Convert YYYY-MM-DD to ISO format
|
# Convert YYYY-MM-DD to ISO format
|
||||||
end_date_iso = (
|
end_date_iso = (
|
||||||
datetime.strptime(end_date, "%Y-%m-%d")
|
datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
.replace(tzinfo=timezone.utc)
|
.replace(tzinfo=UTC)
|
||||||
.isoformat()
|
.isoformat()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Convert provided dates to ISO format for Discord API
|
# Convert provided dates to ISO format for Discord API
|
||||||
start_date_iso = (
|
start_date_iso = (
|
||||||
datetime.strptime(start_date, "%Y-%m-%d")
|
datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
.replace(tzinfo=timezone.utc)
|
.replace(tzinfo=UTC)
|
||||||
.isoformat()
|
.isoformat()
|
||||||
)
|
)
|
||||||
end_date_iso = (
|
end_date_iso = (
|
||||||
datetime.strptime(end_date, "%Y-%m-%d")
|
datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=UTC).isoformat()
|
||||||
.replace(tzinfo=timezone.utc)
|
|
||||||
.isoformat()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -1710,9 +1706,9 @@ async def index_discord_messages(
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": "GuildFetchError"},
|
{"error_type": "GuildFetchError"},
|
||||||
)
|
)
|
||||||
logger.error(f"Failed to get Discord guilds: {str(e)}", exc_info=True)
|
logger.error(f"Failed to get Discord guilds: {e!s}", exc_info=True)
|
||||||
await discord_client.close_bot()
|
await discord_client.close_bot()
|
||||||
return 0, f"Failed to get Discord guilds: {str(e)}"
|
return 0, f"Failed to get Discord guilds: {e!s}"
|
||||||
if not guilds:
|
if not guilds:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry,
|
log_entry,
|
||||||
|
|
@ -1754,7 +1750,7 @@ async def index_discord_messages(
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to get messages for channel {channel_name}: {str(e)}"
|
f"Failed to get messages for channel {channel_name}: {e!s}"
|
||||||
)
|
)
|
||||||
skipped_channels.append(
|
skipped_channels.append(
|
||||||
f"{guild_name}#{channel_name} (fetch error)"
|
f"{guild_name}#{channel_name} (fetch error)"
|
||||||
|
|
@ -1886,7 +1882,9 @@ async def index_discord_messages(
|
||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
Chunk(content=raw_chunk.text, embedding=embedding)
|
Chunk(content=raw_chunk.text, embedding=embedding)
|
||||||
for raw_chunk, embedding in zip(raw_chunks, chunk_embeddings)
|
for raw_chunk, embedding in zip(
|
||||||
|
raw_chunks, chunk_embeddings, strict=False
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create and store new document
|
# Create and store new document
|
||||||
|
|
@ -1902,7 +1900,7 @@ async def index_discord_messages(
|
||||||
"message_count": len(formatted_messages),
|
"message_count": len(formatted_messages),
|
||||||
"start_date": start_date_iso,
|
"start_date": start_date_iso,
|
||||||
"end_date": end_date_iso,
|
"end_date": end_date_iso,
|
||||||
"indexed_at": datetime.now(timezone.utc).strftime(
|
"indexed_at": datetime.now(UTC).strftime(
|
||||||
"%Y-%m-%d %H:%M:%S"
|
"%Y-%m-%d %H:%M:%S"
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
@ -1920,14 +1918,14 @@ async def index_discord_messages(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error processing guild {guild_name}: {str(e)}", exc_info=True
|
f"Error processing guild {guild_name}: {e!s}", exc_info=True
|
||||||
)
|
)
|
||||||
skipped_channels.append(f"{guild_name} (processing error)")
|
skipped_channels.append(f"{guild_name} (processing error)")
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if update_last_indexed and documents_indexed > 0:
|
if update_last_indexed and documents_indexed > 0:
|
||||||
connector.last_indexed_at = datetime.now(timezone.utc)
|
connector.last_indexed_at = datetime.now(UTC)
|
||||||
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
|
logger.info(f"Updated last_indexed_at to {connector.last_indexed_at}")
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
@ -1968,9 +1966,9 @@ async def index_discord_messages(
|
||||||
{"error_type": "SQLAlchemyError"},
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Database error during Discord indexing: {str(db_error)}", exc_info=True
|
f"Database error during Discord indexing: {db_error!s}", exc_info=True
|
||||||
)
|
)
|
||||||
return 0, f"Database error: {str(db_error)}"
|
return 0, f"Database error: {db_error!s}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
|
|
@ -1979,8 +1977,8 @@ async def index_discord_messages(
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__},
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
logger.error(f"Failed to index Discord messages: {str(e)}", exc_info=True)
|
logger.error(f"Failed to index Discord messages: {e!s}", exc_info=True)
|
||||||
return 0, f"Failed to index Discord messages: {str(e)}"
|
return 0, f"Failed to index Discord messages: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
async def index_jira_issues(
|
async def index_jira_issues(
|
||||||
|
|
@ -1988,10 +1986,10 @@ async def index_jira_issues(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
start_date: str = None,
|
start_date: str | None = None,
|
||||||
end_date: str = None,
|
end_date: str | None = None,
|
||||||
update_last_indexed: bool = True,
|
update_last_indexed: bool = True,
|
||||||
) -> Tuple[int, Optional[str]]:
|
) -> tuple[int, str | None]:
|
||||||
"""
|
"""
|
||||||
Index Jira issues and comments.
|
Index Jira issues and comments.
|
||||||
|
|
||||||
|
|
@ -2161,8 +2159,8 @@ async def index_jira_issues(
|
||||||
logger.info(f"Retrieved {len(issues)} issues from Jira API")
|
logger.info(f"Retrieved {len(issues)} issues from Jira API")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Jira issues: {str(e)}", exc_info=True)
|
logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True)
|
||||||
return 0, f"Error fetching Jira issues: {str(e)}"
|
return 0, f"Error fetching Jira issues: {e!s}"
|
||||||
|
|
||||||
# Process and index each issue
|
# Process and index each issue
|
||||||
documents_indexed = 0
|
documents_indexed = 0
|
||||||
|
|
@ -2272,7 +2270,7 @@ async def index_jira_issues(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error processing issue {issue.get('identifier', 'Unknown')}: {str(e)}",
|
f"Error processing issue {issue.get('identifier', 'Unknown')}: {e!s}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
skipped_issues.append(
|
skipped_issues.append(
|
||||||
|
|
@ -2319,8 +2317,8 @@ async def index_jira_issues(
|
||||||
str(db_error),
|
str(db_error),
|
||||||
{"error_type": "SQLAlchemyError"},
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
logger.error(f"Database error: {str(db_error)}", exc_info=True)
|
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||||
return 0, f"Database error: {str(db_error)}"
|
return 0, f"Database error: {db_error!s}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
|
|
@ -2329,5 +2327,5 @@ async def index_jira_issues(
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__},
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
logger.error(f"Failed to index JIRA issues: {str(e)}", exc_info=True)
|
logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True)
|
||||||
return 0, f"Failed to index JIRA issues: {str(e)}"
|
return 0, f"Failed to index JIRA issues: {e!s}"
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,26 @@
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||||
from app.agents.podcaster.state import State
|
from app.agents.podcaster.state import State
|
||||||
from app.db import Chat, Podcast
|
from app.db import Chat, Podcast
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_document_podcast(
|
async def generate_document_podcast(
|
||||||
session: AsyncSession,
|
session: AsyncSession, document_id: int, search_space_id: int, user_id: int
|
||||||
document_id: int,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: int
|
|
||||||
):
|
):
|
||||||
# TODO: Need to fetch the document chunks, then concatenate them and pass them to the podcast generation model
|
# TODO: Need to fetch the document chunks, then concatenate them and pass them to the podcast generation model
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_chat_podcast(
|
async def generate_chat_podcast(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
podcast_title: str,
|
podcast_title: str,
|
||||||
user_id: int
|
user_id: int,
|
||||||
):
|
):
|
||||||
task_logger = TaskLoggingService(session, search_space_id)
|
task_logger = TaskLoggingService(session, search_space_id)
|
||||||
|
|
||||||
|
|
@ -37,21 +33,18 @@ async def generate_chat_podcast(
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
"search_space_id": search_space_id,
|
"search_space_id": search_space_id,
|
||||||
"podcast_title": podcast_title,
|
"podcast_title": podcast_title,
|
||||||
"user_id": str(user_id)
|
"user_id": str(user_id),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch the chat with the specified ID
|
# Fetch the chat with the specified ID
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry, f"Fetching chat {chat_id} from database", {"stage": "fetch_chat"}
|
||||||
f"Fetching chat {chat_id} from database",
|
|
||||||
{"stage": "fetch_chat"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
query = select(Chat).filter(
|
query = select(Chat).filter(
|
||||||
Chat.id == chat_id,
|
Chat.id == chat_id, Chat.search_space_id == search_space_id
|
||||||
Chat.search_space_id == search_space_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
|
|
@ -62,15 +55,17 @@ async def generate_chat_podcast(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Chat with id {chat_id} not found in search space {search_space_id}",
|
f"Chat with id {chat_id} not found in search space {search_space_id}",
|
||||||
"Chat not found",
|
"Chat not found",
|
||||||
{"error_type": "ChatNotFound"}
|
{"error_type": "ChatNotFound"},
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Chat with id {chat_id} not found in search space {search_space_id}"
|
||||||
)
|
)
|
||||||
raise ValueError(f"Chat with id {chat_id} not found in search space {search_space_id}")
|
|
||||||
|
|
||||||
# Create chat history structure
|
# Create chat history structure
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing chat history for chat {chat_id}",
|
f"Processing chat history for chat {chat_id}",
|
||||||
{"stage": "process_chat_history", "message_count": len(chat.messages)}
|
{"stage": "process_chat_history", "message_count": len(chat.messages)},
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history_str = "<chat_history>"
|
chat_history_str = "<chat_history>"
|
||||||
|
|
@ -89,7 +84,9 @@ async def generate_chat_podcast(
|
||||||
# If content is a list, join it into a single string
|
# If content is a list, join it into a single string
|
||||||
if isinstance(answer_text, list):
|
if isinstance(answer_text, list):
|
||||||
answer_text = "\n".join(answer_text)
|
answer_text = "\n".join(answer_text)
|
||||||
chat_history_str += f"<assistant_message>{answer_text}</assistant_message>"
|
chat_history_str += (
|
||||||
|
f"<assistant_message>{answer_text}</assistant_message>"
|
||||||
|
)
|
||||||
processed_messages += 1
|
processed_messages += 1
|
||||||
|
|
||||||
chat_history_str += "</chat_history>"
|
chat_history_str += "</chat_history>"
|
||||||
|
|
@ -98,7 +95,11 @@ async def generate_chat_podcast(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Initializing podcast generation for chat {chat_id}",
|
f"Initializing podcast generation for chat {chat_id}",
|
||||||
{"stage": "initialize_podcast_generation", "processed_messages": processed_messages, "content_length": len(chat_history_str)}
|
{
|
||||||
|
"stage": "initialize_podcast_generation",
|
||||||
|
"processed_messages": processed_messages,
|
||||||
|
"content_length": len(chat_history_str),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
|
|
@ -108,16 +109,13 @@ async def generate_chat_podcast(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# Initialize state with database session and streaming service
|
# Initialize state with database session and streaming service
|
||||||
initial_state = State(
|
initial_state = State(source_content=chat_history_str, db_session=session)
|
||||||
source_content=chat_history_str,
|
|
||||||
db_session=session
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the graph directly
|
# Run the graph directly
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Running podcast generation graph for chat {chat_id}",
|
f"Running podcast generation graph for chat {chat_id}",
|
||||||
{"stage": "run_podcast_graph"}
|
{"stage": "run_podcast_graph"},
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await podcaster_graph.ainvoke(initial_state, config=config)
|
result = await podcaster_graph.ainvoke(initial_state, config=config)
|
||||||
|
|
@ -126,28 +124,33 @@ async def generate_chat_podcast(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Processing podcast transcript for chat {chat_id}",
|
f"Processing podcast transcript for chat {chat_id}",
|
||||||
{"stage": "process_transcript", "transcript_entries": len(result["podcast_transcript"])}
|
{
|
||||||
|
"stage": "process_transcript",
|
||||||
|
"transcript_entries": len(result["podcast_transcript"]),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
serializable_transcript = []
|
serializable_transcript = []
|
||||||
for entry in result["podcast_transcript"]:
|
for entry in result["podcast_transcript"]:
|
||||||
serializable_transcript.append({
|
serializable_transcript.append(
|
||||||
"speaker_id": entry.speaker_id,
|
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
|
||||||
"dialog": entry.dialog
|
)
|
||||||
})
|
|
||||||
|
|
||||||
# Create a new podcast entry
|
# Create a new podcast entry
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Creating podcast database entry for chat {chat_id}",
|
f"Creating podcast database entry for chat {chat_id}",
|
||||||
{"stage": "create_podcast_entry", "file_location": result.get("final_podcast_file_path")}
|
{
|
||||||
|
"stage": "create_podcast_entry",
|
||||||
|
"file_location": result.get("final_podcast_file_path"),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
podcast = Podcast(
|
podcast = Podcast(
|
||||||
title=f"{podcast_title}",
|
title=f"{podcast_title}",
|
||||||
podcast_transcript=serializable_transcript,
|
podcast_transcript=serializable_transcript,
|
||||||
file_location=result["final_podcast_file_path"],
|
file_location=result["final_podcast_file_path"],
|
||||||
search_space_id=search_space_id
|
search_space_id=search_space_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to session and commit
|
# Add to session and commit
|
||||||
|
|
@ -165,8 +168,8 @@ async def generate_chat_podcast(
|
||||||
"transcript_entries": len(serializable_transcript),
|
"transcript_entries": len(serializable_transcript),
|
||||||
"file_location": result.get("final_podcast_file_path"),
|
"file_location": result.get("final_podcast_file_path"),
|
||||||
"processed_messages": processed_messages,
|
"processed_messages": processed_messages,
|
||||||
"content_length": len(chat_history_str)
|
"content_length": len(chat_history_str),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return podcast
|
return podcast
|
||||||
|
|
@ -178,7 +181,7 @@ async def generate_chat_podcast(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Value error during podcast generation for chat {chat_id}",
|
f"Value error during podcast generation for chat {chat_id}",
|
||||||
str(ve),
|
str(ve),
|
||||||
{"error_type": "ValueError"}
|
{"error_type": "ValueError"},
|
||||||
)
|
)
|
||||||
raise ve
|
raise ve
|
||||||
except SQLAlchemyError as db_error:
|
except SQLAlchemyError as db_error:
|
||||||
|
|
@ -187,7 +190,7 @@ async def generate_chat_podcast(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Database error during podcast generation for chat {chat_id}",
|
f"Database error during podcast generation for chat {chat_id}",
|
||||||
str(db_error),
|
str(db_error),
|
||||||
{"error_type": "SQLAlchemyError"}
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
raise db_error
|
raise db_error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -196,7 +199,8 @@ async def generate_chat_podcast(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Unexpected error during podcast generation for chat {chat_id}",
|
f"Unexpected error during podcast generation for chat {chat_id}",
|
||||||
str(e),
|
str(e),
|
||||||
{"error_type": type(e).__name__}
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
raise RuntimeError(f"Failed to generate podcast for chat {chat_id}: {str(e)}")
|
raise RuntimeError(
|
||||||
|
f"Failed to generate podcast for chat {chat_id}: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,25 @@
|
||||||
from typing import Any, AsyncGenerator, List, Union
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from app.agents.researcher.graph import graph as researcher_graph
|
|
||||||
from app.agents.researcher.state import State
|
|
||||||
from app.services.streaming_service import StreamingService
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.researcher.configuration import SearchMode
|
from app.agents.researcher.configuration import SearchMode
|
||||||
|
from app.agents.researcher.graph import graph as researcher_graph
|
||||||
|
from app.agents.researcher.state import State
|
||||||
|
from app.services.streaming_service import StreamingService
|
||||||
|
|
||||||
|
|
||||||
async def stream_connector_search_results(
|
async def stream_connector_search_results(
|
||||||
user_query: str,
|
user_query: str,
|
||||||
user_id: Union[str, UUID],
|
user_id: str | UUID,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
research_mode: str,
|
research_mode: str,
|
||||||
selected_connectors: List[str],
|
selected_connectors: list[str],
|
||||||
langchain_chat_history: List[Any],
|
langchain_chat_history: list[Any],
|
||||||
search_mode_str: str,
|
search_mode_str: str,
|
||||||
document_ids_to_add_in_context: List[int]
|
document_ids_to_add_in_context: list[int],
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Stream connector search results to the client
|
Stream connector search results to the client
|
||||||
|
|
@ -37,14 +38,14 @@ async def stream_connector_search_results(
|
||||||
streaming_service = StreamingService()
|
streaming_service = StreamingService()
|
||||||
|
|
||||||
if research_mode == "REPORT_GENERAL":
|
if research_mode == "REPORT_GENERAL":
|
||||||
NUM_SECTIONS = 1
|
num_sections = 1
|
||||||
elif research_mode == "REPORT_DEEP":
|
elif research_mode == "REPORT_DEEP":
|
||||||
NUM_SECTIONS = 3
|
num_sections = 3
|
||||||
elif research_mode == "REPORT_DEEPER":
|
elif research_mode == "REPORT_DEEPER":
|
||||||
NUM_SECTIONS = 6
|
num_sections = 6
|
||||||
else:
|
else:
|
||||||
# Default fallback
|
# Default fallback
|
||||||
NUM_SECTIONS = 1
|
num_sections = 1
|
||||||
|
|
||||||
# Convert UUID to string if needed
|
# Convert UUID to string if needed
|
||||||
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
|
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
|
||||||
|
|
@ -58,20 +59,20 @@ async def stream_connector_search_results(
|
||||||
config = {
|
config = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"user_query": user_query,
|
"user_query": user_query,
|
||||||
"num_sections": NUM_SECTIONS,
|
"num_sections": num_sections,
|
||||||
"connectors_to_search": selected_connectors,
|
"connectors_to_search": selected_connectors,
|
||||||
"user_id": user_id_str,
|
"user_id": user_id_str,
|
||||||
"search_space_id": search_space_id,
|
"search_space_id": search_space_id,
|
||||||
"search_mode": search_mode,
|
"search_mode": search_mode,
|
||||||
"research_mode": research_mode,
|
"research_mode": research_mode,
|
||||||
"document_ids_to_add_in_context": document_ids_to_add_in_context
|
"document_ids_to_add_in_context": document_ids_to_add_in_context,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# Initialize state with database session and streaming service
|
# Initialize state with database session and streaming service
|
||||||
initial_state = State(
|
initial_state = State(
|
||||||
db_session=session,
|
db_session=session,
|
||||||
streaming_service=streaming_service,
|
streaming_service=streaming_service,
|
||||||
chat_history=langchain_chat_history
|
chat_history=langchain_chat_history,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the graph directly
|
# Run the graph directly
|
||||||
|
|
@ -83,8 +84,7 @@ async def stream_connector_search_results(
|
||||||
config=config,
|
config=config,
|
||||||
stream_mode="custom",
|
stream_mode="custom",
|
||||||
):
|
):
|
||||||
if isinstance(chunk, dict):
|
if isinstance(chunk, dict) and "yield_value" in chunk:
|
||||||
if "yield_value" in chunk:
|
|
||||||
yield chunk["yield_value"]
|
yield chunk["yield_value"]
|
||||||
|
|
||||||
yield streaming_service.format_completion()
|
yield streaming_service.format_completion()
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
from typing import Optional
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import Depends, Request, Response
|
from fastapi import Depends, Request, Response
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||||
from fastapi_users.authentication import (
|
from fastapi_users.authentication import (
|
||||||
AuthenticationBackend,
|
AuthenticationBackend,
|
||||||
|
|
@ -10,16 +9,18 @@ from fastapi_users.authentication import (
|
||||||
JWTStrategy,
|
JWTStrategy,
|
||||||
)
|
)
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from fastapi_users.schemas import model_dump
|
from fastapi_users.schemas import model_dump
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import User, get_user_db
|
from app.db import User, get_user_db
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class BearerResponse(BaseModel):
|
class BearerResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
SECRET = config.SECRET_KEY
|
SECRET = config.SECRET_KEY
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
|
|
@ -35,19 +36,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||||
reset_password_token_secret = SECRET
|
reset_password_token_secret = SECRET
|
||||||
verification_token_secret = SECRET
|
verification_token_secret = SECRET
|
||||||
|
|
||||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
async def on_after_register(self, user: User, request: Request | None = None):
|
||||||
print(f"User {user.id} has registered.")
|
print(f"User {user.id} has registered.")
|
||||||
|
|
||||||
async def on_after_forgot_password(
|
async def on_after_forgot_password(
|
||||||
self, user: User, token: str, request: Optional[Request] = None
|
self, user: User, token: str, request: Request | None = None
|
||||||
):
|
):
|
||||||
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||||
|
|
||||||
async def on_after_request_verify(
|
async def on_after_request_verify(
|
||||||
self, user: User, token: str, request: Optional[Request] = None
|
self, user: User, token: str, request: Request | None = None
|
||||||
):
|
):
|
||||||
print(
|
print(f"Verification requested for user {user.id}. Verification token: {token}")
|
||||||
f"Verification requested for user {user.id}. Verification token: {token}")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||||
|
|
@ -77,6 +77,7 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||||
# get_strategy=get_jwt_strategy,
|
# get_strategy=get_jwt_strategy,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
# BEARER AUTH CODE.
|
# BEARER AUTH CODE.
|
||||||
class CustomBearerTransport(BearerTransport):
|
class CustomBearerTransport(BearerTransport):
|
||||||
async def get_login_response(self, token: str) -> Response:
|
async def get_login_response(self, token: str) -> Response:
|
||||||
|
|
@ -87,6 +88,7 @@ class CustomBearerTransport(BearerTransport):
|
||||||
else:
|
else:
|
||||||
return JSONResponse(model_dump(bearer_response))
|
return JSONResponse(model_dump(bearer_response))
|
||||||
|
|
||||||
|
|
||||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,19 @@
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.db import User
|
from app.db import User
|
||||||
|
|
||||||
|
|
||||||
# Helper function to check user ownership
|
# Helper function to check user ownership
|
||||||
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
|
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
|
||||||
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id))
|
item = await session.execute(
|
||||||
|
select(model).filter(model.id == item_id, model.user_id == user.id)
|
||||||
|
)
|
||||||
item = item.scalars().first()
|
item = item.scalars().first()
|
||||||
if not item:
|
if not item:
|
||||||
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it")
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Item not found or you don't have permission to access it",
|
||||||
|
)
|
||||||
return item
|
return item
|
||||||
|
|
@ -32,7 +32,7 @@ async def convert_element_to_markdown(element) -> str:
|
||||||
"Footer": lambda x: f"*{x}*\n\n",
|
"Footer": lambda x: f"*{x}*\n\n",
|
||||||
"CodeSnippet": lambda x: f"```\n{x}\n```",
|
"CodeSnippet": lambda x: f"```\n{x}\n```",
|
||||||
"PageNumber": lambda x: f"*Page {x}*\n\n",
|
"PageNumber": lambda x: f"*Page {x}*\n\n",
|
||||||
"UncategorizedText": lambda x: f"{x}\n\n"
|
"UncategorizedText": lambda x: f"{x}\n\n",
|
||||||
}
|
}
|
||||||
|
|
||||||
converter = markdown_mapping.get(element_category, lambda x: x)
|
converter = markdown_mapping.get(element_category, lambda x: x)
|
||||||
|
|
@ -74,7 +74,7 @@ def convert_chunks_to_langchain_documents(chunks):
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
|
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
|
||||||
)
|
) from None
|
||||||
|
|
||||||
langchain_docs = []
|
langchain_docs = []
|
||||||
|
|
||||||
|
|
@ -92,17 +92,20 @@ def convert_chunks_to_langchain_documents(chunks):
|
||||||
# Add document information to metadata
|
# Add document information to metadata
|
||||||
if "document" in chunk:
|
if "document" in chunk:
|
||||||
doc = chunk["document"]
|
doc = chunk["document"]
|
||||||
metadata.update({
|
metadata.update(
|
||||||
|
{
|
||||||
"document_id": doc.get("id"),
|
"document_id": doc.get("id"),
|
||||||
"document_title": doc.get("title"),
|
"document_title": doc.get("title"),
|
||||||
"document_type": doc.get("document_type"),
|
"document_type": doc.get("document_type"),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Add document metadata if available
|
# Add document metadata if available
|
||||||
if "metadata" in doc:
|
if "metadata" in doc:
|
||||||
# Prefix document metadata keys to avoid conflicts
|
# Prefix document metadata keys to avoid conflicts
|
||||||
doc_metadata = {f"doc_meta_{k}": v for k,
|
doc_metadata = {
|
||||||
v in doc.get("metadata", {}).items()}
|
f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()
|
||||||
|
}
|
||||||
metadata.update(doc_metadata)
|
metadata.update(doc_metadata)
|
||||||
|
|
||||||
# Add source URL if available in metadata
|
# Add source URL if available in metadata
|
||||||
|
|
@ -131,10 +134,7 @@ def convert_chunks_to_langchain_documents(chunks):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create LangChain Document
|
# Create LangChain Document
|
||||||
langchain_doc = LangChainDocument(
|
langchain_doc = LangChainDocument(page_content=new_content, metadata=metadata)
|
||||||
page_content=new_content,
|
|
||||||
metadata=metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
langchain_docs.append(langchain_doc)
|
langchain_docs.append(langchain_doc)
|
||||||
|
|
||||||
|
|
@ -144,4 +144,4 @@ def convert_chunks_to_langchain_documents(chunks):
|
||||||
def generate_content_hash(content: str, search_space_id: int) -> str:
|
def generate_content_hash(content: str, search_space_id: int) -> str:
|
||||||
"""Generate SHA-256 hash for the given content combined with search space ID."""
|
"""Generate SHA-256 hash for the given content combined with search space ID."""
|
||||||
combined_data = f"{search_space_id}:{content}"
|
combined_data = f"{search_space_id}:{content}"
|
||||||
return hashlib.sha256(combined_data.encode('utf-8')).hexdigest()
|
return hashlib.sha256(combined_data.encode("utf-8")).hexdigest()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import uvicorn
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from app.config.uvicorn import load_uvicorn_config
|
from app.config.uvicorn import load_uvicorn_config
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|
|
||||||
|
|
@ -36,3 +36,97 @@ dependencies = [
|
||||||
"validators>=0.34.0",
|
"validators>=0.34.0",
|
||||||
"youtube-transcript-api>=1.0.3",
|
"youtube-transcript-api>=1.0.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"ruff>=0.12.5",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
# Exclude a variety of commonly ignored directories.
|
||||||
|
exclude = [
|
||||||
|
".bzr",
|
||||||
|
".direnv",
|
||||||
|
".eggs",
|
||||||
|
".git",
|
||||||
|
".git-rewrite",
|
||||||
|
".hg",
|
||||||
|
".ipynb_checkpoints",
|
||||||
|
".mypy_cache",
|
||||||
|
".nox",
|
||||||
|
".pants.d",
|
||||||
|
".pyenv",
|
||||||
|
".pytest_cache",
|
||||||
|
".pytype",
|
||||||
|
".ruff_cache",
|
||||||
|
".svn",
|
||||||
|
".tox",
|
||||||
|
".venv",
|
||||||
|
".vscode",
|
||||||
|
"__pypackages__",
|
||||||
|
"_build",
|
||||||
|
"buck-out",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"node_modules",
|
||||||
|
"site-packages",
|
||||||
|
"venv",
|
||||||
|
]
|
||||||
|
|
||||||
|
line-length = 88
|
||||||
|
indent-width = 4
|
||||||
|
|
||||||
|
# Python 3.12
|
||||||
|
target-version = "py312"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E4", # pycodestyle errors
|
||||||
|
"E7", # pycodestyle errors
|
||||||
|
"E9", # pycodestyle errors
|
||||||
|
"F", # Pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"N", # pep8-naming
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"T20", # flake8-print
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"RUF", # Ruff-specific rules
|
||||||
|
]
|
||||||
|
|
||||||
|
ignore = [
|
||||||
|
"E501", # Line too long (handled by formatter)
|
||||||
|
"B008", # Do not perform function calls in argument defaults
|
||||||
|
"T201", # Print found (allow print statements)
|
||||||
|
"RUF012", # Mutable class attributes should be annotated with `typing.ClassVar`
|
||||||
|
]
|
||||||
|
|
||||||
|
extend-select = ["I"]
|
||||||
|
|
||||||
|
# Allow fix for all enabled rules (when `--fix`) is provided.
|
||||||
|
fixable = ["ALL"]
|
||||||
|
unfixable = []
|
||||||
|
|
||||||
|
# Allow unused variables when underscore-prefixed.
|
||||||
|
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
# Use double quotes for strings.
|
||||||
|
quote-style = "double"
|
||||||
|
|
||||||
|
# Indent with spaces, rather than tabs.
|
||||||
|
indent-style = "space"
|
||||||
|
|
||||||
|
# Respect magic trailing commas.
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
|
|
||||||
|
# Automatically detect the appropriate line ending.
|
||||||
|
line-ending = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
# Group imports by type
|
||||||
|
known-first-party = ["app"]
|
||||||
|
force-single-line = false
|
||||||
|
combine-as-imports = true
|
||||||
|
|
|
||||||
4507
surfsense_backend/uv.lock
generated
4507
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue