mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-08 20:25:19 +02:00
Merge main to Llamaindex-chatui
This commit is contained in:
commit
f006a76587
104 changed files with 12412 additions and 7680 deletions
224
.github/workflows/code-quality.yml
vendored
Normal file
224
.github/workflows/code-quality.yml
vendored
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
name: Code Quality Checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main, dev]
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
file-quality:
|
||||
name: File Quality Checks
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Fetch base branch
|
||||
run: |
|
||||
# Ensure we have the base branch reference for comparison
|
||||
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit
|
||||
|
||||
- name: Cache pre-commit hooks
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
pre-commit-
|
||||
|
||||
- name: Install hook environments (cache)
|
||||
run: pre-commit install-hooks
|
||||
|
||||
- name: Run file quality checks on changed files
|
||||
run: |
|
||||
# Get list of changed files and run specific hooks on them
|
||||
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
|
||||
BASE_REF="${{ github.base_ref }}"
|
||||
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
|
||||
BASE_REF="origin/${{ github.base_ref }}"
|
||||
else
|
||||
echo "Base branch reference not found, running file quality hooks on all files"
|
||||
pre-commit run --all-files check-yaml check-json check-toml check-merge-conflict check-added-large-files debug-statements check-case-conflict
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Running file quality hooks on changed files against $BASE_REF"
|
||||
|
||||
# Run each hook individually on changed files
|
||||
SKIP=detect-secrets,bandit,ruff,ruff-format,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
|
||||
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
|
||||
|
||||
# Exit with the same code as pre-commit
|
||||
exit ${exit_code:-0}
|
||||
|
||||
security-scan:
|
||||
name: Security Scan
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Fetch base branch
|
||||
run: |
|
||||
git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} 2>/dev/null || git fetch origin ${{ github.base_ref }} 2>/dev/null || true
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit
|
||||
|
||||
- name: Cache pre-commit hooks
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-security-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
pre-commit-security-
|
||||
|
||||
- name: Install hook environments (cache)
|
||||
run: pre-commit install-hooks
|
||||
|
||||
- name: Run security scans on changed files
|
||||
run: |
|
||||
# Get base ref for comparison
|
||||
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
|
||||
BASE_REF="${{ github.base_ref }}"
|
||||
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
|
||||
BASE_REF="origin/${{ github.base_ref }}"
|
||||
else
|
||||
echo "Base branch reference not found, running security scans on all files"
|
||||
echo "⚠️ This may take longer than normal"
|
||||
pre-commit run --all-files detect-secrets bandit
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Running security scans on changed files against $BASE_REF"
|
||||
|
||||
# Run only security hooks on changed files
|
||||
SKIP=check-yaml,check-json,check-toml,check-merge-conflict,check-added-large-files,debug-statements,check-case-conflict,ruff,ruff-format,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
|
||||
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
|
||||
|
||||
# Exit with the same code as pre-commit
|
||||
exit ${exit_code:-0}
|
||||
|
||||
python-backend:
|
||||
name: Python Backend Quality
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install UV
|
||||
uses: astral-sh/setup-uv@v3
|
||||
|
||||
- name: Check if backend files changed
|
||||
id: backend-changes
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
backend:
|
||||
- 'surfsense_backend/**'
|
||||
|
||||
- name: Cache dependencies
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
surfsense_backend/.venv
|
||||
key: python-deps-${{ hashFiles('surfsense_backend/uv.lock') }}
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
working-directory: surfsense_backend
|
||||
run: uv sync
|
||||
|
||||
- name: Install pre-commit for backend checks
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
run: pip install pre-commit
|
||||
|
||||
- name: Cache pre-commit hooks
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-backend-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
pre-commit-backend-
|
||||
|
||||
- name: Install hook environments (cache)
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
run: pre-commit install-hooks
|
||||
|
||||
- name: Run Python backend quality checks
|
||||
if: steps.backend-changes.outputs.backend == 'true'
|
||||
run: |
|
||||
# Get base ref for comparison
|
||||
if git show-ref --verify --quiet refs/heads/${{ github.base_ref }}; then
|
||||
BASE_REF="${{ github.base_ref }}"
|
||||
elif git show-ref --verify --quiet refs/remotes/origin/${{ github.base_ref }}; then
|
||||
BASE_REF="origin/${{ github.base_ref }}"
|
||||
else
|
||||
echo "Base branch reference not found, running Python backend checks on all files"
|
||||
pre-commit run --all-files ruff ruff-format
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Running Python backend checks on changed files against $BASE_REF"
|
||||
|
||||
# Run only ruff hooks on changed Python files
|
||||
SKIP=detect-secrets,bandit,check-yaml,check-json,check-toml,check-merge-conflict,check-added-large-files,debug-statements,check-case-conflict,prettier,eslint,typescript-check-web,typescript-check-extension,commitizen \
|
||||
pre-commit run --from-ref $BASE_REF --to-ref HEAD || exit_code=$?
|
||||
|
||||
# Exit with the same code as pre-commit
|
||||
exit ${exit_code:-0}
|
||||
|
||||
quality-gate:
|
||||
name: Quality Gate
|
||||
runs-on: ubuntu-latest
|
||||
needs: [file-quality, security-scan, python-backend]
|
||||
if: always()
|
||||
|
||||
steps:
|
||||
- name: Check all jobs status
|
||||
run: |
|
||||
if [[ "${{ needs.file-quality.result }}" == "failure" ||
|
||||
"${{ needs.security-scan.result }}" == "failure" ||
|
||||
"${{ needs.python-backend.result }}" == "failure" ]]; then
|
||||
echo "❌ Code quality checks failed"
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All code quality checks passed"
|
||||
fi
|
||||
59
.github/workflows/pre-commit.yml
vendored
Normal file
59
.github/workflows/pre-commit.yml
vendored
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
name: pre-commit
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
branches: [main, dev]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Required for detecting diffs
|
||||
|
||||
- name: Fetch main branch
|
||||
run: |
|
||||
# Ensure we have the main branch reference for comparison
|
||||
git fetch origin main:main 2>/dev/null || git fetch origin main 2>/dev/null || true
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Cache pre-commit environments
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
pre-commit-
|
||||
|
||||
- name: Install pre-commit
|
||||
run: |
|
||||
pip install pre-commit
|
||||
|
||||
- name: Install hook environments (cache)
|
||||
run: |
|
||||
pre-commit install-hooks
|
||||
|
||||
- name: Run pre-commit on changed files
|
||||
run: |
|
||||
# Use pre-commit's native diff detection with fallback strategies
|
||||
if git show-ref --verify --quiet refs/heads/main; then
|
||||
# Main branch exists locally, use pre-commit's native diff mode
|
||||
echo "Running pre-commit with native diff detection against main branch"
|
||||
pre-commit run --from-ref main --to-ref HEAD
|
||||
elif git show-ref --verify --quiet refs/remotes/origin/main; then
|
||||
# Origin/main exists, use it as reference
|
||||
echo "Running pre-commit with native diff detection against origin/main"
|
||||
pre-commit run --from-ref origin/main --to-ref HEAD
|
||||
else
|
||||
# Fallback: run on all files (for first commits or when main is unavailable)
|
||||
echo "Main branch reference not found, running pre-commit on all files"
|
||||
echo "⚠️ This may take longer and show more issues than normal"
|
||||
pre-commit run --all-files
|
||||
fi
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,3 +1,5 @@
|
|||
.flashrank_cache*
|
||||
podcasts/
|
||||
.env
|
||||
|
||||
.ruff_cache/
|
||||
112
.pre-commit-config.yaml
Normal file
112
.pre-commit-config.yaml
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
# Pre-commit configuration for SurfSense
|
||||
# See https://pre-commit.com for more information
|
||||
|
||||
repos:
|
||||
# General file quality hooks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
args: [--multi, --unsafe]
|
||||
- id: check-json
|
||||
exclude: '(tsconfig\.json|\.vscode/.*\.json)$'
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=10240] # 10MB limit
|
||||
- id: debug-statements
|
||||
- id: check-case-conflict
|
||||
|
||||
# Security - detect secrets across all file types
|
||||
- repo: https://github.com/Yelp/detect-secrets
|
||||
rev: v1.5.0
|
||||
hooks:
|
||||
- id: detect-secrets
|
||||
args: ['--baseline', '.secrets.baseline']
|
||||
exclude: |
|
||||
(?x)^(
|
||||
.*\.env\.example|
|
||||
.*\.env\.template|
|
||||
.*/tests/.*|
|
||||
.*test.*\.py|
|
||||
test_.*\.py|
|
||||
.github/workflows/.*\.yml|
|
||||
.github/workflows/.*\.yaml|
|
||||
.*pnpm-lock\.yaml|
|
||||
.*alembic\.ini|
|
||||
.*alembic/versions/.*\.py|
|
||||
.*\.mdx$
|
||||
)$
|
||||
|
||||
# Python Backend Hooks (surfsense_backend) - Using Ruff for linting and formatting
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: ruff-check
|
||||
files: ^surfsense_backend/
|
||||
exclude: ^surfsense_backend/(test_.*\.py|.*test.*\.py)
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
name: ruff-format
|
||||
files: ^surfsense_backend/
|
||||
exclude: ^surfsense_backend/(test_.*\.py|.*test.*\.py)
|
||||
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.6
|
||||
hooks:
|
||||
- id: bandit
|
||||
files: ^surfsense_backend/
|
||||
args: ['-f', 'json', '--severity-level', 'high', '--confidence-level', 'high']
|
||||
exclude: ^surfsense_backend/(tests/|test_.*\.py|.*test.*\.py|alembic/)
|
||||
|
||||
# Frontend/Extension Hooks (TypeScript/JavaScript)
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v4.0.0-alpha.8
|
||||
hooks:
|
||||
- id: prettier
|
||||
files: ^(surfsense_web|surfsense_browser_extension)/
|
||||
types_or: [javascript, jsx, ts, tsx, json, yaml, markdown]
|
||||
exclude: '(package-lock\.json|\.next/|build/|dist/)'
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-eslint
|
||||
rev: v9.31.0
|
||||
hooks:
|
||||
- id: eslint
|
||||
files: ^surfsense_web/
|
||||
types: [file]
|
||||
types_or: [javascript, jsx, ts, tsx]
|
||||
additional_dependencies:
|
||||
- 'eslint@^9'
|
||||
- 'eslint-config-next@15.2.0'
|
||||
- '@eslint/eslintrc@^3'
|
||||
args: [--fix]
|
||||
exclude: '(\.next/|build/|dist/)'
|
||||
|
||||
# TypeScript compilation check
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: typescript-check-web
|
||||
name: TypeScript Check (Web)
|
||||
entry: bash -c 'cd surfsense_web && (command -v pnpm >/dev/null 2>&1 && pnpm build --dry-run || npx next build --dry-run)'
|
||||
language: system
|
||||
files: ^surfsense_web/.*\.(ts|tsx)$
|
||||
pass_filenames: false
|
||||
|
||||
- id: typescript-check-extension
|
||||
name: TypeScript Check (Browser Extension)
|
||||
entry: bash -c 'cd surfsense_browser_extension && npx tsc --noEmit'
|
||||
language: system
|
||||
files: ^surfsense_browser_extension/.*\.(ts|tsx)$
|
||||
pass_filenames: false
|
||||
|
||||
# Commit message linting
|
||||
- repo: https://github.com/commitizen-tools/commitizen
|
||||
rev: v4.8.3
|
||||
hooks:
|
||||
- id: commitizen
|
||||
stages: [commit-msg]
|
||||
|
||||
# Global configuration
|
||||
default_stages: [pre-commit]
|
||||
fail_fast: false
|
||||
115
.secrets.baseline
Normal file
115
.secrets.baseline
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
{
|
||||
"version": "1.5.0",
|
||||
"plugins_used": [
|
||||
{
|
||||
"name": "ArtifactoryDetector"
|
||||
},
|
||||
{
|
||||
"name": "AWSKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "AzureStorageKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "Base64HighEntropyString",
|
||||
"limit": 4.5
|
||||
},
|
||||
{
|
||||
"name": "BasicAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "CloudantDetector"
|
||||
},
|
||||
{
|
||||
"name": "DiscordBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitHubTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "HexHighEntropyString",
|
||||
"limit": 3.0
|
||||
},
|
||||
{
|
||||
"name": "IbmCloudIamDetector"
|
||||
},
|
||||
{
|
||||
"name": "IbmCosHmacDetector"
|
||||
},
|
||||
{
|
||||
"name": "JwtTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "KeywordDetector",
|
||||
"keyword_exclude": ""
|
||||
},
|
||||
{
|
||||
"name": "MailchimpDetector"
|
||||
},
|
||||
{
|
||||
"name": "NpmDetector"
|
||||
},
|
||||
{
|
||||
"name": "PrivateKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "SendGridDetector"
|
||||
},
|
||||
{
|
||||
"name": "SlackDetector"
|
||||
},
|
||||
{
|
||||
"name": "SoftlayerDetector"
|
||||
},
|
||||
{
|
||||
"name": "SquareOAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "StripeDetector"
|
||||
},
|
||||
{
|
||||
"name": "TwilioKeyDetector"
|
||||
}
|
||||
],
|
||||
"filters_used": [
|
||||
{
|
||||
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.common.is_baseline_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
|
||||
"min_level": 2
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_indirect_reference"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_likely_id_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_lock_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_potential_uuid"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_sequential_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_swagger_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_templated_secret"
|
||||
}
|
||||
],
|
||||
"results": {},
|
||||
"generated_at": "2025-01-20T12:00:00Z"
|
||||
}
|
||||
|
|
@ -76,6 +76,13 @@ SurfSense consists of three main components:
|
|||
|
||||
## 🧪 Development Guidelines
|
||||
|
||||
### Code Quality & Pre-commit Hooks
|
||||
We use pre-commit hooks to maintain code quality, security, and consistency across the codebase. Before you start developing:
|
||||
|
||||
1. **Install and set up pre-commit hooks** - See our detailed [Pre-commit Guide](./PRE_COMMIT.md)
|
||||
2. **Understand the automated checks** that will run on your code
|
||||
3. **Learn about bypassing hooks** when necessary (use sparingly!)
|
||||
|
||||
### Code Style
|
||||
- **Backend**: Follow Python PEP 8 style guidelines
|
||||
- **Frontend**: Use TypeScript and follow the existing code patterns
|
||||
|
|
|
|||
237
PRE_COMMIT.md
Normal file
237
PRE_COMMIT.md
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
# Pre-commit Hooks for SurfSense Contributors
|
||||
|
||||
Welcome to SurfSense! As an open-source project, we use pre-commit hooks to maintain code quality, security, and consistency across our multi-component codebase. This guide will help you set up and work with our pre-commit configuration.
|
||||
|
||||
## 🚀 What is Pre-commit?
|
||||
|
||||
Pre-commit is a framework for managing multi-language pre-commit hooks. It runs automatically before each commit to catch issues early, ensuring high code quality and consistency across the project.
|
||||
|
||||
## 📁 Project Structure
|
||||
|
||||
SurfSense consists of three main components:
|
||||
- **`surfsense_backend/`** - Python backend API
|
||||
- **`surfsense_web/`** - Next.js web application
|
||||
- **`surfsense_browser_extension/`** - TypeScript browser extension
|
||||
|
||||
## 🛠 Installation
|
||||
|
||||
### Prerequisites
|
||||
- Python 3.8 or higher
|
||||
- Node.js 18+ and pnpm (for frontend components)
|
||||
- Git
|
||||
|
||||
### Install Pre-commit
|
||||
|
||||
```bash
|
||||
# Install pre-commit globally
|
||||
pip install pre-commit
|
||||
|
||||
# Or using your preferred package manager
|
||||
# pipx install pre-commit # Recommended for isolation
|
||||
```
|
||||
|
||||
### Setup Pre-commit Hooks
|
||||
|
||||
1. **Clone the repository**:
|
||||
```bash
|
||||
git clone https://github.com/masabinhok/SurfSense.git
|
||||
cd SurfSense
|
||||
```
|
||||
|
||||
2. **Install the pre-commit hooks**:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
3. **Install commit message hooks** (optional, for conventional commits):
|
||||
```bash
|
||||
pre-commit install --hook-type commit-msg
|
||||
```
|
||||
|
||||
## 🔧 Configuration Files Added
|
||||
|
||||
When you install pre-commit, the following files are part of the setup:
|
||||
|
||||
- **`.pre-commit-config.yaml`** - Main pre-commit configuration
|
||||
- **`.secrets.baseline`** - Baseline file for secret detection (prevents false positives)
|
||||
- **`.github/workflows/pre-commit.yml`** - CI workflow that runs pre-commit on PRs
|
||||
|
||||
## 🎯 What Gets Checked
|
||||
|
||||
### All Files
|
||||
- ✅ Trailing whitespace removal
|
||||
- ✅ YAML, JSON, and TOML validation
|
||||
- ✅ Large file detection (>10MB)
|
||||
- ✅ Merge conflict markers
|
||||
- 🔒 **Secret detection** using detect-secrets
|
||||
|
||||
### Python Backend (`surfsense_backend/`)
|
||||
- 🐍 **Black** - Code formatting
|
||||
- 📦 **isort** - Import sorting
|
||||
- ⚡ **Ruff** - Fast linting and formatting
|
||||
- 🔍 **MyPy** - Static type checking
|
||||
- 🛡️ **Bandit** - Security vulnerability scanning
|
||||
|
||||
### Frontend (`surfsense_web/` & `surfsense_browser_extension/`)
|
||||
- 💅 **Prettier** - Code formatting
|
||||
- 🔍 **ESLint** - Linting (Next.js config)
|
||||
- 📝 **TypeScript** - Compilation checks
|
||||
|
||||
### Commit Messages
|
||||
- 📝 **Commitizen** - Conventional commit format validation
|
||||
|
||||
## 🚀 Usage
|
||||
|
||||
### Normal Workflow
|
||||
Pre-commit will run automatically when you commit:
|
||||
|
||||
```bash
|
||||
git add .
|
||||
git commit -m "feat: add new feature"
|
||||
# Pre-commit hooks will run automatically
|
||||
```
|
||||
|
||||
### Manual Execution
|
||||
|
||||
Run on staged files only:
|
||||
```bash
|
||||
pre-commit run
|
||||
```
|
||||
|
||||
Run on specific files:
|
||||
```bash
|
||||
pre-commit run --files path/to/file.py path/to/file.ts
|
||||
```
|
||||
|
||||
Run all hooks on all files:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
⚠️ **Warning**: Running `--all-files` may generate numerous errors as this codebase has existing linting and type issues that are being gradually resolved.
|
||||
|
||||
### Advanced Commands
|
||||
|
||||
Update all hooks to latest versions:
|
||||
```bash
|
||||
pre-commit autoupdate
|
||||
```
|
||||
|
||||
Run only specific hooks:
|
||||
```bash
|
||||
pre-commit run black # Run only black
|
||||
pre-commit run --all-files prettier # Run prettier on all files
|
||||
```
|
||||
|
||||
Clean pre-commit cache:
|
||||
```bash
|
||||
pre-commit clean
|
||||
```
|
||||
|
||||
## 🆘 Bypassing Pre-commit (When Necessary)
|
||||
|
||||
Sometimes you might need to bypass pre-commit hooks (use sparingly!):
|
||||
|
||||
### Skip all hooks for one commit:
|
||||
```bash
|
||||
git commit -m "fix: urgent hotfix" --no-verify
|
||||
```
|
||||
|
||||
### Skip specific hooks:
|
||||
```bash
|
||||
SKIP=mypy,black git commit -m "feat: work in progress"
|
||||
```
|
||||
|
||||
Available hook IDs to skip:
|
||||
- `trailing-whitespace`, `check-yaml`, `check-json`
|
||||
- `detect-secrets`
|
||||
- `black`, `isort`, `ruff`, `ruff-format`, `mypy`, `bandit`
|
||||
- `prettier`, `eslint`
|
||||
- `typescript-check-web`, `typescript-check-extension`
|
||||
- `commitizen`
|
||||
|
||||
## 🐛 Common Issues & Solutions
|
||||
|
||||
### Secret Detection False Positives
|
||||
|
||||
If detect-secrets flags legitimate content as secrets:
|
||||
|
||||
1. **Review the detection** - Ensure it's not actually a secret
|
||||
2. **Update baseline**:
|
||||
```bash
|
||||
detect-secrets scan --baseline .secrets.baseline --update
|
||||
git add .secrets.baseline
|
||||
```
|
||||
|
||||
### TypeScript/Node.js Issues
|
||||
|
||||
Ensure dependencies are installed:
|
||||
```bash
|
||||
cd surfsense_web && pnpm install
|
||||
cd surfsense_browser_extension && pnpm install
|
||||
```
|
||||
|
||||
### Python Environment Issues
|
||||
|
||||
For Python hooks, ensure you're in the correct environment:
|
||||
```bash
|
||||
cd surfsense_backend
|
||||
# If using uv
|
||||
uv sync
|
||||
# Or traditional pip
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Hook Installation Issues
|
||||
|
||||
If hooks aren't running:
|
||||
```bash
|
||||
pre-commit uninstall
|
||||
pre-commit install --install-hooks
|
||||
```
|
||||
|
||||
## 📊 Performance Tips
|
||||
|
||||
- **Incremental runs**: Pre-commit only runs on changed files by default
|
||||
- **Parallel execution**: Many hooks run in parallel for speed
|
||||
- **Caching**: Pre-commit caches environments to speed up subsequent runs
|
||||
|
||||
## 🔄 CI Integration
|
||||
|
||||
Pre-commit also runs in our GitHub Actions CI pipeline on every PR to `main`. The CI:
|
||||
- Runs only on changed files for efficiency
|
||||
- Provides the same feedback as local pre-commit
|
||||
- Prevents merging code that doesn't pass quality checks
|
||||
|
||||
## 📋 Best Practices
|
||||
|
||||
1. **Install pre-commit early** in your development setup
|
||||
2. **Fix issues incrementally** rather than bypassing hooks
|
||||
3. **Update your branch regularly** to avoid conflicts with formatting changes
|
||||
4. **Run `--all-files` periodically** on feature branches (in small chunks)
|
||||
5. **Keep the `.secrets.baseline` updated** when legitimate secrets-like strings are added
|
||||
|
||||
## 💡 Contributing to Pre-commit Config
|
||||
|
||||
To modify the pre-commit configuration:
|
||||
|
||||
1. Edit `.pre-commit-config.yaml`
|
||||
2. Test your changes:
|
||||
```bash
|
||||
pre-commit run --all-files # Test with caution!
|
||||
```
|
||||
3. Update the baseline if needed:
|
||||
```bash
|
||||
detect-secrets scan --baseline .secrets.baseline --update
|
||||
```
|
||||
4. Submit a PR with your changes
|
||||
|
||||
## 🆘 Getting Help
|
||||
|
||||
- **Pre-commit docs**: https://pre-commit.com/
|
||||
- **Project issues**: Open an issue on GitHub
|
||||
- **Hook-specific help**: Check individual tool documentation (Black, Ruff, ESLint, etc.)
|
||||
|
||||
---
|
||||
|
||||
Thank you for contributing to SurfSense! 🏄♀️ Quality code makes everyone's surfing experience smoother.
|
||||
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
|
||||
# SurfSense
|
||||
While tools like NotebookLM and Perplexity are impressive and highly effective for conducting research on any topic/query, SurfSense elevates this capability by integrating with your personal knowledge base. It is a highly customizable AI research agent, connected to external sources such as search engines (Tavily, LinkUp), Slack, Linear, Notion, YouTube, GitHub, Discord and more to come.
|
||||
While tools like NotebookLM and Perplexity are impressive and highly effective for conducting research on any topic/query, SurfSense elevates this capability by integrating with your personal knowledge base. It is a highly customizable AI research agent, connected to external sources such as search engines (Tavily, LinkUp), Slack, Linear, Jira, Notion, YouTube, GitHub, Discord and more to come.
|
||||
|
||||
<div align="center">
|
||||
<a href="https://trendshift.io/repositories/13606" target="_blank"><img src="https://trendshift.io/api/badge/repositories/13606" alt="MODSetter%2FSurfSense | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
|
@ -63,6 +63,7 @@ Open source and easy to deploy locally.
|
|||
- Search Engines (Tavily, LinkUp)
|
||||
- Slack
|
||||
- Linear
|
||||
- Jira
|
||||
- Notion
|
||||
- Youtube Videos
|
||||
- GitHub
|
||||
|
|
|
|||
1
node_modules/.cache/prettier/.prettier-caches/a2ecb2962bf19c1099cfe708e42daa0097f94976.json
generated
vendored
Normal file
1
node_modules/.cache/prettier/.prettier-caches/a2ecb2962bf19c1099cfe708e42daa0097f94976.json
generated
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
{"2d0ec64d93969318101ee479b664221b32241665":{"files":{"surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/page.tsx":["EHKKvlOK0vfy0GgHwlG/J2Bx5rw=",true]},"modified":1753426633288}}
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
|
@ -11,10 +11,10 @@ from alembic import context
|
|||
|
||||
# Ensure the app directory is in the Python path
|
||||
# This allows Alembic to find your models
|
||||
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
# Import your models base
|
||||
from app.db import Base # Assuming your Base is defined in app.db
|
||||
from app.db import Base # Assuming your Base is defined in app.db
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
|
|
|||
|
|
@ -4,17 +4,15 @@ Revision ID: 10
|
|||
Revises: 9
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "10"
|
||||
down_revision: Union[str, None] = "9"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "9"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Define the ENUM type name
|
||||
CHAT_TYPE_ENUM = "chattype"
|
||||
|
|
@ -22,87 +20,101 @@ CHAT_TYPE_ENUM = "chattype"
|
|||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema - replace ChatType enum values with new QNA/REPORT structure."""
|
||||
|
||||
|
||||
# Old enum name for temporary storage
|
||||
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
||||
|
||||
|
||||
# New enum values
|
||||
new_values = (
|
||||
"QNA",
|
||||
"REPORT_GENERAL",
|
||||
"REPORT_DEEP",
|
||||
"REPORT_DEEPER"
|
||||
)
|
||||
new_values = ("QNA", "REPORT_GENERAL", "REPORT_DEEP", "REPORT_DEEPER")
|
||||
new_values_sql = ", ".join([f"'{v}'" for v in new_values])
|
||||
|
||||
|
||||
# Table and column info
|
||||
table_name = "chats"
|
||||
column_name = "type"
|
||||
|
||||
|
||||
# Step 1: Rename the current enum type
|
||||
op.execute(f"ALTER TYPE {CHAT_TYPE_ENUM} RENAME TO {old_enum_name}")
|
||||
|
||||
|
||||
# Step 2: Create the new enum type with new values
|
||||
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({new_values_sql})")
|
||||
|
||||
|
||||
# Step 3: Add a temporary column with the new type
|
||||
op.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}")
|
||||
|
||||
op.execute(
|
||||
f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}"
|
||||
)
|
||||
|
||||
# Step 4: Update the temporary column with mapped values
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'QNA' WHERE {column_name}::text = 'GENERAL'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEP' WHERE {column_name}::text = 'DEEP'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPER'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPEST'")
|
||||
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'QNA' WHERE {column_name}::text = 'GENERAL'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEP' WHERE {column_name}::text = 'DEEP'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPER'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'REPORT_DEEPER' WHERE {column_name}::text = 'DEEPEST'"
|
||||
)
|
||||
|
||||
# Step 5: Drop the old column
|
||||
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
|
||||
|
||||
|
||||
# Step 6: Rename the new column to the original name
|
||||
op.execute(f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}")
|
||||
|
||||
op.execute(
|
||||
f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}"
|
||||
)
|
||||
|
||||
# Step 7: Drop the old enum type
|
||||
op.execute(f"DROP TYPE {old_enum_name}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema - revert ChatType enum to old GENERAL/DEEP/DEEPER/DEEPEST structure."""
|
||||
|
||||
|
||||
# Old enum name for temporary storage
|
||||
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
||||
|
||||
|
||||
# Original enum values
|
||||
original_values = (
|
||||
"GENERAL",
|
||||
"DEEP",
|
||||
"DEEPER",
|
||||
"DEEPEST"
|
||||
)
|
||||
original_values = ("GENERAL", "DEEP", "DEEPER", "DEEPEST")
|
||||
original_values_sql = ", ".join([f"'{v}'" for v in original_values])
|
||||
|
||||
|
||||
# Table and column info
|
||||
table_name = "chats"
|
||||
column_name = "type"
|
||||
|
||||
|
||||
# Step 1: Rename the current enum type
|
||||
op.execute(f"ALTER TYPE {CHAT_TYPE_ENUM} RENAME TO {old_enum_name}")
|
||||
|
||||
|
||||
# Step 2: Create the new enum type with original values
|
||||
op.execute(f"CREATE TYPE {CHAT_TYPE_ENUM} AS ENUM({original_values_sql})")
|
||||
|
||||
|
||||
# Step 3: Add a temporary column with the original type
|
||||
op.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}")
|
||||
|
||||
op.execute(
|
||||
f"ALTER TABLE {table_name} ADD COLUMN {column_name}_new {CHAT_TYPE_ENUM}"
|
||||
)
|
||||
|
||||
# Step 4: Update the temporary column with mapped values back to old values
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'QNA'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'REPORT_GENERAL'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'DEEP' WHERE {column_name}::text = 'REPORT_DEEP'")
|
||||
op.execute(f"UPDATE {table_name} SET {column_name}_new = 'DEEPER' WHERE {column_name}::text = 'REPORT_DEEPER'")
|
||||
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'QNA'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'GENERAL' WHERE {column_name}::text = 'REPORT_GENERAL'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'DEEP' WHERE {column_name}::text = 'REPORT_DEEP'"
|
||||
)
|
||||
op.execute(
|
||||
f"UPDATE {table_name} SET {column_name}_new = 'DEEPER' WHERE {column_name}::text = 'REPORT_DEEPER'"
|
||||
)
|
||||
|
||||
# Step 5: Drop the old column
|
||||
op.execute(f"ALTER TABLE {table_name} DROP COLUMN {column_name}")
|
||||
|
||||
|
||||
# Step 6: Rename the new column to the original name
|
||||
op.execute(f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}")
|
||||
|
||||
op.execute(
|
||||
f"ALTER TABLE {table_name} RENAME COLUMN {column_name}_new TO {column_name}"
|
||||
)
|
||||
|
||||
# Step 7: Drop the old enum type
|
||||
op.execute(f"DROP TYPE {old_enum_name}")
|
||||
op.execute(f"DROP TYPE {old_enum_name}")
|
||||
|
|
|
|||
|
|
@ -4,83 +4,160 @@ Revision ID: 11
|
|||
Revises: 10
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "11"
|
||||
down_revision: Union[str, None] = "10"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "10"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema - add LiteLLMProvider enum, LLMConfig table and user LLM preferences."""
|
||||
|
||||
# Check if enum type exists and create if it doesn't
|
||||
op.execute("""
|
||||
|
||||
# Create enum only if not exists
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'litellmprovider') THEN
|
||||
CREATE TYPE litellmprovider AS ENUM ('OPENAI', 'ANTHROPIC', 'GROQ', 'COHERE', 'HUGGINGFACE', 'AZURE_OPENAI', 'GOOGLE', 'AWS_BEDROCK', 'OLLAMA', 'MISTRAL', 'TOGETHER_AI', 'REPLICATE', 'PALM', 'VERTEX_AI', 'ANYSCALE', 'PERPLEXITY', 'DEEPINFRA', 'AI21', 'NLPCLOUD', 'ALEPH_ALPHA', 'PETALS', 'CUSTOM');
|
||||
CREATE TYPE litellmprovider AS ENUM (
|
||||
'OPENAI', 'ANTHROPIC', 'GROQ', 'COHERE', 'HUGGINGFACE',
|
||||
'AZURE_OPENAI', 'GOOGLE', 'AWS_BEDROCK', 'OLLAMA', 'MISTRAL',
|
||||
'TOGETHER_AI', 'REPLICATE', 'PALM', 'VERTEX_AI', 'ANYSCALE',
|
||||
'PERPLEXITY', 'DEEPINFRA', 'AI21', 'NLPCLOUD', 'ALEPH_ALPHA',
|
||||
'PETALS', 'CUSTOM'
|
||||
);
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
|
||||
# Create llm_configs table using raw SQL to avoid enum creation conflicts
|
||||
op.execute("""
|
||||
CREATE TABLE llm_configs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
name VARCHAR(100) NOT NULL,
|
||||
provider litellmprovider NOT NULL,
|
||||
custom_provider VARCHAR(100),
|
||||
model_name VARCHAR(100) NOT NULL,
|
||||
api_key TEXT NOT NULL,
|
||||
api_base VARCHAR(500),
|
||||
litellm_params JSONB,
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_created_at'), 'llm_configs', ['created_at'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
|
||||
|
||||
# Add LLM preference columns to user table
|
||||
op.add_column('user', sa.Column('long_context_llm_id', sa.Integer(), nullable=True))
|
||||
op.add_column('user', sa.Column('fast_llm_id', sa.Integer(), nullable=True))
|
||||
op.add_column('user', sa.Column('strategic_llm_id', sa.Integer(), nullable=True))
|
||||
|
||||
# Create foreign key constraints for LLM preferences
|
||||
op.create_foreign_key(op.f('fk_user_long_context_llm_id_llm_configs'), 'user', 'llm_configs', ['long_context_llm_id'], ['id'], ondelete='SET NULL')
|
||||
op.create_foreign_key(op.f('fk_user_fast_llm_id_llm_configs'), 'user', 'llm_configs', ['fast_llm_id'], ['id'], ondelete='SET NULL')
|
||||
op.create_foreign_key(op.f('fk_user_strategic_llm_id_llm_configs'), 'user', 'llm_configs', ['strategic_llm_id'], ['id'], ondelete='SET NULL')
|
||||
"""
|
||||
)
|
||||
|
||||
# Create llm_configs table only if it doesn't already exist
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'llm_configs'
|
||||
) THEN
|
||||
CREATE TABLE llm_configs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
name VARCHAR(100) NOT NULL,
|
||||
provider litellmprovider NOT NULL,
|
||||
custom_provider VARCHAR(100),
|
||||
model_name VARCHAR(100) NOT NULL,
|
||||
api_key TEXT NOT NULL,
|
||||
api_base VARCHAR(500),
|
||||
litellm_params JSONB,
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE
|
||||
);
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create indexes if they don't exist
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE tablename = 'llm_configs' AND indexname = 'ix_llm_configs_id'
|
||||
) THEN
|
||||
CREATE INDEX ix_llm_configs_id ON llm_configs(id);
|
||||
END IF;
|
||||
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE tablename = 'llm_configs' AND indexname = 'ix_llm_configs_created_at'
|
||||
) THEN
|
||||
CREATE INDEX ix_llm_configs_created_at ON llm_configs(created_at);
|
||||
END IF;
|
||||
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE tablename = 'llm_configs' AND indexname = 'ix_llm_configs_name'
|
||||
) THEN
|
||||
CREATE INDEX ix_llm_configs_name ON llm_configs(name);
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Safely add columns to user table
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
existing_columns = [col["name"] for col in inspector.get_columns("user")]
|
||||
|
||||
with op.batch_alter_table("user") as batch_op:
|
||||
if "long_context_llm_id" not in existing_columns:
|
||||
batch_op.add_column(
|
||||
sa.Column("long_context_llm_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
op.f("fk_user_long_context_llm_id_llm_configs"),
|
||||
"llm_configs",
|
||||
["long_context_llm_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
if "fast_llm_id" not in existing_columns:
|
||||
batch_op.add_column(sa.Column("fast_llm_id", sa.Integer(), nullable=True))
|
||||
batch_op.create_foreign_key(
|
||||
op.f("fk_user_fast_llm_id_llm_configs"),
|
||||
"llm_configs",
|
||||
["fast_llm_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
if "strategic_llm_id" not in existing_columns:
|
||||
batch_op.add_column(
|
||||
sa.Column("strategic_llm_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
op.f("fk_user_strategic_llm_id_llm_configs"),
|
||||
"llm_configs",
|
||||
["strategic_llm_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema - remove LLMConfig table and user LLM preferences."""
|
||||
|
||||
|
||||
# Drop foreign key constraints
|
||||
op.drop_constraint(op.f('fk_user_strategic_llm_id_llm_configs'), 'user', type_='foreignkey')
|
||||
op.drop_constraint(op.f('fk_user_fast_llm_id_llm_configs'), 'user', type_='foreignkey')
|
||||
op.drop_constraint(op.f('fk_user_long_context_llm_id_llm_configs'), 'user', type_='foreignkey')
|
||||
|
||||
op.drop_constraint(
|
||||
op.f("fk_user_strategic_llm_id_llm_configs"), "user", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
op.f("fk_user_fast_llm_id_llm_configs"), "user", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
op.f("fk_user_long_context_llm_id_llm_configs"), "user", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop LLM preference columns from user table
|
||||
op.drop_column('user', 'strategic_llm_id')
|
||||
op.drop_column('user', 'fast_llm_id')
|
||||
op.drop_column('user', 'long_context_llm_id')
|
||||
|
||||
op.drop_column("user", "strategic_llm_id")
|
||||
op.drop_column("user", "fast_llm_id")
|
||||
op.drop_column("user", "long_context_llm_id")
|
||||
|
||||
# Drop indexes and table
|
||||
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_created_at'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs')
|
||||
op.drop_table('llm_configs')
|
||||
|
||||
op.drop_index(op.f("ix_llm_configs_name"), table_name="llm_configs")
|
||||
op.drop_index(op.f("ix_llm_configs_created_at"), table_name="llm_configs")
|
||||
op.drop_index(op.f("ix_llm_configs_id"), table_name="llm_configs")
|
||||
op.drop_table("llm_configs")
|
||||
|
||||
# Drop LiteLLMProvider enum
|
||||
op.execute("DROP TYPE IF EXISTS litellmprovider")
|
||||
op.execute("DROP TYPE IF EXISTS litellmprovider")
|
||||
|
|
|
|||
|
|
@ -4,68 +4,93 @@ Revision ID: 12
|
|||
Revises: 11
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "12"
|
||||
down_revision: Union[str, None] = "11"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "11"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema - add LogLevel and LogStatus enums and logs table."""
|
||||
|
||||
# Create LogLevel enum
|
||||
op.execute("""
|
||||
CREATE TYPE loglevel AS ENUM ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL')
|
||||
""")
|
||||
|
||||
# Create LogStatus enum
|
||||
op.execute("""
|
||||
CREATE TYPE logstatus AS ENUM ('IN_PROGRESS', 'SUCCESS', 'FAILED')
|
||||
""")
|
||||
|
||||
# Create logs table
|
||||
op.execute("""
|
||||
CREATE TABLE logs (
|
||||
|
||||
# Create LogLevel enum if it doesn't exist
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'loglevel') THEN
|
||||
CREATE TYPE loglevel AS ENUM ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL');
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create LogStatus enum if it doesn't exist
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'logstatus') THEN
|
||||
CREATE TYPE logstatus AS ENUM ('IN_PROGRESS', 'SUCCESS', 'FAILED');
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create logs table if it doesn't exist
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
level loglevel NOT NULL,
|
||||
status logstatus NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
source VARCHAR(200),
|
||||
log_metadata JSONB DEFAULT '{}',
|
||||
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
op.create_index(op.f('ix_logs_id'), 'logs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_logs_created_at'), 'logs', ['created_at'], unique=False)
|
||||
op.create_index(op.f('ix_logs_level'), 'logs', ['level'], unique=False)
|
||||
op.create_index(op.f('ix_logs_status'), 'logs', ['status'], unique=False)
|
||||
op.create_index(op.f('ix_logs_source'), 'logs', ['source'], unique=False)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Get existing indexes
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
existing_indexes = [idx["name"] for idx in inspector.get_indexes("logs")]
|
||||
|
||||
# Create indexes only if they don't already exist
|
||||
if "ix_logs_id" not in existing_indexes:
|
||||
op.create_index("ix_logs_id", "logs", ["id"])
|
||||
if "ix_logs_created_at" not in existing_indexes:
|
||||
op.create_index("ix_logs_created_at", "logs", ["created_at"])
|
||||
if "ix_logs_level" not in existing_indexes:
|
||||
op.create_index("ix_logs_level", "logs", ["level"])
|
||||
if "ix_logs_status" not in existing_indexes:
|
||||
op.create_index("ix_logs_status", "logs", ["status"])
|
||||
if "ix_logs_source" not in existing_indexes:
|
||||
op.create_index("ix_logs_source", "logs", ["source"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema - remove logs table and enums."""
|
||||
|
||||
|
||||
# Drop indexes
|
||||
op.drop_index(op.f('ix_logs_source'), table_name='logs')
|
||||
op.drop_index(op.f('ix_logs_status'), table_name='logs')
|
||||
op.drop_index(op.f('ix_logs_level'), table_name='logs')
|
||||
op.drop_index(op.f('ix_logs_created_at'), table_name='logs')
|
||||
op.drop_index(op.f('ix_logs_id'), table_name='logs')
|
||||
|
||||
op.drop_index("ix_logs_source", table_name="logs")
|
||||
op.drop_index("ix_logs_status", table_name="logs")
|
||||
op.drop_index("ix_logs_level", table_name="logs")
|
||||
op.drop_index("ix_logs_created_at", table_name="logs")
|
||||
op.drop_index("ix_logs_id", table_name="logs")
|
||||
|
||||
# Drop logs table
|
||||
op.drop_table('logs')
|
||||
|
||||
op.drop_table("logs")
|
||||
|
||||
# Drop enums
|
||||
op.execute("DROP TYPE IF EXISTS logstatus")
|
||||
op.execute("DROP TYPE IF EXISTS loglevel")
|
||||
op.execute("DROP TYPE IF EXISTS loglevel")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
"""Add JIRA_CONNECTOR to enums
|
||||
|
||||
Revision ID: 13
|
||||
Revises: 12
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "13"
|
||||
down_revision: str | None = "12"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Safely add 'JIRA_CONNECTOR' to enum types if missing."""
|
||||
|
||||
# Add to searchsourceconnectortype enum
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'searchsourceconnectortype' AND e.enumlabel = 'JIRA_CONNECTOR'
|
||||
) THEN
|
||||
ALTER TYPE searchsourceconnectortype ADD VALUE 'JIRA_CONNECTOR';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add to documenttype enum
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'documenttype' AND e.enumlabel = 'JIRA_CONNECTOR'
|
||||
) THEN
|
||||
ALTER TYPE documenttype ADD VALUE 'JIRA_CONNECTOR';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Downgrade logic not implemented since PostgreSQL
|
||||
does not support removing enum values.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -1,31 +1,48 @@
|
|||
"""Add GITHUB_CONNECTOR to SearchSourceConnectorType enum
|
||||
|
||||
Revision ID: 1
|
||||
Revises:
|
||||
Revises:
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# Import pgvector if needed for other types, though not for this ENUM change
|
||||
# import pgvector
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "1"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
|
||||
# Manually add the command to add the enum value
|
||||
# Note: It's generally better to let autogenerate handle this, but we're bypassing it
|
||||
op.execute("ALTER TYPE searchsourceconnectortype ADD VALUE 'GITHUB_CONNECTOR'")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_enum
|
||||
WHERE enumlabel = 'GITHUB_CONNECTOR'
|
||||
AND enumtypid = (
|
||||
SELECT oid FROM pg_type WHERE typname = 'searchsourceconnectortype'
|
||||
)
|
||||
) THEN
|
||||
ALTER TYPE searchsourceconnectortype ADD VALUE 'GITHUB_CONNECTOR';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Pass for the rest, as autogenerate didn't run to add other schema details
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -33,20 +50,23 @@ def upgrade() -> None:
|
|||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
|
||||
# Downgrading removal of an enum value is complex and potentially dangerous
|
||||
# if the value is in use. Often omitted or requires manual SQL based on context.
|
||||
# For now, we'll just pass. If you needed to reverse this, you'd likely
|
||||
# For now, we'll just pass. If you needed to reverse this, you'd likely
|
||||
# have to manually check if 'GITHUB_CONNECTOR' is used in the table
|
||||
# and then potentially recreate the type without it.
|
||||
op.execute("ALTER TYPE searchsourceconnectortype RENAME TO searchsourceconnectortype_old")
|
||||
op.execute("CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR')")
|
||||
op.execute((
|
||||
op.execute(
|
||||
"ALTER TYPE searchsourceconnectortype RENAME TO searchsourceconnectortype_old"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR')"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||
"connector_type::text::searchsourceconnectortype"
|
||||
))
|
||||
)
|
||||
op.execute("DROP TYPE searchsourceconnectortype_old")
|
||||
|
||||
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -4,41 +4,55 @@ Revision ID: 2
|
|||
Revises: e55302644c51
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2'
|
||||
down_revision: Union[str, None] = 'e55302644c51'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "2"
|
||||
down_revision: str | None = "e55302644c51"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Manually add the command to add the enum value
|
||||
op.execute("ALTER TYPE searchsourceconnectortype ADD VALUE 'LINEAR_CONNECTOR'")
|
||||
|
||||
# Pass for the rest, as autogenerate didn't run to add other schema details
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = 'LINEAR_CONNECTOR'
|
||||
AND enumtypid = (
|
||||
SELECT oid FROM pg_type WHERE typname = 'searchsourceconnectortype'
|
||||
)
|
||||
) THEN
|
||||
ALTER TYPE searchsourceconnectortype ADD VALUE 'LINEAR_CONNECTOR';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
|
||||
# Downgrading removal of an enum value requires recreating the type
|
||||
op.execute("ALTER TYPE searchsourceconnectortype RENAME TO searchsourceconnectortype_old")
|
||||
op.execute("CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR')")
|
||||
op.execute((
|
||||
op.execute(
|
||||
"ALTER TYPE searchsourceconnectortype RENAME TO searchsourceconnectortype_old"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR')"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||
"connector_type::text::searchsourceconnectortype"
|
||||
))
|
||||
)
|
||||
op.execute("DROP TYPE searchsourceconnectortype_old")
|
||||
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -4,26 +4,41 @@ Revision ID: 3
|
|||
Revises: 2
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '3'
|
||||
down_revision: Union[str, None] = '2'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "3"
|
||||
down_revision: str | None = "2"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Define the ENUM type name and the new value
|
||||
ENUM_NAME = 'documenttype' # Make sure this matches the name in your DB (usually lowercase class name)
|
||||
NEW_VALUE = 'LINEAR_CONNECTOR'
|
||||
ENUM_NAME = "documenttype" # Make sure this matches the name in your DB (usually lowercase class name)
|
||||
NEW_VALUE = "LINEAR_CONNECTOR"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.execute(f"ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}'")
|
||||
|
||||
op.execute(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = '{NEW_VALUE}'
|
||||
AND enumtypid = (
|
||||
SELECT oid FROM pg_type WHERE typname = '{ENUM_NAME}'
|
||||
)
|
||||
) THEN
|
||||
ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Warning: This will delete all rows with the new value
|
||||
def downgrade() -> None:
|
||||
|
|
@ -34,19 +49,19 @@ def downgrade() -> None:
|
|||
|
||||
# Enum values *before* LINEAR_CONNECTOR was added
|
||||
old_values = (
|
||||
'EXTENSION',
|
||||
'CRAWLED_URL',
|
||||
'FILE',
|
||||
'SLACK_CONNECTOR',
|
||||
'NOTION_CONNECTOR',
|
||||
'YOUTUBE_VIDEO',
|
||||
'GITHUB_CONNECTOR'
|
||||
"EXTENSION",
|
||||
"CRAWLED_URL",
|
||||
"FILE",
|
||||
"SLACK_CONNECTOR",
|
||||
"NOTION_CONNECTOR",
|
||||
"YOUTUBE_VIDEO",
|
||||
"GITHUB_CONNECTOR",
|
||||
)
|
||||
old_values_sql = ", ".join([f"'{v}'" for v in old_values])
|
||||
|
||||
# Table and column names (adjust if different)
|
||||
table_name = 'documents'
|
||||
column_name = 'document_type'
|
||||
table_name = "documents"
|
||||
column_name = "document_type"
|
||||
|
||||
# 1. Rename the current enum type
|
||||
op.execute(f"ALTER TYPE {ENUM_NAME} RENAME TO {old_enum_name}")
|
||||
|
|
@ -54,10 +69,8 @@ def downgrade() -> None:
|
|||
# 2. Create the new enum type with the old values
|
||||
op.execute(f"CREATE TYPE {ENUM_NAME} AS ENUM({old_values_sql})")
|
||||
|
||||
# 3. Update the table:
|
||||
op.execute(
|
||||
f"DELETE FROM {table_name} WHERE {column_name}::text = '{NEW_VALUE}'"
|
||||
)
|
||||
# 3. Update the table:
|
||||
op.execute(f"DELETE FROM {table_name} WHERE {column_name}::text = '{NEW_VALUE}'")
|
||||
|
||||
# 4. Alter the column to use the new enum type (casting old values)
|
||||
op.execute(
|
||||
|
|
@ -67,4 +80,4 @@ def downgrade() -> None:
|
|||
|
||||
# 5. Drop the old enum type
|
||||
op.execute(f"DROP TYPE {old_enum_name}")
|
||||
# ### end Alembic commands ###
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -4,25 +4,24 @@ Revision ID: 4
|
|||
Revises: 3
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '4'
|
||||
down_revision: Union[str, None] = '3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "4"
|
||||
down_revision: str | None = "3"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
|
||||
# Manually add the command to add the enum value
|
||||
op.execute("ALTER TYPE searchsourceconnectortype ADD VALUE 'LINKUP_API'")
|
||||
|
||||
|
||||
# Pass for the rest, as autogenerate didn't run to add other schema details
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -30,15 +29,19 @@ def upgrade() -> None:
|
|||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
|
||||
# Downgrading removal of an enum value requires recreating the type
|
||||
op.execute("ALTER TYPE searchsourceconnectortype RENAME TO searchsourceconnectortype_old")
|
||||
op.execute("CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR', 'LINEAR_CONNECTOR')")
|
||||
op.execute((
|
||||
op.execute(
|
||||
"ALTER TYPE searchsourceconnectortype RENAME TO searchsourceconnectortype_old"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TYPE searchsourceconnectortype AS ENUM('SERPER_API', 'TAVILY_API', 'SLACK_CONNECTOR', 'NOTION_CONNECTOR', 'GITHUB_CONNECTOR', 'LINEAR_CONNECTOR')"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE search_source_connectors ALTER COLUMN connector_type TYPE searchsourceconnectortype USING "
|
||||
"connector_type::text::searchsourceconnectortype"
|
||||
))
|
||||
)
|
||||
op.execute("DROP TYPE searchsourceconnectortype_old")
|
||||
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -4,54 +4,73 @@ Revision ID: 5
|
|||
Revises: 4
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '5'
|
||||
down_revision: Union[str, None] = '4'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "5"
|
||||
down_revision: str | None = "4"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Alter Chat table
|
||||
op.alter_column('chats', 'title',
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False)
|
||||
|
||||
op.alter_column(
|
||||
"chats",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Alter Document table
|
||||
op.alter_column('documents', 'title',
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False)
|
||||
|
||||
op.alter_column(
|
||||
"documents",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Alter Podcast table
|
||||
op.alter_column('podcasts', 'title',
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False)
|
||||
op.alter_column(
|
||||
"podcasts",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert Chat table
|
||||
op.alter_column('chats', 'title',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False)
|
||||
|
||||
op.alter_column(
|
||||
"chats",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Revert Document table
|
||||
op.alter_column('documents', 'title',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False)
|
||||
|
||||
op.alter_column(
|
||||
"documents",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Revert Podcast table
|
||||
op.alter_column('podcasts', 'title',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False)
|
||||
op.alter_column(
|
||||
"podcasts",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,40 +4,59 @@ Revision ID: 6
|
|||
Revises: 5
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '6'
|
||||
down_revision: Union[str, None] = '5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "6"
|
||||
down_revision: str | None = "5"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the old column and create a new one with the new name and type
|
||||
# We need to do this because PostgreSQL doesn't support direct column renames with type changes
|
||||
op.add_column('podcasts', sa.Column('podcast_transcript', JSON, nullable=False, server_default='{}'))
|
||||
|
||||
# Copy data from old column to new column
|
||||
# Convert text to JSON by storing it as a JSON string value
|
||||
op.execute("UPDATE podcasts SET podcast_transcript = jsonb_build_object('text', podcast_content) WHERE podcast_content != ''")
|
||||
|
||||
# Drop the old column
|
||||
op.drop_column('podcasts', 'podcast_content')
|
||||
bind = op.get_bind()
|
||||
inspector = inspect(bind)
|
||||
|
||||
columns = [col["name"] for col in inspector.get_columns("podcasts")]
|
||||
if "podcast_transcript" not in columns:
|
||||
op.add_column(
|
||||
"podcasts",
|
||||
sa.Column("podcast_transcript", JSON, nullable=False, server_default="{}"),
|
||||
)
|
||||
|
||||
# Copy data from old column to new column
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE podcasts
|
||||
SET podcast_transcript = jsonb_build_object('text', podcast_content)
|
||||
WHERE podcast_content != ''
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the old column only if it exists
|
||||
if "podcast_content" in columns:
|
||||
op.drop_column("podcasts", "podcast_content")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the original column
|
||||
op.add_column('podcasts', sa.Column('podcast_content', sa.Text(), nullable=False, server_default=''))
|
||||
|
||||
op.add_column(
|
||||
"podcasts",
|
||||
sa.Column("podcast_content", sa.Text(), nullable=False, server_default=""),
|
||||
)
|
||||
|
||||
# Copy data from JSON column back to text column
|
||||
# Extract the 'text' field if it exists, otherwise use empty string
|
||||
op.execute("UPDATE podcasts SET podcast_content = COALESCE((podcast_transcript->>'text'), '')")
|
||||
|
||||
op.execute(
|
||||
"UPDATE podcasts SET podcast_content = COALESCE((podcast_transcript->>'text'), '')"
|
||||
)
|
||||
|
||||
# Drop the new column
|
||||
op.drop_column('podcasts', 'podcast_transcript')
|
||||
op.drop_column("podcasts", "podcast_transcript")
|
||||
|
|
|
|||
|
|
@ -4,24 +4,35 @@ Revision ID: 7
|
|||
Revises: 6
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7'
|
||||
down_revision: Union[str, None] = '6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "7"
|
||||
down_revision: str | None = "6"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the is_generated column
|
||||
op.drop_column('podcasts', 'is_generated')
|
||||
# Get the current database connection
|
||||
bind = op.get_bind()
|
||||
inspector = inspect(bind)
|
||||
|
||||
# Check if the column exists before attempting to drop it
|
||||
columns = [col["name"] for col in inspector.get_columns("podcasts")]
|
||||
if "is_generated" in columns:
|
||||
op.drop_column("podcasts", "is_generated")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the is_generated column with its original constraints
|
||||
op.add_column('podcasts', sa.Column('is_generated', sa.Boolean(), nullable=False, server_default='false'))
|
||||
op.add_column(
|
||||
"podcasts",
|
||||
sa.Column("is_generated", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,54 +3,69 @@
|
|||
Revision ID: 8
|
||||
Revises: 7
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '8'
|
||||
down_revision: Union[str, None] = '7'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "8"
|
||||
down_revision: str | None = "7"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add content_hash column as nullable first to handle existing data
|
||||
op.add_column('documents', sa.Column('content_hash', sa.String(), nullable=True))
|
||||
|
||||
# Update existing documents to generate content hashes
|
||||
# Using SHA-256 hash of the content column with proper UTF-8 encoding
|
||||
op.execute("""
|
||||
UPDATE documents
|
||||
SET content_hash = encode(sha256(convert_to(content, 'UTF8')), 'hex')
|
||||
WHERE content_hash IS NULL
|
||||
""")
|
||||
|
||||
# Handle duplicate content hashes by keeping only the oldest document for each hash
|
||||
# Delete newer documents with duplicate content hashes
|
||||
op.execute("""
|
||||
DELETE FROM documents
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
FROM documents
|
||||
GROUP BY content_hash
|
||||
bind = op.get_bind()
|
||||
inspector = inspect(bind)
|
||||
columns = [col["name"] for col in inspector.get_columns("documents")]
|
||||
|
||||
# Only add the column if it doesn't already exist
|
||||
if "content_hash" not in columns:
|
||||
op.add_column(
|
||||
"documents", sa.Column("content_hash", sa.String(), nullable=True)
|
||||
)
|
||||
""")
|
||||
|
||||
# Now alter the column to match the model: nullable=False, index=True, unique=True
|
||||
op.alter_column('documents', 'content_hash',
|
||||
existing_type=sa.String(),
|
||||
nullable=False)
|
||||
op.create_index(op.f('ix_documents_content_hash'), 'documents', ['content_hash'], unique=False)
|
||||
op.create_unique_constraint(op.f('uq_documents_content_hash'), 'documents', ['content_hash'])
|
||||
|
||||
# Populate the content_hash column
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE documents
|
||||
SET content_hash = encode(sha256(convert_to(content, 'UTF8')), 'hex')
|
||||
WHERE content_hash IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM documents
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
FROM documents
|
||||
GROUP BY content_hash
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"documents", "content_hash", existing_type=sa.String(), nullable=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_documents_content_hash"),
|
||||
"documents",
|
||||
["content_hash"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
op.f("uq_documents_content_hash"), "documents", ["content_hash"]
|
||||
)
|
||||
else:
|
||||
print("Column 'content_hash' already exists. Skipping column creation.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove constraints and index first
|
||||
op.drop_constraint(op.f('uq_documents_content_hash'), 'documents', type_='unique')
|
||||
op.drop_index(op.f('ix_documents_content_hash'), table_name='documents')
|
||||
|
||||
# Remove content_hash column from documents table
|
||||
op.drop_column('documents', 'content_hash')
|
||||
op.drop_constraint(op.f("uq_documents_content_hash"), "documents", type_="unique")
|
||||
op.drop_index(op.f("ix_documents_content_hash"), table_name="documents")
|
||||
op.drop_column("documents", "content_hash")
|
||||
|
|
|
|||
|
|
@ -4,17 +4,15 @@ Revision ID: 9
|
|||
Revises: 8
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9"
|
||||
down_revision: Union[str, None] = "8"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
down_revision: str | None = "8"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Define the ENUM type name and the new value
|
||||
CONNECTOR_ENUM = "searchsourceconnectortype"
|
||||
|
|
@ -24,11 +22,38 @@ DOCUMENT_NEW_VALUE = "DISCORD_CONNECTOR"
|
|||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema - add DISCORD_CONNECTOR to connector and document enum."""
|
||||
# Add DISCORD_CONNECTOR to searchsourceconnectortype
|
||||
op.execute(f"ALTER TYPE {CONNECTOR_ENUM} ADD VALUE '{CONNECTOR_NEW_VALUE}'")
|
||||
# Add DISCORD_CONNECTOR to documenttype
|
||||
op.execute(f"ALTER TYPE {DOCUMENT_ENUM} ADD VALUE '{DOCUMENT_NEW_VALUE}'")
|
||||
"""Upgrade schema - add DISCORD_CONNECTOR to connector and document enum safely."""
|
||||
# Add DISCORD_CONNECTOR to searchsourceconnectortype only if not exists
|
||||
op.execute(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = '{CONNECTOR_NEW_VALUE}'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = '{CONNECTOR_ENUM}')
|
||||
) THEN
|
||||
ALTER TYPE {CONNECTOR_ENUM} ADD VALUE '{CONNECTOR_NEW_VALUE}';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add DISCORD_CONNECTOR to documenttype only if not exists
|
||||
op.execute(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = '{DOCUMENT_NEW_VALUE}'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = '{DOCUMENT_ENUM}')
|
||||
) THEN
|
||||
ALTER TYPE {DOCUMENT_ENUM} ADD VALUE '{DOCUMENT_NEW_VALUE}';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
|
@ -85,7 +110,6 @@ def downgrade() -> None:
|
|||
# 4. Drop the old connector enum type
|
||||
op.execute(f"DROP TYPE {old_connector_enum_name}")
|
||||
|
||||
|
||||
# Document Enum Downgrade Steps
|
||||
# 1. Rename the current document enum type
|
||||
op.execute(f"ALTER TYPE {DOCUMENT_ENUM} RENAME TO {old_document_enum_name}")
|
||||
|
|
|
|||
|
|
@ -1,69 +1,67 @@
|
|||
"""Add GITHUB_CONNECTOR to DocumentType enum
|
||||
|
||||
Revision ID: e55302644c51
|
||||
Revises: 1
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e55302644c51'
|
||||
down_revision: Union[str, None] = '1'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "e55302644c51"
|
||||
down_revision: str | None = "1"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
# Define the ENUM type name and the new value
|
||||
ENUM_NAME = 'documenttype' # Make sure this matches the name in your DB (usually lowercase class name)
|
||||
NEW_VALUE = 'GITHUB_CONNECTOR'
|
||||
ENUM_NAME = "documenttype"
|
||||
NEW_VALUE = "GITHUB_CONNECTOR"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.execute(f"ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}'")
|
||||
|
||||
op.execute(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = '{NEW_VALUE}'
|
||||
AND enumtypid = (
|
||||
SELECT oid FROM pg_type WHERE typname = '{ENUM_NAME}'
|
||||
)
|
||||
) THEN
|
||||
ALTER TYPE {ENUM_NAME} ADD VALUE '{NEW_VALUE}';
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Warning: This will delete all rows with the new value
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema - remove GITHUB_CONNECTOR from enum."""
|
||||
|
||||
# The old type name
|
||||
old_enum_name = f"{ENUM_NAME}_old"
|
||||
|
||||
# Enum values *before* GITHUB_CONNECTOR was added
|
||||
old_values = (
|
||||
'EXTENSION',
|
||||
'CRAWLED_URL',
|
||||
'FILE',
|
||||
'SLACK_CONNECTOR',
|
||||
'NOTION_CONNECTOR',
|
||||
'YOUTUBE_VIDEO'
|
||||
"EXTENSION",
|
||||
"CRAWLED_URL",
|
||||
"FILE",
|
||||
"SLACK_CONNECTOR",
|
||||
"NOTION_CONNECTOR",
|
||||
"YOUTUBE_VIDEO",
|
||||
)
|
||||
old_values_sql = ", ".join([f"'{v}'" for v in old_values])
|
||||
|
||||
# Table and column names (adjust if different)
|
||||
table_name = 'documents'
|
||||
column_name = 'document_type'
|
||||
table_name = "documents"
|
||||
column_name = "document_type"
|
||||
|
||||
# 1. Rename the current enum type
|
||||
op.execute(f"ALTER TYPE {ENUM_NAME} RENAME TO {old_enum_name}")
|
||||
# 1. Create the new enum type with the old values
|
||||
op.execute(f"CREATE TYPE {old_enum_name} AS ENUM({old_values_sql})")
|
||||
|
||||
# 2. Create the new enum type with the old values
|
||||
op.execute(f"CREATE TYPE {ENUM_NAME} AS ENUM({old_values_sql})")
|
||||
# 2. Delete rows using the new value
|
||||
op.execute(f"DELETE FROM {table_name} WHERE {column_name}::text = '{NEW_VALUE}'")
|
||||
|
||||
# 3. Update the table:
|
||||
op.execute(
|
||||
f"DELETE FROM {table_name} WHERE {column_name}::text = '{NEW_VALUE}'"
|
||||
)
|
||||
|
||||
# 4. Alter the column to use the new enum type (casting old values)
|
||||
# 3. Alter the column to use the old enum type
|
||||
op.execute(
|
||||
f"ALTER TABLE {table_name} ALTER COLUMN {column_name} "
|
||||
f"TYPE {ENUM_NAME} USING {column_name}::text::{ENUM_NAME}"
|
||||
f"TYPE {old_enum_name} USING {column_name}::text::{old_enum_name}"
|
||||
)
|
||||
|
||||
# 5. Drop the old enum type
|
||||
op.execute(f"DROP TYPE {old_enum_name}")
|
||||
# ### end Alembic commands ###
|
||||
# 4. Drop the current enum type and rename the old one
|
||||
op.execute(f"DROP TYPE {ENUM_NAME}")
|
||||
op.execute(f"ALTER TYPE {old_enum_name} RENAME TO {ENUM_NAME}")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
|
@ -17,11 +16,11 @@ class Configuration:
|
|||
# create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/)
|
||||
# and when you invoke the graph
|
||||
podcast_title: str
|
||||
user_id: str
|
||||
user_id: str
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: Optional[RunnableConfig] = None
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
from langgraph.graph import StateGraph
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import create_merged_podcast_audio, create_podcast_transcript
|
||||
from .state import State
|
||||
|
||||
|
||||
from .nodes import create_merged_podcast_audio, create_podcast_transcript
|
||||
|
||||
|
||||
def build_graph():
|
||||
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
|
|
@ -24,8 +21,9 @@ def build_graph():
|
|||
# Compile the workflow into an executable graph
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith
|
||||
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# Compile the graph once when the module is loaded
|
||||
graph = build_graph()
|
||||
|
|
|
|||
|
|
@ -1,148 +1,154 @@
|
|||
from typing import Any, Dict
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from ffmpeg.asyncio import FFmpeg
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from litellm import aspeech
|
||||
from ffmpeg.asyncio import FFmpeg
|
||||
|
||||
from .configuration import Configuration
|
||||
from .state import PodcastTranscriptEntry, State, PodcastTranscripts
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
from app.config import config as app_config
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
from .state import PodcastTranscriptEntry, PodcastTranscripts, State
|
||||
|
||||
async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
|
||||
async def create_podcast_transcript(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Each node does work."""
|
||||
|
||||
|
||||
# Get configuration from runnable config
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
user_id = configuration.user_id
|
||||
|
||||
|
||||
# Get user's long context LLM
|
||||
llm = await get_user_long_context_llm(state.db_session, user_id)
|
||||
if not llm:
|
||||
error_message = f"No long context LLM configured for user {user_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
||||
# Get the prompt
|
||||
prompt = get_podcast_generation_prompt()
|
||||
|
||||
|
||||
# Create the messages
|
||||
messages = [
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(content=f"<source_content>{state.source_content}</source_content>")
|
||||
HumanMessage(
|
||||
content=f"<source_content>{state.source_content}</source_content>"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# Generate the podcast transcript
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
|
||||
|
||||
# First try the direct approach
|
||||
try:
|
||||
podcast_transcript = PodcastTranscripts.model_validate(json.loads(llm_response.content))
|
||||
podcast_transcript = PodcastTranscripts.model_validate(
|
||||
json.loads(llm_response.content)
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {str(e)}")
|
||||
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||
|
||||
# Fallback: Parse the JSON response manually
|
||||
try:
|
||||
# Extract JSON content from the response
|
||||
content = llm_response.content
|
||||
|
||||
|
||||
# Find the JSON in the content (handle case where LLM might add additional text)
|
||||
json_start = content.find('{')
|
||||
json_end = content.rfind('}') + 1
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_str = content[json_start:json_end]
|
||||
|
||||
|
||||
# Parse the JSON string
|
||||
parsed_data = json.loads(json_str)
|
||||
|
||||
|
||||
# Convert to Pydantic model
|
||||
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
|
||||
|
||||
print(f"Successfully parsed podcast transcript using fallback approach")
|
||||
|
||||
print("Successfully parsed podcast transcript using fallback approach")
|
||||
else:
|
||||
# If JSON structure not found, raise a clear error
|
||||
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
||||
print(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e2:
|
||||
# Log the error and re-raise it
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {str(e2)}"
|
||||
print(f"Error parsing LLM response: {str(e2)}")
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||
print(f"Error parsing LLM response: {e2!s}")
|
||||
print(f"Raw response: {llm_response.content}")
|
||||
raise
|
||||
|
||||
return {
|
||||
"podcast_transcript": podcast_transcript.podcast_transcripts
|
||||
}
|
||||
|
||||
|
||||
async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
|
||||
return {"podcast_transcript": podcast_transcript.podcast_transcripts}
|
||||
|
||||
|
||||
async def create_merged_podcast_audio(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Generate audio for each transcript and merge them into a single podcast file."""
|
||||
|
||||
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
|
||||
|
||||
starting_transcript = PodcastTranscriptEntry(
|
||||
speaker_id=1,
|
||||
dialog=f"Welcome to {configuration.podcast_title} Podcast."
|
||||
speaker_id=1, dialog=f"Welcome to {configuration.podcast_title} Podcast."
|
||||
)
|
||||
|
||||
|
||||
transcript = state.podcast_transcript
|
||||
|
||||
|
||||
# Merge the starting transcript with the podcast transcript
|
||||
# Check if transcript is a PodcastTranscripts object or already a list
|
||||
if hasattr(transcript, 'podcast_transcripts'):
|
||||
if hasattr(transcript, "podcast_transcripts"):
|
||||
transcript_entries = transcript.podcast_transcripts
|
||||
else:
|
||||
transcript_entries = transcript
|
||||
|
||||
merged_transcript = [starting_transcript] + transcript_entries
|
||||
|
||||
|
||||
merged_transcript = [starting_transcript, *transcript_entries]
|
||||
|
||||
# Create a temporary directory for audio files
|
||||
temp_dir = Path("temp_audio")
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# Generate a unique session ID for this podcast
|
||||
session_id = str(uuid.uuid4())
|
||||
output_path = f"podcasts/{session_id}_podcast.mp3"
|
||||
os.makedirs("podcasts", exist_ok=True)
|
||||
|
||||
|
||||
# Map of speaker_id to voice
|
||||
voice_mapping = {
|
||||
0: "alloy", # Default/intro voice
|
||||
1: "echo", # First speaker
|
||||
1: "echo", # First speaker
|
||||
# 2: "fable", # Second speaker
|
||||
# 3: "onyx", # Third speaker
|
||||
# 4: "nova", # Fourth speaker
|
||||
# 5: "shimmer" # Fifth speaker
|
||||
}
|
||||
|
||||
|
||||
# Generate audio for each transcript segment
|
||||
audio_files = []
|
||||
|
||||
|
||||
async def generate_speech_for_segment(segment, index):
|
||||
# Handle both dictionary and PodcastTranscriptEntry objects
|
||||
if hasattr(segment, 'speaker_id'):
|
||||
if hasattr(segment, "speaker_id"):
|
||||
speaker_id = segment.speaker_id
|
||||
dialog = segment.dialog
|
||||
else:
|
||||
speaker_id = segment.get("speaker_id", 0)
|
||||
dialog = segment.get("dialog", "")
|
||||
|
||||
|
||||
# Select voice based on speaker_id
|
||||
voice = voice_mapping.get(speaker_id, "alloy")
|
||||
|
||||
|
||||
# Generate a unique filename for this segment
|
||||
filename = f"{temp_dir}/{session_id}_{index}.mp3"
|
||||
|
||||
|
||||
try:
|
||||
if app_config.TTS_SERVICE_API_BASE:
|
||||
response = await aspeech(
|
||||
|
|
@ -163,55 +169,61 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
|
|||
max_retries=2,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
|
||||
# Save the audio to a file - use proper streaming method
|
||||
with open(filename, 'wb') as f:
|
||||
with open(filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
|
||||
return filename
|
||||
except Exception as e:
|
||||
print(f"Error generating speech for segment {index}: {str(e)}")
|
||||
print(f"Error generating speech for segment {index}: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# Generate all audio files concurrently
|
||||
tasks = [generate_speech_for_segment(segment, i) for i, segment in enumerate(merged_transcript)]
|
||||
tasks = [
|
||||
generate_speech_for_segment(segment, i)
|
||||
for i, segment in enumerate(merged_transcript)
|
||||
]
|
||||
audio_files = await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
# Merge audio files using ffmpeg
|
||||
try:
|
||||
# Create FFmpeg instance with the first input
|
||||
ffmpeg = FFmpeg().option("y")
|
||||
|
||||
|
||||
# Add each audio file as input
|
||||
for audio_file in audio_files:
|
||||
ffmpeg = ffmpeg.input(audio_file)
|
||||
|
||||
|
||||
# Configure the concatenation and output
|
||||
filter_complex = []
|
||||
for i in range(len(audio_files)):
|
||||
filter_complex.append(f"[{i}:0]")
|
||||
|
||||
filter_complex_str = "".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
|
||||
|
||||
filter_complex_str = (
|
||||
"".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
|
||||
)
|
||||
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
|
||||
ffmpeg = ffmpeg.output(output_path, map="[outa]")
|
||||
|
||||
|
||||
# Execute FFmpeg
|
||||
await ffmpeg.execute()
|
||||
|
||||
|
||||
print(f"Successfully created podcast audio: {output_path}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error merging audio files: {str(e)}")
|
||||
print(f"Error merging audio files: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
for audio_file in audio_files:
|
||||
try:
|
||||
os.remove(audio_file)
|
||||
except:
|
||||
except Exception as e:
|
||||
print(f"Error removing audio file {audio_file}: {e!s}")
|
||||
pass
|
||||
|
||||
|
||||
return {
|
||||
"podcast_transcript": merged_transcript,
|
||||
"final_podcast_file_path": output_path
|
||||
"final_podcast_file_path": output_path,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,4 +108,4 @@ Output:
|
|||
|
||||
Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 5-minute audio duration.
|
||||
</podcast_generation_system>
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,14 +3,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class PodcastTranscriptEntry(BaseModel):
|
||||
"""
|
||||
Represents a single entry in a podcast transcript.
|
||||
"""
|
||||
|
||||
speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)")
|
||||
dialog: str = Field(..., description="The dialog text spoken by the speaker")
|
||||
|
||||
|
|
@ -19,10 +21,11 @@ class PodcastTranscripts(BaseModel):
|
|||
"""
|
||||
Represents the full podcast transcript structure.
|
||||
"""
|
||||
podcast_transcripts: List[PodcastTranscriptEntry] = Field(
|
||||
...,
|
||||
description="List of transcript entries with alternating speakers"
|
||||
)
|
||||
|
||||
podcast_transcripts: list[PodcastTranscriptEntry] = Field(
|
||||
..., description="List of transcript entries with alternating speakers"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
|
|
@ -32,8 +35,9 @@ class State:
|
|||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
source_content: str
|
||||
podcast_transcript: Optional[List[PodcastTranscriptEntry]] = None
|
||||
final_podcast_file_path: Optional[str] = None
|
||||
podcast_transcript: list[PodcastTranscriptEntry] | None = None
|
||||
final_podcast_file_path: str | None = None
|
||||
|
|
|
|||
|
|
@ -4,17 +4,20 @@ from __future__ import annotations
|
|||
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
class SearchMode(Enum):
|
||||
|
||||
class SearchMode(Enum):
|
||||
"""Enum defining the type of search mode."""
|
||||
|
||||
CHUNKS = "CHUNKS"
|
||||
DOCUMENTS = "DOCUMENTS"
|
||||
|
||||
|
||||
class ResearchMode(Enum):
|
||||
"""Enum defining the type of research mode."""
|
||||
|
||||
QNA = "QNA"
|
||||
REPORT_GENERAL = "REPORT_GENERAL"
|
||||
REPORT_DEEP = "REPORT_DEEP"
|
||||
|
|
@ -28,16 +31,16 @@ class Configuration:
|
|||
# Input parameters provided at invocation
|
||||
user_query: str
|
||||
num_sections: int
|
||||
connectors_to_search: List[str]
|
||||
connectors_to_search: list[str]
|
||||
user_id: str
|
||||
search_space_id: int
|
||||
search_mode: SearchMode
|
||||
research_mode: ResearchMode
|
||||
document_ids_to_add_in_context: List[int]
|
||||
document_ids_to_add_in_context: list[int]
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: Optional[RunnableConfig] = None
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
|
|
|
|||
|
|
@ -1,31 +1,41 @@
|
|||
from typing import Any, TypedDict
|
||||
|
||||
from langgraph.graph import StateGraph
|
||||
from .state import State
|
||||
from .nodes import reformulate_user_query, write_answer_outline, process_sections, handle_qna_workflow, generate_further_questions
|
||||
|
||||
from .configuration import Configuration, ResearchMode
|
||||
from typing import TypedDict, List, Dict, Any, Optional
|
||||
from .nodes import (
|
||||
generate_further_questions,
|
||||
handle_qna_workflow,
|
||||
process_sections,
|
||||
reformulate_user_query,
|
||||
write_answer_outline,
|
||||
)
|
||||
from .state import State
|
||||
|
||||
|
||||
# Define what keys are in our state dict
|
||||
class GraphState(TypedDict):
|
||||
# Intermediate data produced during workflow
|
||||
answer_outline: Optional[Any]
|
||||
answer_outline: Any | None
|
||||
# Final output
|
||||
final_written_report: Optional[str]
|
||||
final_written_report: str | None
|
||||
|
||||
|
||||
def build_graph():
|
||||
"""
|
||||
Build and return the LangGraph workflow.
|
||||
|
||||
|
||||
This function constructs the researcher agent graph with conditional routing
|
||||
based on research_mode - QNA mode uses a direct Q&A workflow while other modes
|
||||
use the full report generation pipeline. Both paths generate follow-up questions
|
||||
at the end using the reranked documents from the sub-agents.
|
||||
|
||||
|
||||
Returns:
|
||||
A compiled LangGraph workflow
|
||||
"""
|
||||
# Define a new graph with state class
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
|
||||
# Add nodes to the graph
|
||||
workflow.add_node("reformulate_user_query", reformulate_user_query)
|
||||
workflow.add_node("handle_qna_workflow", handle_qna_workflow)
|
||||
|
|
@ -35,41 +45,42 @@ def build_graph():
|
|||
|
||||
# Define the edges
|
||||
workflow.add_edge("__start__", "reformulate_user_query")
|
||||
|
||||
|
||||
# Add conditional edges from reformulate_user_query based on research mode
|
||||
def route_after_reformulate(state: State, config) -> str:
|
||||
"""Route based on research_mode after reformulating the query."""
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
|
||||
|
||||
if configuration.research_mode == ResearchMode.QNA.value:
|
||||
return "handle_qna_workflow"
|
||||
else:
|
||||
return "write_answer_outline"
|
||||
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"reformulate_user_query",
|
||||
route_after_reformulate,
|
||||
{
|
||||
"handle_qna_workflow": "handle_qna_workflow",
|
||||
"write_answer_outline": "write_answer_outline"
|
||||
}
|
||||
"write_answer_outline": "write_answer_outline",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# QNA workflow path: handle_qna_workflow -> generate_further_questions -> __end__
|
||||
workflow.add_edge("handle_qna_workflow", "generate_further_questions")
|
||||
|
||||
|
||||
# Report generation workflow path: write_answer_outline -> process_sections -> generate_further_questions -> __end__
|
||||
workflow.add_edge("write_answer_outline", "process_sections")
|
||||
workflow.add_edge("process_sections", "generate_further_questions")
|
||||
|
||||
|
||||
# Both paths end after generating further questions
|
||||
workflow.add_edge("generate_further_questions", "__end__")
|
||||
|
||||
# Compile the workflow into an executable graph
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith
|
||||
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# Compile the graph once when the module is loaded
|
||||
graph = build_graph()
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -221,4 +221,4 @@ Output:
|
|||
}}
|
||||
</examples>
|
||||
</further_questions_system>
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""QnA Agent.
|
||||
"""
|
||||
"""QnA Agent."""
|
||||
|
||||
from .graph import graph
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional, List, Any
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
|
@ -15,13 +15,15 @@ class Configuration:
|
|||
# Configuration parameters for the Q&A agent
|
||||
user_query: str # The user's question to answer
|
||||
reformulated_query: str # The reformulated query
|
||||
relevant_documents: List[Any] # Documents provided directly to the agent for answering
|
||||
relevant_documents: list[
|
||||
Any
|
||||
] # Documents provided directly to the agent for answering
|
||||
user_id: str # User identifier
|
||||
search_space_id: int # Search space identifier
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: Optional[RunnableConfig] = None
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from langgraph.graph import StateGraph
|
||||
from .state import State
|
||||
from .nodes import rerank_documents, answer_question
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import answer_question, rerank_documents
|
||||
from .state import State
|
||||
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,21 @@
|
|||
from app.services.reranker_service import RerankerService
|
||||
from .configuration import Configuration
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from .state import State
|
||||
from typing import Any, Dict
|
||||
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from app.services.reranker_service import RerankerService
|
||||
|
||||
from ..utils import (
|
||||
optimize_documents_for_token_limit,
|
||||
calculate_token_count,
|
||||
format_documents_section,
|
||||
optimize_documents_for_token_limit,
|
||||
)
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
||||
from .state import State
|
||||
|
||||
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""
|
||||
Rerank the documents based on relevance to the user's question.
|
||||
|
||||
|
|
@ -71,13 +74,13 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during reranking: {str(e)}")
|
||||
print(f"Error during reranking: {e!s}")
|
||||
# Use original docs if reranking fails
|
||||
|
||||
return {"reranked_documents": reranked_docs}
|
||||
|
||||
|
||||
async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""
|
||||
Answer the user's question using the provided documents.
|
||||
|
||||
|
|
@ -122,7 +125,8 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
|||
|
||||
# Use initial system prompt for token calculation
|
||||
initial_system_prompt = get_qna_citation_system_prompt()
|
||||
base_messages = state.chat_history + [
|
||||
base_messages = [
|
||||
*state.chat_history,
|
||||
SystemMessage(content=initial_system_prompt),
|
||||
HumanMessage(content=base_human_message_template),
|
||||
]
|
||||
|
|
@ -173,7 +177,8 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
|||
"""
|
||||
|
||||
# Create final messages for the LLM
|
||||
messages_with_chat_history = state.chat_history + [
|
||||
messages_with_chat_history = [
|
||||
*state.chat_history,
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=human_message_content),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ You are SurfSense, an advanced AI research assistant that provides detailed, wel
|
|||
- YOUTUBE_VIDEO: "YouTube video transcripts and metadata" (personally saved videos)
|
||||
- GITHUB_CONNECTOR: "GitHub repository content and issues" (personal repositories and interactions)
|
||||
- LINEAR_CONNECTOR: "Linear project issues and discussions" (personal project management)
|
||||
- DISCORD_CONNECTOR: "Discord server messages and channels" (personal community interactions)
|
||||
- JIRA_CONNECTOR: "Jira project issues, tickets, and comments" (personal project tracking)
|
||||
- DISCORD_CONNECTOR: "Discord server conversations and shared content" (personal community communications)
|
||||
- TAVILY_API: "Tavily search API results" (personalized search results)
|
||||
- LINKUP_API: "Linkup search API results" (personalized search results)
|
||||
</knowledge_sources>
|
||||
|
|
@ -71,7 +72,7 @@ You are SurfSense, an advanced AI research assistant that provides detailed, wel
|
|||
Python's asyncio library provides tools for writing concurrent code using the async/await syntax. It's particularly useful for I/O-bound and high-level structured network code.
|
||||
</content>
|
||||
</document>
|
||||
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>12</source_id>
|
||||
|
|
|
|||
|
|
@ -3,14 +3,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the dynamic state for the Q&A agent during execution.
|
||||
|
||||
This state tracks the database session, chat history, and the outputs
|
||||
This state tracks the database session, chat history, and the outputs
|
||||
generated by the agent's nodes during question answering.
|
||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
|
|
@ -18,8 +20,8 @@ class State:
|
|||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
|
||||
chat_history: Optional[List[Any]] = field(default_factory=list)
|
||||
|
||||
chat_history: list[Any] | None = field(default_factory=list)
|
||||
# OUTPUT: Populated by agent nodes
|
||||
reranked_documents: Optional[List[Any]] = None
|
||||
final_answer: Optional[str] = None
|
||||
reranked_documents: list[Any] | None = None
|
||||
final_answer: str | None = None
|
||||
|
|
|
|||
|
|
@ -3,10 +3,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.streaming_service import StreamingService
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the dynamic state for the agent during execution.
|
||||
|
|
@ -15,23 +18,23 @@ class State:
|
|||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context (not part of actual graph state)
|
||||
db_session: AsyncSession
|
||||
|
||||
|
||||
# Streaming service
|
||||
streaming_service: StreamingService
|
||||
|
||||
chat_history: Optional[List[Any]] = field(default_factory=list)
|
||||
|
||||
reformulated_query: Optional[str] = field(default=None)
|
||||
|
||||
chat_history: list[Any] | None = field(default_factory=list)
|
||||
|
||||
reformulated_query: str | None = field(default=None)
|
||||
# Using field to explicitly mark as part of state
|
||||
answer_outline: Optional[Any] = field(default=None)
|
||||
further_questions: Optional[Any] = field(default=None)
|
||||
|
||||
answer_outline: Any | None = field(default=None)
|
||||
further_questions: Any | None = field(default=None)
|
||||
|
||||
# Temporary field to hold reranked documents from sub-agents for further question generation
|
||||
reranked_documents: Optional[List[Any]] = field(default=None)
|
||||
|
||||
reranked_documents: list[Any] | None = field(default=None)
|
||||
|
||||
# OUTPUT: Populated by agent nodes
|
||||
# Using field to explicitly mark as part of state
|
||||
final_written_report: Optional[str] = field(default=None)
|
||||
|
||||
final_written_report: str | None = field(default=None)
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@ from __future__ import annotations
|
|||
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Any
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
class SubSectionType(Enum):
|
||||
"""Enum defining the type of sub-section."""
|
||||
|
||||
START = "START"
|
||||
MIDDLE = "MIDDLE"
|
||||
END = "END"
|
||||
|
|
@ -22,17 +23,16 @@ class Configuration:
|
|||
|
||||
# Input parameters provided at invocation
|
||||
sub_section_title: str
|
||||
sub_section_questions: List[str]
|
||||
sub_section_questions: list[str]
|
||||
sub_section_type: SubSectionType
|
||||
user_query: str
|
||||
relevant_documents: List[Any] # Documents provided directly to the agent
|
||||
relevant_documents: list[Any] # Documents provided directly to the agent
|
||||
user_id: str
|
||||
search_space_id: int
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: Optional[RunnableConfig] = None
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from langgraph.graph import StateGraph
|
||||
from .state import State
|
||||
from .nodes import write_sub_section, rerank_documents
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import rerank_documents, write_sub_section
|
||||
from .state import State
|
||||
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
from .configuration import Configuration
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from .state import State
|
||||
from typing import Any, Dict
|
||||
from app.services.reranker_service import RerankerService
|
||||
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from .configuration import SubSectionType
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from app.services.reranker_service import RerankerService
|
||||
|
||||
from ..utils import (
|
||||
optimize_documents_for_token_limit,
|
||||
calculate_token_count,
|
||||
format_documents_section,
|
||||
optimize_documents_for_token_limit,
|
||||
)
|
||||
from .configuration import Configuration, SubSectionType
|
||||
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
||||
from .state import State
|
||||
|
||||
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""
|
||||
Rerank the documents based on relevance to the sub-section title.
|
||||
|
||||
|
|
@ -79,13 +81,13 @@ async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error during reranking: {str(e)}")
|
||||
print(f"Error during reranking: {e!s}")
|
||||
# Use original docs if reranking fails
|
||||
|
||||
return {"reranked_documents": reranked_docs}
|
||||
|
||||
|
||||
async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""
|
||||
Write the sub-section using the provided documents.
|
||||
|
||||
|
|
@ -159,7 +161,8 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
|
||||
# Use initial system prompt for token calculation
|
||||
initial_system_prompt = get_citation_system_prompt()
|
||||
base_messages = state.chat_history + [
|
||||
base_messages = [
|
||||
*state.chat_history,
|
||||
SystemMessage(content=initial_system_prompt),
|
||||
HumanMessage(content=base_human_message_template),
|
||||
]
|
||||
|
|
@ -219,7 +222,8 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
"""
|
||||
|
||||
# Create final messages for the LLM
|
||||
messages_with_chat_history = state.chat_history + [
|
||||
messages_with_chat_history = [
|
||||
*state.chat_history,
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=human_message_content),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the dynamic state for the agent during execution.
|
||||
|
|
@ -14,11 +16,11 @@ class State:
|
|||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
|
||||
chat_history: Optional[List[Any]] = field(default_factory=list)
|
||||
# OUTPUT: Populated by agent nodes
|
||||
reranked_documents: Optional[List[Any]] = None
|
||||
final_answer: Optional[str] = None
|
||||
|
||||
chat_history: list[Any] | None = field(default_factory=list)
|
||||
# OUTPUT: Populated by agent nodes
|
||||
reranked_documents: list[Any] | None = None
|
||||
final_answer: str | None = None
|
||||
|
|
|
|||
|
|
@ -1,27 +1,37 @@
|
|||
from typing import List, Dict, Any, Tuple, NamedTuple
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from litellm import get_model_info, token_counter
|
||||
from pydantic import BaseModel, Field
|
||||
from litellm import token_counter, get_model_info
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""A section in the answer outline."""
|
||||
|
||||
section_id: int = Field(..., description="The zero-based index of the section")
|
||||
section_title: str = Field(..., description="The title of the section")
|
||||
questions: List[str] = Field(..., description="Questions to research for this section")
|
||||
questions: list[str] = Field(
|
||||
..., description="Questions to research for this section"
|
||||
)
|
||||
|
||||
|
||||
class AnswerOutline(BaseModel):
|
||||
"""The complete answer outline with all sections."""
|
||||
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
||||
|
||||
answer_outline: list[Section] = Field(
|
||||
..., description="List of sections in the answer outline"
|
||||
)
|
||||
|
||||
|
||||
class DocumentTokenInfo(NamedTuple):
|
||||
"""Information about a document and its token cost."""
|
||||
|
||||
index: int
|
||||
document: Dict[str, Any]
|
||||
document: dict[str, Any]
|
||||
formatted_content: str
|
||||
token_count: int
|
||||
|
||||
|
||||
|
||||
|
||||
def get_connector_emoji(connector_name: str) -> str:
|
||||
"""Get an appropriate emoji for a connector type."""
|
||||
connector_emojis = {
|
||||
|
|
@ -33,8 +43,10 @@ def get_connector_emoji(connector_name: str) -> str:
|
|||
"NOTION_CONNECTOR": "📘",
|
||||
"GITHUB_CONNECTOR": "🐙",
|
||||
"LINEAR_CONNECTOR": "📊",
|
||||
"JIRA_CONNECTOR": "🎫",
|
||||
"DISCORD_CONNECTOR": "🗨️",
|
||||
"TAVILY_API": "🔍",
|
||||
"LINKUP_API": "🔗"
|
||||
"LINKUP_API": "🔗",
|
||||
}
|
||||
return connector_emojis.get(connector_name, "🔎")
|
||||
|
||||
|
|
@ -50,32 +62,29 @@ def get_connector_friendly_name(connector_name: str) -> str:
|
|||
"NOTION_CONNECTOR": "Notion",
|
||||
"GITHUB_CONNECTOR": "GitHub",
|
||||
"LINEAR_CONNECTOR": "Linear",
|
||||
"JIRA_CONNECTOR": "Jira",
|
||||
"DISCORD_CONNECTOR": "Discord",
|
||||
"TAVILY_API": "Tavily Search",
|
||||
"LINKUP_API": "Linkup Search"
|
||||
"LINKUP_API": "Linkup Search",
|
||||
}
|
||||
return connector_friendly_names.get(connector_name, connector_name)
|
||||
|
||||
|
||||
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
||||
def convert_langchain_messages_to_dict(
|
||||
messages: list[BaseMessage],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert LangChain messages to format expected by token_counter."""
|
||||
role_mapping = {
|
||||
'system': 'system',
|
||||
'human': 'user',
|
||||
'ai': 'assistant'
|
||||
}
|
||||
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
|
||||
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = role_mapping.get(getattr(msg, 'type', None), 'user')
|
||||
converted_messages.append({
|
||||
"role": role,
|
||||
"content": str(msg.content)
|
||||
})
|
||||
role = role_mapping.get(getattr(msg, "type", None), "user")
|
||||
converted_messages.append({"role": role, "content": str(msg.content)})
|
||||
|
||||
return converted_messages
|
||||
|
||||
|
||||
def format_document_for_citation(document: Dict[str, Any]) -> str:
|
||||
def format_document_for_citation(document: dict[str, Any]) -> str:
|
||||
"""Format a single document for citation in the standard XML format."""
|
||||
content = document.get("content", "")
|
||||
doc_info = document.get("document", {})
|
||||
|
|
@ -93,7 +102,9 @@ def format_document_for_citation(document: Dict[str, Any]) -> str:
|
|||
</document>"""
|
||||
|
||||
|
||||
def format_documents_section(documents: List[Dict[str, Any]], section_title: str = "Source material") -> str:
|
||||
def format_documents_section(
|
||||
documents: list[dict[str, Any]], section_title: str = "Source material"
|
||||
) -> str:
|
||||
"""Format multiple documents into a complete documents section."""
|
||||
if not documents:
|
||||
return ""
|
||||
|
|
@ -106,7 +117,9 @@ def format_documents_section(documents: List[Dict[str, Any]], section_title: str
|
|||
</documents>"""
|
||||
|
||||
|
||||
def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) -> List[DocumentTokenInfo]:
|
||||
def calculate_document_token_costs(
|
||||
documents: list[dict[str, Any]], model: str
|
||||
) -> list[DocumentTokenInfo]:
|
||||
"""Pre-calculate token costs for each document."""
|
||||
document_token_info = []
|
||||
|
||||
|
|
@ -115,24 +128,24 @@ def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str)
|
|||
|
||||
# Calculate token count for this document
|
||||
token_count = token_counter(
|
||||
messages=[{"role": "user", "content": formatted_doc}],
|
||||
model=model
|
||||
messages=[{"role": "user", "content": formatted_doc}], model=model
|
||||
)
|
||||
|
||||
document_token_info.append(DocumentTokenInfo(
|
||||
index=i,
|
||||
document=doc,
|
||||
formatted_content=formatted_doc,
|
||||
token_count=token_count
|
||||
))
|
||||
document_token_info.append(
|
||||
DocumentTokenInfo(
|
||||
index=i,
|
||||
document=doc,
|
||||
formatted_content=formatted_doc,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
return document_token_info
|
||||
|
||||
|
||||
def find_optimal_documents_with_binary_search(
|
||||
document_tokens: List[DocumentTokenInfo],
|
||||
available_tokens: int
|
||||
) -> List[DocumentTokenInfo]:
|
||||
document_tokens: list[DocumentTokenInfo], available_tokens: int
|
||||
) -> list[DocumentTokenInfo]:
|
||||
"""Use binary search to find the maximum number of documents that fit within token limit."""
|
||||
if not document_tokens or available_tokens <= 0:
|
||||
return []
|
||||
|
|
@ -143,8 +156,7 @@ def find_optimal_documents_with_binary_search(
|
|||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
current_docs = document_tokens[:mid]
|
||||
current_token_sum = sum(
|
||||
doc_info.token_count for doc_info in current_docs)
|
||||
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
|
||||
|
||||
if current_token_sum <= available_tokens:
|
||||
optimal_docs = current_docs
|
||||
|
|
@ -159,20 +171,18 @@ def get_model_context_window(model_name: str) -> int:
|
|||
"""Get the total context window size for a model (input + output tokens)."""
|
||||
try:
|
||||
model_info = get_model_info(model_name)
|
||||
context_window = model_info.get(
|
||||
'max_input_tokens', 4096) # Default fallback
|
||||
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
|
||||
return context_window
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}")
|
||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
|
||||
)
|
||||
return 4096 # Conservative fallback
|
||||
|
||||
|
||||
def optimize_documents_for_token_limit(
|
||||
documents: List[Dict[str, Any]],
|
||||
base_messages: List[BaseMessage],
|
||||
model_name: str
|
||||
) -> Tuple[List[Dict[str, Any]], bool]:
|
||||
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
|
||||
) -> tuple[list[dict[str, Any]], bool]:
|
||||
"""
|
||||
Optimize documents to fit within token limits using binary search.
|
||||
|
||||
|
|
@ -197,7 +207,8 @@ def optimize_documents_for_token_limit(
|
|||
available_tokens_for_docs = context_window - base_tokens
|
||||
|
||||
print(
|
||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}")
|
||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
|
||||
)
|
||||
|
||||
if available_tokens_for_docs <= 0:
|
||||
print("No tokens available for documents after base content and output buffer")
|
||||
|
|
@ -208,8 +219,7 @@ def optimize_documents_for_token_limit(
|
|||
|
||||
# Find optimal number of documents using binary search
|
||||
optimal_doc_info = find_optimal_documents_with_binary_search(
|
||||
document_token_info,
|
||||
available_tokens_for_docs
|
||||
document_token_info, available_tokens_for_docs
|
||||
)
|
||||
|
||||
# Extract the original document objects
|
||||
|
|
@ -217,12 +227,13 @@ def optimize_documents_for_token_limit(
|
|||
has_documents_remaining = len(optimized_documents) > 0
|
||||
|
||||
print(
|
||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents")
|
||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
|
||||
)
|
||||
|
||||
return optimized_documents, has_documents_remaining
|
||||
|
||||
|
||||
def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int:
|
||||
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
|
||||
"""Calculate token count for a list of LangChain messages."""
|
||||
model = model_name
|
||||
messages_dict = convert_langchain_messages_to_dict(messages)
|
||||
|
|
|
|||
|
|
@ -2,22 +2,13 @@ from contextlib import asynccontextmanager
|
|||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
|
||||
|
||||
from app.routes import router as crud_router
|
||||
from app.config import config
|
||||
|
||||
from app.users import (
|
||||
SECRET,
|
||||
auth_backend,
|
||||
fastapi_users,
|
||||
current_active_user
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.routes import router as crud_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -64,12 +55,10 @@ app.include_router(
|
|||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from app.users import google_oauth_client
|
||||
|
||||
app.include_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
google_oauth_client,
|
||||
auth_backend,
|
||||
SECRET,
|
||||
is_verified_by_default=True
|
||||
google_oauth_client, auth_backend, SECRET, is_verified_by_default=True
|
||||
),
|
||||
prefix="/auth/google",
|
||||
tags=["auth"],
|
||||
|
|
@ -79,5 +68,8 @@ app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
|
|||
|
||||
|
||||
@app.get("/verify-token")
|
||||
async def authenticated_route(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)):
|
||||
async def authenticated_route(
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
return {"message": "Token is valid"}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from rerankers import Reranker
|
||||
|
||||
|
||||
# Get the base directory of the project
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
|
|
@ -18,37 +16,37 @@ load_dotenv(env_file)
|
|||
def is_ffmpeg_installed():
|
||||
"""
|
||||
Check if ffmpeg is installed on the current system.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if ffmpeg is installed, False otherwise.
|
||||
"""
|
||||
return shutil.which("ffmpeg") is not None
|
||||
|
||||
|
||||
|
||||
class Config:
|
||||
# Check if ffmpeg is installed
|
||||
if not is_ffmpeg_installed():
|
||||
import static_ffmpeg
|
||||
|
||||
# ffmpeg installed on first call to add_paths(), threadsafe.
|
||||
static_ffmpeg.add_paths()
|
||||
# check if ffmpeg is installed again
|
||||
if not is_ffmpeg_installed():
|
||||
raise ValueError("FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster.")
|
||||
|
||||
raise ValueError(
|
||||
"FFmpeg is not installed on the system. Please install it to use the Surfsense Podcaster."
|
||||
)
|
||||
|
||||
# Database
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
||||
|
||||
|
||||
|
||||
# AUTH: Google OAuth
|
||||
AUTH_TYPE = os.getenv("AUTH_TYPE")
|
||||
if AUTH_TYPE == "GOOGLE":
|
||||
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
|
||||
|
||||
|
||||
# LLM instances are now managed per-user through the LLMConfig system
|
||||
# Legacy environment variables removed in favor of user-specific configurations
|
||||
|
||||
|
|
@ -56,12 +54,12 @@ class Config:
|
|||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
|
||||
chunker_instance = RecursiveChunker(
|
||||
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512)
|
||||
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
|
||||
)
|
||||
code_chunker_instance = CodeChunker(
|
||||
chunk_size=getattr(embedding_model_instance, 'max_seq_length', 512)
|
||||
chunk_size=getattr(embedding_model_instance, "max_seq_length", 512)
|
||||
)
|
||||
|
||||
|
||||
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
|
||||
RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME")
|
||||
RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE")
|
||||
|
|
@ -69,45 +67,46 @@ class Config:
|
|||
model_name=RERANKERS_MODEL_NAME,
|
||||
model_type=RERANKERS_MODEL_TYPE,
|
||||
)
|
||||
|
||||
|
||||
# OAuth JWT
|
||||
SECRET_KEY = os.getenv("SECRET_KEY")
|
||||
|
||||
|
||||
# ETL Service
|
||||
ETL_SERVICE = os.getenv("ETL_SERVICE")
|
||||
|
||||
|
||||
if ETL_SERVICE == "UNSTRUCTURED":
|
||||
# Unstructured API Key
|
||||
UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY")
|
||||
|
||||
|
||||
elif ETL_SERVICE == "LLAMACLOUD":
|
||||
# LlamaCloud API Key
|
||||
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
|
||||
|
||||
# Firecrawl API Key
|
||||
FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None)
|
||||
|
||||
FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None)
|
||||
|
||||
# Litellm TTS Configuration
|
||||
TTS_SERVICE = os.getenv("TTS_SERVICE")
|
||||
TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE")
|
||||
TTS_SERVICE_API_KEY = os.getenv("TTS_SERVICE_API_KEY")
|
||||
|
||||
|
||||
# Litellm STT Configuration
|
||||
STT_SERVICE = os.getenv("STT_SERVICE")
|
||||
STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE")
|
||||
STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY")
|
||||
|
||||
|
||||
|
||||
# Validation Checks
|
||||
# Check embedding dimension
|
||||
if hasattr(embedding_model_instance, 'dimension') and embedding_model_instance.dimension > 2000:
|
||||
if (
|
||||
hasattr(embedding_model_instance, "dimension")
|
||||
and embedding_model_instance.dimension > 2000
|
||||
):
|
||||
raise ValueError(
|
||||
f"Embedding dimension for Model: {EMBEDDING_MODEL} "
|
||||
f"has {embedding_model_instance.dimension} dimensions, which "
|
||||
f"exceeds the maximum of 2000 allowed by PGVector."
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
"""Get all settings as a dictionary."""
|
||||
|
|
|
|||
|
|
@ -1,26 +1,25 @@
|
|||
import os
|
||||
|
||||
|
||||
def _parse_bool(value):
|
||||
"""Parse boolean value from string."""
|
||||
return value.lower() == "true" if value else False
|
||||
|
||||
|
||||
def _parse_int(value, var_name):
|
||||
"""Parse integer value with error handling."""
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid integer value for {var_name}: {value}")
|
||||
raise ValueError(f"Invalid integer value for {var_name}: {value}") from None
|
||||
|
||||
|
||||
def _parse_headers(value):
|
||||
"""Parse headers from comma-separated string."""
|
||||
try:
|
||||
return [
|
||||
tuple(h.split(":", 1))
|
||||
for h in value.split(",")
|
||||
if ":" in h
|
||||
]
|
||||
return [tuple(h.split(":", 1)) for h in value.split(",") if ":" in h]
|
||||
except Exception:
|
||||
raise ValueError(f"Invalid headers format: {value}")
|
||||
raise ValueError(f"Invalid headers format: {value}") from None
|
||||
|
||||
|
||||
def load_uvicorn_config(args=None):
|
||||
|
|
@ -28,16 +27,16 @@ def load_uvicorn_config(args=None):
|
|||
Load Uvicorn configuration from environment variables and CLI args.
|
||||
Returns a dict suitable for passing to uvicorn.Config.
|
||||
"""
|
||||
config_kwargs = dict(
|
||||
app="app.app:app",
|
||||
host=os.getenv("UVICORN_HOST", "0.0.0.0"),
|
||||
port=int(os.getenv("UVICORN_PORT", 8000)),
|
||||
log_level=os.getenv("UVICORN_LOG_LEVEL", "info"),
|
||||
reload=args.reload if args else False,
|
||||
reload_dirs=["app"] if (args and args.reload) else None,
|
||||
)
|
||||
|
||||
# Configuration mapping for advanced options
|
||||
config_kwargs = {
|
||||
"app": "app.app:app",
|
||||
"host": os.getenv("UVICORN_HOST", "0.0.0.0"),
|
||||
"port": int(os.getenv("UVICORN_PORT", 8000)),
|
||||
"log_level": os.getenv("UVICORN_LOG_LEVEL", "info"),
|
||||
"reload": args.reload if args else False,
|
||||
"reload_dirs": ["app"] if (args and args.reload) else None,
|
||||
}
|
||||
|
||||
# Configuration mapping for advanced options
|
||||
config_mapping = {
|
||||
"UVICORN_PROXY_HEADERS": ("proxy_headers", _parse_bool),
|
||||
"UVICORN_FORWARDED_ALLOW_IPS": ("forwarded_allow_ips", str),
|
||||
|
|
@ -51,15 +50,33 @@ def load_uvicorn_config(args=None):
|
|||
"UVICORN_LOG_CONFIG": ("log_config", str),
|
||||
"UVICORN_SERVER_HEADER": ("server_header", _parse_bool),
|
||||
"UVICORN_DATE_HEADER": ("date_header", _parse_bool),
|
||||
"UVICORN_LIMIT_CONCURRENCY": ("limit_concurrency", lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY")),
|
||||
"UVICORN_LIMIT_MAX_REQUESTS": ("limit_max_requests", lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS")),
|
||||
"UVICORN_TIMEOUT_KEEP_ALIVE": ("timeout_keep_alive", lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE")),
|
||||
"UVICORN_TIMEOUT_NOTIFY": ("timeout_notify", lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY")),
|
||||
"UVICORN_LIMIT_CONCURRENCY": (
|
||||
"limit_concurrency",
|
||||
lambda x: _parse_int(x, "UVICORN_LIMIT_CONCURRENCY"),
|
||||
),
|
||||
"UVICORN_LIMIT_MAX_REQUESTS": (
|
||||
"limit_max_requests",
|
||||
lambda x: _parse_int(x, "UVICORN_LIMIT_MAX_REQUESTS"),
|
||||
),
|
||||
"UVICORN_TIMEOUT_KEEP_ALIVE": (
|
||||
"timeout_keep_alive",
|
||||
lambda x: _parse_int(x, "UVICORN_TIMEOUT_KEEP_ALIVE"),
|
||||
),
|
||||
"UVICORN_TIMEOUT_NOTIFY": (
|
||||
"timeout_notify",
|
||||
lambda x: _parse_int(x, "UVICORN_TIMEOUT_NOTIFY"),
|
||||
),
|
||||
"UVICORN_SSL_KEYFILE": ("ssl_keyfile", str),
|
||||
"UVICORN_SSL_CERTFILE": ("ssl_certfile", str),
|
||||
"UVICORN_SSL_KEYFILE_PASSWORD": ("ssl_keyfile_password", str),
|
||||
"UVICORN_SSL_VERSION": ("ssl_version", lambda x: _parse_int(x, "UVICORN_SSL_VERSION")),
|
||||
"UVICORN_SSL_CERT_REQS": ("ssl_cert_reqs", lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS")),
|
||||
"UVICORN_SSL_VERSION": (
|
||||
"ssl_version",
|
||||
lambda x: _parse_int(x, "UVICORN_SSL_VERSION"),
|
||||
),
|
||||
"UVICORN_SSL_CERT_REQS": (
|
||||
"ssl_cert_reqs",
|
||||
lambda x: _parse_int(x, "UVICORN_SSL_CERT_REQS"),
|
||||
),
|
||||
"UVICORN_SSL_CA_CERTS": ("ssl_ca_certs", str),
|
||||
"UVICORN_SSL_CIPHERS": ("ssl_ciphers", str),
|
||||
"UVICORN_HEADERS": ("headers", _parse_headers),
|
||||
|
|
@ -76,7 +93,6 @@ def load_uvicorn_config(args=None):
|
|||
try:
|
||||
config_kwargs[config_key] = parser(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Configuration error for {env_var}: {e}")
|
||||
|
||||
raise ValueError(f"Configuration error for {env_var}: {e}") from e
|
||||
|
||||
return config_kwargs
|
||||
|
|
|
|||
|
|
@ -6,11 +6,12 @@ A module for interacting with Discord's HTTP API to retrieve guilds, channels, a
|
|||
Requires a Discord bot token.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import datetime
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||
class DiscordConnector(commands.Bot):
|
||||
"""Class for retrieving guild, channel, and message history from Discord."""
|
||||
|
||||
def __init__(self, token: str = None):
|
||||
def __init__(self, token: str | None = None):
|
||||
"""
|
||||
Initialize the DiscordConnector with a bot token.
|
||||
|
||||
|
|
@ -30,7 +31,9 @@ class DiscordConnector(commands.Bot):
|
|||
intents.messages = True # Required to fetch messages
|
||||
intents.message_content = True # Required to read message content
|
||||
intents.members = True # Required to fetch member information
|
||||
super().__init__(command_prefix="!", intents=intents) # command_prefix is required but not strictly used here
|
||||
super().__init__(
|
||||
command_prefix="!", intents=intents
|
||||
) # command_prefix is required but not strictly used here
|
||||
self.token = token
|
||||
self._bot_task = None # Holds the async bot task
|
||||
self._is_running = False # Flag to track if the bot is running
|
||||
|
|
@ -48,7 +51,7 @@ class DiscordConnector(commands.Bot):
|
|||
@self.event
|
||||
async def on_disconnect():
|
||||
logger.debug("Bot disconnected from Discord gateway.")
|
||||
self._is_running = False # Reset flag on disconnect
|
||||
self._is_running = False # Reset flag on disconnect
|
||||
|
||||
@self.event
|
||||
async def on_resumed():
|
||||
|
|
@ -63,17 +66,23 @@ class DiscordConnector(commands.Bot):
|
|||
|
||||
try:
|
||||
if self._is_running:
|
||||
logger.warning("Bot is already running. Use close_bot() to stop it before starting again.")
|
||||
logger.warning(
|
||||
"Bot is already running. Use close_bot() to stop it before starting again."
|
||||
)
|
||||
return
|
||||
|
||||
await self.start(self.token)
|
||||
logger.info("Discord bot started successfully.")
|
||||
except discord.LoginFailure:
|
||||
logger.error("Failed to log in: Invalid token was provided. Please check your bot token.")
|
||||
logger.error(
|
||||
"Failed to log in: Invalid token was provided. Please check your bot token."
|
||||
)
|
||||
self._is_running = False
|
||||
raise
|
||||
except discord.PrivilegedIntentsRequired as e:
|
||||
logger.error(f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page.")
|
||||
logger.error(
|
||||
f"Privileged Intents Required: {e}. Make sure all required intents are enabled in your bot's application page."
|
||||
)
|
||||
self._is_running = False
|
||||
raise
|
||||
except discord.ConnectionClosed as e:
|
||||
|
|
@ -96,7 +105,6 @@ class DiscordConnector(commands.Bot):
|
|||
else:
|
||||
logger.info("Bot is not running or already disconnected.")
|
||||
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""
|
||||
Set the discord bot token.
|
||||
|
|
@ -106,8 +114,10 @@ class DiscordConnector(commands.Bot):
|
|||
"""
|
||||
logger.info("Setting Discord bot token.")
|
||||
self.token = token
|
||||
logger.info("Token set successfully. You can now start the bot with start_bot().")
|
||||
|
||||
logger.info(
|
||||
"Token set successfully. You can now start the bot with start_bot()."
|
||||
)
|
||||
|
||||
async def _wait_until_ready(self):
|
||||
"""Helper to wait until the bot is connected and ready."""
|
||||
logger.info("Waiting for the bot to be ready...")
|
||||
|
|
@ -115,16 +125,20 @@ class DiscordConnector(commands.Bot):
|
|||
# Give the event loop a chance to switch to the bot's startup task.
|
||||
# This allows self.start() to begin initializing the client.
|
||||
# Terrible solution, but necessary to avoid blocking the event loop.
|
||||
await asyncio.sleep(1) # Yield control to the event loop
|
||||
|
||||
await asyncio.sleep(1) # Yield control to the event loop
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self.wait_until_ready(), timeout=60.0)
|
||||
logger.info("Bot is ready.")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Bot did not become ready within 60 seconds. Connection may have failed.")
|
||||
except TimeoutError:
|
||||
logger.error(
|
||||
"Bot did not become ready within 60 seconds. Connection may have failed."
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred while waiting for the bot to be ready: {e}")
|
||||
logger.error(
|
||||
f"An unexpected error occurred while waiting for the bot to be ready: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_guilds(self) -> list[dict]:
|
||||
|
|
@ -143,7 +157,9 @@ class DiscordConnector(commands.Bot):
|
|||
|
||||
guilds_data = []
|
||||
for guild in self.guilds:
|
||||
member_count = guild.member_count if guild.member_count is not None else "N/A"
|
||||
member_count = (
|
||||
guild.member_count if guild.member_count is not None else "N/A"
|
||||
)
|
||||
guilds_data.append(
|
||||
{
|
||||
"id": str(guild.id),
|
||||
|
|
@ -183,15 +199,17 @@ class DiscordConnector(commands.Bot):
|
|||
channels_data.append(
|
||||
{"id": str(channel.id), "name": channel.name, "type": "text"}
|
||||
)
|
||||
|
||||
logger.info(f"Fetched {len(channels_data)} text channels from guild {guild_id}.")
|
||||
|
||||
logger.info(
|
||||
f"Fetched {len(channels_data)} text channels from guild {guild_id}."
|
||||
)
|
||||
return channels_data
|
||||
|
||||
async def get_channel_history(
|
||||
self,
|
||||
channel_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Fetch message history from a text channel.
|
||||
|
|
@ -227,20 +245,26 @@ class DiscordConnector(commands.Bot):
|
|||
|
||||
if start_date:
|
||||
try:
|
||||
start_datetime = datetime.datetime.fromisoformat(start_date).replace(tzinfo=datetime.timezone.utc)
|
||||
start_datetime = datetime.datetime.fromisoformat(start_date).replace(
|
||||
tzinfo=datetime.UTC
|
||||
)
|
||||
after = start_datetime
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid start_date format: {start_date}. Ignoring.")
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(tzinfo=datetime.timezone.utc)
|
||||
end_datetime = datetime.datetime.fromisoformat(f"{end_date}").replace(
|
||||
tzinfo=datetime.UTC
|
||||
)
|
||||
before = end_datetime
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid end_date format: {end_date}. Ignoring.")
|
||||
|
||||
try:
|
||||
async for message in channel.history(limit=None, before=before, after=after):
|
||||
async for message in channel.history(
|
||||
limit=None, before=before, after=after
|
||||
):
|
||||
messages_data.append(
|
||||
{
|
||||
"id": str(message.id),
|
||||
|
|
@ -251,12 +275,14 @@ class DiscordConnector(commands.Bot):
|
|||
}
|
||||
)
|
||||
except discord.Forbidden:
|
||||
logger.error(f"Bot does not have permissions to read message history in channel {channel_id}.")
|
||||
logger.error(
|
||||
f"Bot does not have permissions to read message history in channel {channel_id}."
|
||||
)
|
||||
raise
|
||||
except discord.HTTPException as e:
|
||||
logger.error(f"Failed to fetch messages from channel {channel_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
logger.info(f"Fetched {len(messages_data)} messages from channel {channel_id}.")
|
||||
return messages_data
|
||||
|
||||
|
|
@ -278,7 +304,9 @@ class DiscordConnector(commands.Bot):
|
|||
permissions to view members.
|
||||
"""
|
||||
await self._wait_until_ready()
|
||||
logger.info(f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}")
|
||||
logger.info(
|
||||
f"Fetching user info for user ID: {user_id} in guild ID: {guild_id}"
|
||||
)
|
||||
|
||||
guild = self.get_guild(int(guild_id))
|
||||
if not guild:
|
||||
|
|
@ -294,7 +322,9 @@ class DiscordConnector(commands.Bot):
|
|||
return {
|
||||
"id": str(member.id),
|
||||
"name": member.name,
|
||||
"joined_at": member.joined_at.isoformat() if member.joined_at else None,
|
||||
"joined_at": member.joined_at.isoformat()
|
||||
if member.joined_at
|
||||
else None,
|
||||
"roles": roles,
|
||||
}
|
||||
logger.warning(f"User {user_id} not found in guild {guild_id}.")
|
||||
|
|
@ -303,8 +333,12 @@ class DiscordConnector(commands.Bot):
|
|||
logger.warning(f"User {user_id} not found in guild {guild_id}.")
|
||||
return None
|
||||
except discord.Forbidden:
|
||||
logger.error(f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled.")
|
||||
logger.error(
|
||||
f"Bot does not have permissions to fetch members in guild {guild_id}. Ensure GUILD_MEMBERS intent is enabled."
|
||||
)
|
||||
raise
|
||||
except discord.HTTPException as e:
|
||||
logger.error(f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}")
|
||||
logger.error(
|
||||
f"Failed to fetch user info for {user_id} in guild {guild_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,54 +1,91 @@
|
|||
import base64
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from github3 import login as github_login, exceptions as github_exceptions
|
||||
from github3.repos.contents import Contents
|
||||
from typing import Any
|
||||
|
||||
from github3 import exceptions as github_exceptions, login as github_login
|
||||
from github3.exceptions import ForbiddenError, NotFoundError
|
||||
from github3.repos.contents import Contents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of common code file extensions to target
|
||||
CODE_EXTENSIONS = {
|
||||
'.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp',
|
||||
'.cs', '.go', '.rb', '.php', '.swift', '.kt', '.scala', '.rs', '.m',
|
||||
'.sh', '.bash', '.ps1', '.lua', '.pl', '.pm', '.r', '.dart', '.sql'
|
||||
".py",
|
||||
".js",
|
||||
".jsx",
|
||||
".ts",
|
||||
".tsx",
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".rs",
|
||||
".m",
|
||||
".sh",
|
||||
".bash",
|
||||
".ps1",
|
||||
".lua",
|
||||
".pl",
|
||||
".pm",
|
||||
".r",
|
||||
".dart",
|
||||
".sql",
|
||||
}
|
||||
|
||||
# List of common documentation/text file extensions
|
||||
DOC_EXTENSIONS = {
|
||||
'.md', '.txt', '.rst', '.adoc', '.html', '.htm', '.xml', '.json', '.yaml', '.yml', '.toml'
|
||||
".md",
|
||||
".txt",
|
||||
".rst",
|
||||
".adoc",
|
||||
".html",
|
||||
".htm",
|
||||
".xml",
|
||||
".json",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
}
|
||||
|
||||
# Maximum file size in bytes (e.g., 1MB)
|
||||
MAX_FILE_SIZE = 1 * 1024 * 1024
|
||||
|
||||
|
||||
class GitHubConnector:
|
||||
"""Connector for interacting with the GitHub API."""
|
||||
|
||||
# Directories to skip during file traversal
|
||||
SKIPPED_DIRS = {
|
||||
# Version control
|
||||
'.git',
|
||||
".git",
|
||||
# Dependencies
|
||||
'node_modules',
|
||||
'vendor',
|
||||
"node_modules",
|
||||
"vendor",
|
||||
# Build artifacts / Caches
|
||||
'build',
|
||||
'dist',
|
||||
'target',
|
||||
'__pycache__',
|
||||
"build",
|
||||
"dist",
|
||||
"target",
|
||||
"__pycache__",
|
||||
# Virtual environments
|
||||
'venv',
|
||||
'.venv',
|
||||
'env',
|
||||
"venv",
|
||||
".venv",
|
||||
"env",
|
||||
# IDE/Editor config
|
||||
'.vscode',
|
||||
'.idea',
|
||||
'.project',
|
||||
'.settings',
|
||||
".vscode",
|
||||
".idea",
|
||||
".project",
|
||||
".settings",
|
||||
# Temporary / Logs
|
||||
'tmp',
|
||||
'logs',
|
||||
"tmp",
|
||||
"logs",
|
||||
# Add other project-specific irrelevant directories if needed
|
||||
}
|
||||
|
||||
|
|
@ -68,35 +105,39 @@ class GitHubConnector:
|
|||
logger.info("Successfully authenticated with GitHub API.")
|
||||
except (github_exceptions.AuthenticationFailed, ForbiddenError) as e:
|
||||
logger.error(f"GitHub authentication failed: {e}")
|
||||
raise ValueError("Invalid GitHub token or insufficient permissions.")
|
||||
raise ValueError("Invalid GitHub token or insufficient permissions.") from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GitHub client: {e}")
|
||||
raise
|
||||
raise e
|
||||
|
||||
def get_user_repositories(self) -> List[Dict[str, Any]]:
|
||||
def get_user_repositories(self) -> list[dict[str, Any]]:
|
||||
"""Fetches repositories accessible by the authenticated user."""
|
||||
repos_data = []
|
||||
try:
|
||||
# type='owner' fetches repos owned by the user
|
||||
# type='member' fetches repos the user is a collaborator on (including orgs)
|
||||
# type='all' fetches both
|
||||
for repo in self.gh.repositories(type='all', sort='updated'):
|
||||
repos_data.append({
|
||||
"id": repo.id,
|
||||
"name": repo.name,
|
||||
"full_name": repo.full_name,
|
||||
"private": repo.private,
|
||||
"url": repo.html_url,
|
||||
"description": repo.description or "",
|
||||
"last_updated": repo.updated_at if repo.updated_at else None,
|
||||
})
|
||||
for repo in self.gh.repositories(type="all", sort="updated"):
|
||||
repos_data.append(
|
||||
{
|
||||
"id": repo.id,
|
||||
"name": repo.name,
|
||||
"full_name": repo.full_name,
|
||||
"private": repo.private,
|
||||
"url": repo.html_url,
|
||||
"description": repo.description or "",
|
||||
"last_updated": repo.updated_at if repo.updated_at else None,
|
||||
}
|
||||
)
|
||||
logger.info(f"Fetched {len(repos_data)} repositories.")
|
||||
return repos_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch GitHub repositories: {e}")
|
||||
return [] # Return empty list on error
|
||||
return [] # Return empty list on error
|
||||
|
||||
def get_repository_files(self, repo_full_name: str, path: str = '') -> List[Dict[str, Any]]:
|
||||
def get_repository_files(
|
||||
self, repo_full_name: str, path: str = ""
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Recursively fetches details of relevant files (code, docs) within a repository path.
|
||||
|
||||
|
|
@ -110,54 +151,72 @@ class GitHubConnector:
|
|||
"""
|
||||
files_list = []
|
||||
try:
|
||||
owner, repo_name = repo_full_name.split('/')
|
||||
owner, repo_name = repo_full_name.split("/")
|
||||
repo = self.gh.repository(owner, repo_name)
|
||||
if not repo:
|
||||
logger.warning(f"Repository '{repo_full_name}' not found.")
|
||||
return []
|
||||
contents = repo.directory_contents(directory_path=path) # Use directory_contents for clarity
|
||||
|
||||
contents = repo.directory_contents(
|
||||
directory_path=path
|
||||
) # Use directory_contents for clarity
|
||||
|
||||
# contents returns a list of tuples (name, content_obj)
|
||||
for item_name, content_item in contents:
|
||||
for _item_name, content_item in contents:
|
||||
if not isinstance(content_item, Contents):
|
||||
continue
|
||||
|
||||
if content_item.type == 'dir':
|
||||
if content_item.type == "dir":
|
||||
# Check if the directory name is in the skipped list
|
||||
if content_item.name in self.SKIPPED_DIRS:
|
||||
logger.debug(f"Skipping directory: {content_item.path}")
|
||||
continue # Skip recursion for this directory
|
||||
|
||||
continue # Skip recursion for this directory
|
||||
|
||||
# Recursively fetch contents of subdirectory
|
||||
files_list.extend(self.get_repository_files(repo_full_name, path=content_item.path))
|
||||
elif content_item.type == 'file':
|
||||
files_list.extend(
|
||||
self.get_repository_files(
|
||||
repo_full_name, path=content_item.path
|
||||
)
|
||||
)
|
||||
elif content_item.type == "file":
|
||||
# Check if the file extension is relevant and size is within limits
|
||||
file_extension = '.' + content_item.name.split('.')[-1].lower() if '.' in content_item.name else ''
|
||||
file_extension = (
|
||||
"." + content_item.name.split(".")[-1].lower()
|
||||
if "." in content_item.name
|
||||
else ""
|
||||
)
|
||||
is_code = file_extension in CODE_EXTENSIONS
|
||||
is_doc = file_extension in DOC_EXTENSIONS
|
||||
|
||||
|
||||
if (is_code or is_doc) and content_item.size <= MAX_FILE_SIZE:
|
||||
files_list.append({
|
||||
"path": content_item.path,
|
||||
"sha": content_item.sha,
|
||||
"url": content_item.html_url,
|
||||
"size": content_item.size,
|
||||
"type": "code" if is_code else "doc"
|
||||
})
|
||||
files_list.append(
|
||||
{
|
||||
"path": content_item.path,
|
||||
"sha": content_item.sha,
|
||||
"url": content_item.html_url,
|
||||
"size": content_item.size,
|
||||
"type": "code" if is_code else "doc",
|
||||
}
|
||||
)
|
||||
elif content_item.size > MAX_FILE_SIZE:
|
||||
logger.debug(f"Skipping large file: {content_item.path} ({content_item.size} bytes)")
|
||||
logger.debug(
|
||||
f"Skipping large file: {content_item.path} ({content_item.size} bytes)"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Skipping irrelevant file type: {content_item.path}")
|
||||
logger.debug(
|
||||
f"Skipping irrelevant file type: {content_item.path}"
|
||||
)
|
||||
|
||||
except (NotFoundError, ForbiddenError) as e:
|
||||
logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}")
|
||||
logger.warning(f"Cannot access path '{path}' in '{repo_full_name}': {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get files for {repo_full_name} at path '{path}': {e}")
|
||||
logger.error(
|
||||
f"Failed to get files for {repo_full_name} at path '{path}': {e}"
|
||||
)
|
||||
# Return what we have collected so far in case of partial failure
|
||||
|
||||
|
||||
return files_list
|
||||
|
||||
def get_file_content(self, repo_full_name: str, file_path: str) -> Optional[str]:
|
||||
def get_file_content(self, repo_full_name: str, file_path: str) -> str | None:
|
||||
"""
|
||||
Fetches the decoded content of a specific file.
|
||||
|
||||
|
|
@ -169,43 +228,69 @@ class GitHubConnector:
|
|||
The decoded file content as a string, or None if fetching fails or file is too large.
|
||||
"""
|
||||
try:
|
||||
owner, repo_name = repo_full_name.split('/')
|
||||
owner, repo_name = repo_full_name.split("/")
|
||||
repo = self.gh.repository(owner, repo_name)
|
||||
if not repo:
|
||||
logger.warning(f"Repository '{repo_full_name}' not found when fetching file '{file_path}'.")
|
||||
logger.warning(
|
||||
f"Repository '{repo_full_name}' not found when fetching file '{file_path}'."
|
||||
)
|
||||
return None
|
||||
|
||||
content_item = repo.file_contents(path=file_path) # Use file_contents for clarity
|
||||
|
||||
if not content_item or not isinstance(content_item, Contents) or content_item.type != 'file':
|
||||
logger.warning(f"File '{file_path}' not found or is not a file in '{repo_full_name}'.")
|
||||
content_item = repo.file_contents(
|
||||
path=file_path
|
||||
) # Use file_contents for clarity
|
||||
|
||||
if (
|
||||
not content_item
|
||||
or not isinstance(content_item, Contents)
|
||||
or content_item.type != "file"
|
||||
):
|
||||
logger.warning(
|
||||
f"File '{file_path}' not found or is not a file in '{repo_full_name}'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
if content_item.size > MAX_FILE_SIZE:
|
||||
logger.warning(f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch.")
|
||||
logger.warning(
|
||||
f"File '{file_path}' in '{repo_full_name}' exceeds max size ({content_item.size} > {MAX_FILE_SIZE}). Skipping content fetch."
|
||||
)
|
||||
return None
|
||||
|
||||
# Content is base64 encoded
|
||||
if content_item.content:
|
||||
try:
|
||||
decoded_content = base64.b64decode(content_item.content).decode('utf-8')
|
||||
decoded_content = base64.b64decode(content_item.content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
return decoded_content
|
||||
except UnicodeDecodeError:
|
||||
logger.warning(f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'.")
|
||||
logger.warning(
|
||||
f"Could not decode file '{file_path}' in '{repo_full_name}' as UTF-8. Trying with 'latin-1'."
|
||||
)
|
||||
try:
|
||||
# Try a fallback encoding
|
||||
decoded_content = base64.b64decode(content_item.content).decode('latin-1')
|
||||
decoded_content = base64.b64decode(content_item.content).decode(
|
||||
"latin-1"
|
||||
)
|
||||
return decoded_content
|
||||
except Exception as decode_err:
|
||||
logger.error(f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}")
|
||||
return None # Give up if fallback fails
|
||||
logger.error(
|
||||
f"Failed to decode file '{file_path}' with fallback encoding: {decode_err}"
|
||||
)
|
||||
return None # Give up if fallback fails
|
||||
else:
|
||||
logger.warning(f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty.")
|
||||
return "" # Return empty string for empty files
|
||||
logger.warning(
|
||||
f"No content returned for file '{file_path}' in '{repo_full_name}'. It might be empty."
|
||||
)
|
||||
return "" # Return empty string for empty files
|
||||
|
||||
except (NotFoundError, ForbiddenError) as e:
|
||||
logger.warning(f"Cannot access file '{file_path}' in '{repo_full_name}': {e}")
|
||||
return None
|
||||
logger.warning(
|
||||
f"Cannot access file '{file_path}' in '{repo_full_name}': {e}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}")
|
||||
return None
|
||||
logger.error(
|
||||
f"Failed to get content for file '{file_path}' in '{repo_full_name}': {e}"
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
487
surfsense_backend/app/connectors/jira_connector.py
Normal file
487
surfsense_backend/app/connectors/jira_connector.py
Normal file
|
|
@ -0,0 +1,487 @@
|
|||
"""
|
||||
Jira Connector Module
|
||||
|
||||
A module for retrieving data from Jira.
|
||||
Allows fetching issue lists and their comments, projects and more.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class JiraConnector:
|
||||
"""Class for retrieving data from Jira."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
email: str | None = None,
|
||||
api_token: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the JiraConnector class.
|
||||
|
||||
Args:
|
||||
base_url: Jira instance base URL (e.g., 'https://yourcompany.atlassian.net') (optional)
|
||||
email: Jira account email address (optional)
|
||||
api_token: Jira API token (optional)
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/") if base_url else None
|
||||
self.email = email
|
||||
self.api_token = api_token
|
||||
self.api_version = "3" # Jira Cloud API version
|
||||
|
||||
def set_credentials(self, base_url: str, email: str, api_token: str) -> None:
|
||||
"""
|
||||
Set the Jira credentials.
|
||||
|
||||
Args:
|
||||
base_url: Jira instance base URL
|
||||
email: Jira account email address
|
||||
api_token: Jira API token
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.email = email
|
||||
self.api_token = api_token
|
||||
|
||||
def set_email(self, email: str) -> None:
|
||||
"""
|
||||
Set the Jira account email.
|
||||
|
||||
Args:
|
||||
email: Jira account email address
|
||||
"""
|
||||
self.email = email
|
||||
|
||||
def set_api_token(self, api_token: str) -> None:
|
||||
"""
|
||||
Set the Jira API token.
|
||||
|
||||
Args:
|
||||
api_token: Jira API token
|
||||
"""
|
||||
self.api_token = api_token
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for Jira API requests using Basic Authentication.
|
||||
|
||||
Returns:
|
||||
Dictionary of headers
|
||||
|
||||
Raises:
|
||||
ValueError: If email, api_token, or base_url have not been set
|
||||
"""
|
||||
if not all([self.base_url, self.email, self.api_token]):
|
||||
raise ValueError(
|
||||
"Jira credentials not initialized. Call set_credentials() first."
|
||||
)
|
||||
|
||||
# Create Basic Auth header using email:api_token
|
||||
auth_str = f"{self.email}:{self.api_token}"
|
||||
auth_bytes = auth_str.encode("utf-8")
|
||||
auth_header = "Basic " + base64.b64encode(auth_bytes).decode("ascii")
|
||||
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": auth_header,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def make_api_request(
|
||||
self, endpoint: str, params: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Make a request to the Jira API.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint (without base URL)
|
||||
params: Query parameters for the request (optional)
|
||||
|
||||
Returns:
|
||||
Response data from the API
|
||||
|
||||
Raises:
|
||||
ValueError: If email, api_token, or base_url have not been set
|
||||
Exception: If the API request fails
|
||||
"""
|
||||
if not all([self.base_url, self.email, self.api_token]):
|
||||
raise ValueError(
|
||||
"Jira credentials not initialized. Call set_credentials() first."
|
||||
)
|
||||
|
||||
url = f"{self.base_url}/rest/api/{self.api_version}/{endpoint}"
|
||||
headers = self.get_headers()
|
||||
|
||||
response = requests.get(url, headers=headers, params=params, timeout=500)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(
|
||||
f"API request failed with status code {response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
def get_all_projects(self) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch all projects from Jira.
|
||||
|
||||
Returns:
|
||||
List of project objects
|
||||
|
||||
Raises:
|
||||
ValueError: If credentials have not been set
|
||||
Exception: If the API request fails
|
||||
"""
|
||||
return self.make_api_request("project/search")
|
||||
|
||||
def get_all_issues(self, project_key: str | None = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all issues from Jira.
|
||||
|
||||
Args:
|
||||
project_key: Optional project key to filter issues (e.g., 'PROJ')
|
||||
|
||||
Returns:
|
||||
List of issue objects
|
||||
|
||||
Raises:
|
||||
ValueError: If credentials have not been set
|
||||
Exception: If the API request fails
|
||||
"""
|
||||
jql = "ORDER BY created DESC"
|
||||
if project_key:
|
||||
jql = f'project = "{project_key}" ' + jql
|
||||
|
||||
fields = [
|
||||
"summary",
|
||||
"description",
|
||||
"status",
|
||||
"assignee",
|
||||
"reporter",
|
||||
"created",
|
||||
"updated",
|
||||
"priority",
|
||||
"issuetype",
|
||||
"project",
|
||||
]
|
||||
|
||||
params = {
|
||||
"jql": jql,
|
||||
"fields": ",".join(fields),
|
||||
"maxResults": 100,
|
||||
"startAt": 0,
|
||||
}
|
||||
|
||||
all_issues = []
|
||||
start_at = 0
|
||||
|
||||
while True:
|
||||
params["startAt"] = start_at
|
||||
result = self.make_api_request("search", params)
|
||||
|
||||
if not isinstance(result, dict) or "issues" not in result:
|
||||
raise Exception("Invalid response from Jira API")
|
||||
|
||||
issues = result["issues"]
|
||||
all_issues.extend(issues)
|
||||
|
||||
print(f"Fetched {len(issues)} issues (startAt={start_at})")
|
||||
|
||||
total = result.get("total", 0)
|
||||
if start_at + len(issues) >= total:
|
||||
break
|
||||
|
||||
start_at += len(issues)
|
||||
|
||||
return all_issues
|
||||
|
||||
def get_issues_by_date_range(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
include_comments: bool = True,
|
||||
project_key: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
Fetch issues within a date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format (inclusive)
|
||||
include_comments: Whether to include comments in the response
|
||||
project_key: Optional project key to filter issues
|
||||
|
||||
Returns:
|
||||
Tuple containing (issues list, error message or None)
|
||||
"""
|
||||
try:
|
||||
# Build JQL query for date range
|
||||
# Query issues that were either created OR updated within the date range
|
||||
date_filter = (
|
||||
f"(createdDate >= '{start_date}' AND createdDate <= '{end_date}')"
|
||||
)
|
||||
# TODO : This JQL needs some improvement to work as expected
|
||||
|
||||
_jql = f"{date_filter}"
|
||||
if project_key:
|
||||
_jql = (
|
||||
f'project = "{project_key}" AND {date_filter} ORDER BY created DESC'
|
||||
)
|
||||
|
||||
# Define fields to retrieve
|
||||
fields = [
|
||||
"summary",
|
||||
"description",
|
||||
"status",
|
||||
"assignee",
|
||||
"reporter",
|
||||
"created",
|
||||
"updated",
|
||||
"priority",
|
||||
"issuetype",
|
||||
"project",
|
||||
]
|
||||
|
||||
if include_comments:
|
||||
fields.append("comment")
|
||||
|
||||
params = {
|
||||
# "jql": "", TODO : Add a JQL query to filter from a date range
|
||||
"fields": ",".join(fields),
|
||||
"maxResults": 100,
|
||||
"startAt": 0,
|
||||
}
|
||||
|
||||
all_issues = []
|
||||
start_at = 0
|
||||
|
||||
while True:
|
||||
params["startAt"] = start_at
|
||||
|
||||
result = self.make_api_request("search", params)
|
||||
|
||||
if not isinstance(result, dict) or "issues" not in result:
|
||||
return [], "Invalid response from Jira API"
|
||||
|
||||
issues = result["issues"]
|
||||
all_issues.extend(issues)
|
||||
|
||||
# Check if there are more issues to fetch
|
||||
total = result.get("total", 0)
|
||||
if start_at + len(issues) >= total:
|
||||
break
|
||||
|
||||
start_at += len(issues)
|
||||
|
||||
if not all_issues:
|
||||
return [], "No issues found in the specified date range."
|
||||
|
||||
return all_issues, None
|
||||
|
||||
except Exception as e:
|
||||
return [], f"Error fetching issues: {e!s}"
|
||||
|
||||
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Format an issue for easier consumption.
|
||||
|
||||
Args:
|
||||
issue: The issue object from Jira API
|
||||
|
||||
Returns:
|
||||
Formatted issue dictionary
|
||||
"""
|
||||
fields = issue.get("fields", {})
|
||||
|
||||
# Extract basic issue details
|
||||
formatted = {
|
||||
"id": issue.get("id", ""),
|
||||
"key": issue.get("key", ""),
|
||||
"title": fields.get("summary", ""),
|
||||
"description": fields.get("description", ""),
|
||||
"status": (
|
||||
fields.get("status", {}).get("name", "Unknown")
|
||||
if fields.get("status")
|
||||
else "Unknown"
|
||||
),
|
||||
"status_category": (
|
||||
fields.get("status", {})
|
||||
.get("statusCategory", {})
|
||||
.get("name", "Unknown")
|
||||
if fields.get("status")
|
||||
else "Unknown"
|
||||
),
|
||||
"priority": (
|
||||
fields.get("priority", {}).get("name", "Unknown")
|
||||
if fields.get("priority")
|
||||
else "Unknown"
|
||||
),
|
||||
"issue_type": (
|
||||
fields.get("issuetype", {}).get("name", "Unknown")
|
||||
if fields.get("issuetype")
|
||||
else "Unknown"
|
||||
),
|
||||
"project": (
|
||||
fields.get("project", {}).get("key", "Unknown")
|
||||
if fields.get("project")
|
||||
else "Unknown"
|
||||
),
|
||||
"created_at": fields.get("created", ""),
|
||||
"updated_at": fields.get("updated", ""),
|
||||
"reporter": (
|
||||
{
|
||||
"account_id": (
|
||||
fields.get("reporter", {}).get("accountId", "")
|
||||
if fields.get("reporter")
|
||||
else ""
|
||||
),
|
||||
"display_name": (
|
||||
fields.get("reporter", {}).get("displayName", "Unknown")
|
||||
if fields.get("reporter")
|
||||
else "Unknown"
|
||||
),
|
||||
"email": (
|
||||
fields.get("reporter", {}).get("emailAddress", "")
|
||||
if fields.get("reporter")
|
||||
else ""
|
||||
),
|
||||
}
|
||||
if fields.get("reporter")
|
||||
else {"account_id": "", "display_name": "Unknown", "email": ""}
|
||||
),
|
||||
"assignee": (
|
||||
{
|
||||
"account_id": fields.get("assignee", {}).get("accountId", ""),
|
||||
"display_name": fields.get("assignee", {}).get(
|
||||
"displayName", "Unknown"
|
||||
),
|
||||
"email": fields.get("assignee", {}).get("emailAddress", ""),
|
||||
}
|
||||
if fields.get("assignee")
|
||||
else None
|
||||
),
|
||||
"comments": [],
|
||||
}
|
||||
|
||||
# Extract comments if available
|
||||
if "comment" in fields and "comments" in fields["comment"]:
|
||||
for comment in fields["comment"]["comments"]:
|
||||
formatted_comment = {
|
||||
"id": comment.get("id", ""),
|
||||
"body": comment.get("body", ""),
|
||||
"created_at": comment.get("created", ""),
|
||||
"updated_at": comment.get("updated", ""),
|
||||
"author": (
|
||||
{
|
||||
"account_id": (
|
||||
comment.get("author", {}).get("accountId", "")
|
||||
if comment.get("author")
|
||||
else ""
|
||||
),
|
||||
"display_name": (
|
||||
comment.get("author", {}).get("displayName", "Unknown")
|
||||
if comment.get("author")
|
||||
else "Unknown"
|
||||
),
|
||||
"email": (
|
||||
comment.get("author", {}).get("emailAddress", "")
|
||||
if comment.get("author")
|
||||
else ""
|
||||
),
|
||||
}
|
||||
if comment.get("author")
|
||||
else {"account_id": "", "display_name": "Unknown", "email": ""}
|
||||
),
|
||||
}
|
||||
formatted["comments"].append(formatted_comment)
|
||||
|
||||
return formatted
|
||||
|
||||
def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
|
||||
"""
|
||||
Convert an issue to markdown format.
|
||||
|
||||
Args:
|
||||
issue: The issue object (either raw or formatted)
|
||||
|
||||
Returns:
|
||||
Markdown string representation of the issue
|
||||
"""
|
||||
# Format the issue if it's not already formatted
|
||||
if "key" not in issue:
|
||||
issue = self.format_issue(issue)
|
||||
|
||||
# Build the markdown content
|
||||
markdown = (
|
||||
f"# {issue.get('key', 'No Key')}: {issue.get('title', 'No Title')}\n\n"
|
||||
)
|
||||
|
||||
if issue.get("status"):
|
||||
markdown += f"**Status:** {issue['status']}\n"
|
||||
|
||||
if issue.get("priority"):
|
||||
markdown += f"**Priority:** {issue['priority']}\n"
|
||||
|
||||
if issue.get("issue_type"):
|
||||
markdown += f"**Type:** {issue['issue_type']}\n"
|
||||
|
||||
if issue.get("project"):
|
||||
markdown += f"**Project:** {issue['project']}\n\n"
|
||||
|
||||
if issue.get("assignee") and issue["assignee"].get("display_name"):
|
||||
markdown += f"**Assignee:** {issue['assignee']['display_name']}\n"
|
||||
|
||||
if issue.get("reporter") and issue["reporter"].get("display_name"):
|
||||
markdown += f"**Reporter:** {issue['reporter']['display_name']}\n"
|
||||
|
||||
if issue.get("created_at"):
|
||||
created_date = self.format_date(issue["created_at"])
|
||||
markdown += f"**Created:** {created_date}\n"
|
||||
|
||||
if issue.get("updated_at"):
|
||||
updated_date = self.format_date(issue["updated_at"])
|
||||
markdown += f"**Updated:** {updated_date}\n\n"
|
||||
|
||||
if issue.get("description"):
|
||||
markdown += f"## Description\n\n{issue['description']}\n\n"
|
||||
|
||||
if issue.get("comments"):
|
||||
markdown += f"## Comments ({len(issue['comments'])})\n\n"
|
||||
|
||||
for comment in issue["comments"]:
|
||||
author_name = "Unknown"
|
||||
if comment.get("author") and comment["author"].get("display_name"):
|
||||
author_name = comment["author"]["display_name"]
|
||||
|
||||
comment_date = "Unknown date"
|
||||
if comment.get("created_at"):
|
||||
comment_date = self.format_date(comment["created_at"])
|
||||
|
||||
markdown += f"### {author_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n"
|
||||
|
||||
return markdown
|
||||
|
||||
@staticmethod
|
||||
def format_date(iso_date: str) -> str:
|
||||
"""
|
||||
Format an ISO date string to a more readable format.
|
||||
|
||||
Args:
|
||||
iso_date: ISO format date string
|
||||
|
||||
Returns:
|
||||
Formatted date string
|
||||
"""
|
||||
if not iso_date or not isinstance(iso_date, str):
|
||||
return "Unknown date"
|
||||
|
||||
try:
|
||||
# Jira dates are typically in format: 2023-01-01T12:00:00.000+0000
|
||||
dt = datetime.fromisoformat(iso_date.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
return iso_date
|
||||
|
|
@ -5,96 +5,94 @@ A module for retrieving issues and comments from Linear.
|
|||
Allows fetching issue lists and their comments with date range filtering.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class LinearConnector:
|
||||
"""Class for retrieving issues and comments from Linear."""
|
||||
|
||||
def __init__(self, token: str = None):
|
||||
|
||||
def __init__(self, token: str | None = None):
|
||||
"""
|
||||
Initialize the LinearConnector class.
|
||||
|
||||
|
||||
Args:
|
||||
token: Linear API token (optional, can be set later with set_token)
|
||||
"""
|
||||
self.token = token
|
||||
self.api_url = "https://api.linear.app/graphql"
|
||||
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""
|
||||
Set the Linear API token.
|
||||
|
||||
|
||||
Args:
|
||||
token: Linear API token
|
||||
"""
|
||||
self.token = token
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for Linear API requests.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary of headers
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Linear token has been set
|
||||
"""
|
||||
if not self.token:
|
||||
raise ValueError("Linear token not initialized. Call set_token() first.")
|
||||
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': self.token
|
||||
}
|
||||
|
||||
def execute_graphql_query(self, query: str, variables: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
|
||||
return {"Content-Type": "application/json", "Authorization": self.token}
|
||||
|
||||
def execute_graphql_query(
|
||||
self, query: str, variables: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a GraphQL query against the Linear API.
|
||||
|
||||
|
||||
Args:
|
||||
query: GraphQL query string
|
||||
variables: Variables for the GraphQL query (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
Response data from the API
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Linear token has been set
|
||||
Exception: If the API request fails
|
||||
"""
|
||||
if not self.token:
|
||||
raise ValueError("Linear token not initialized. Call set_token() first.")
|
||||
|
||||
|
||||
headers = self.get_headers()
|
||||
payload = {'query': query}
|
||||
|
||||
payload = {"query": query}
|
||||
|
||||
if variables:
|
||||
payload['variables'] = variables
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
payload["variables"] = variables
|
||||
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(f"Query failed with status code {response.status_code}: {response.text}")
|
||||
|
||||
def get_all_issues(self, include_comments: bool = True) -> List[Dict[str, Any]]:
|
||||
raise Exception(
|
||||
f"Query failed with status code {response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all issues from Linear.
|
||||
|
||||
|
||||
Args:
|
||||
include_comments: Whether to include comments in the response
|
||||
|
||||
|
||||
Returns:
|
||||
List of issue objects
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Linear token has been set
|
||||
Exception: If the API request fails
|
||||
|
|
@ -116,7 +114,7 @@ class LinearConnector:
|
|||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
query = f"""
|
||||
query {{
|
||||
issues {{
|
||||
|
|
@ -147,29 +145,30 @@ class LinearConnector:
|
|||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
result = self.execute_graphql_query(query)
|
||||
|
||||
|
||||
# Extract issues from the response
|
||||
if "data" in result and "issues" in result["data"] and "nodes" in result["data"]["issues"]:
|
||||
if (
|
||||
"data" in result
|
||||
and "issues" in result["data"]
|
||||
and "nodes" in result["data"]["issues"]
|
||||
):
|
||||
return result["data"]["issues"]["nodes"]
|
||||
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_issues_by_date_range(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
include_comments: bool = True
|
||||
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
self, start_date: str, end_date: str, include_comments: bool = True
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
Fetch issues within a date range.
|
||||
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format (inclusive)
|
||||
include_comments: Whether to include comments in the response
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple containing (issues list, error message or None)
|
||||
"""
|
||||
|
|
@ -194,7 +193,7 @@ class LinearConnector:
|
|||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
# Query issues that were either created OR updated within the date range
|
||||
# This ensures we catch both new issues and updated existing issues
|
||||
query = f"""
|
||||
|
|
@ -250,58 +249,65 @@ class LinearConnector:
|
|||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
all_issues = []
|
||||
has_next_page = True
|
||||
cursor = None
|
||||
|
||||
|
||||
# Handle pagination to get all issues
|
||||
while has_next_page:
|
||||
variables = {"after": cursor} if cursor else {}
|
||||
result = self.execute_graphql_query(query, variables)
|
||||
|
||||
|
||||
# Check for errors
|
||||
if "errors" in result:
|
||||
error_message = "; ".join([error.get("message", "Unknown error") for error in result["errors"]])
|
||||
error_message = "; ".join(
|
||||
[
|
||||
error.get("message", "Unknown error")
|
||||
for error in result["errors"]
|
||||
]
|
||||
)
|
||||
return [], f"GraphQL errors: {error_message}"
|
||||
|
||||
|
||||
# Extract issues from the response
|
||||
if "data" in result and "issues" in result["data"]:
|
||||
issues_page = result["data"]["issues"]
|
||||
|
||||
|
||||
# Add issues from this page
|
||||
if "nodes" in issues_page:
|
||||
all_issues.extend(issues_page["nodes"])
|
||||
|
||||
|
||||
# Check if there are more pages
|
||||
if "pageInfo" in issues_page:
|
||||
page_info = issues_page["pageInfo"]
|
||||
has_next_page = page_info.get("hasNextPage", False)
|
||||
cursor = page_info.get("endCursor") if has_next_page else None
|
||||
cursor = (
|
||||
page_info.get("endCursor") if has_next_page else None
|
||||
)
|
||||
else:
|
||||
has_next_page = False
|
||||
else:
|
||||
has_next_page = False
|
||||
|
||||
|
||||
if not all_issues:
|
||||
return [], "No issues found in the specified date range."
|
||||
|
||||
|
||||
return all_issues, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return [], f"Error fetching issues: {str(e)}"
|
||||
|
||||
return [], f"Error fetching issues: {e!s}"
|
||||
|
||||
except ValueError as e:
|
||||
return [], f"Invalid date format: {str(e)}. Please use YYYY-MM-DD."
|
||||
|
||||
def format_issue(self, issue: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return [], f"Invalid date format: {e!s}. Please use YYYY-MM-DD."
|
||||
|
||||
def format_issue(self, issue: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Format an issue for easier consumption.
|
||||
|
||||
|
||||
Args:
|
||||
issue: The issue object from Linear API
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted issue dictionary
|
||||
"""
|
||||
|
|
@ -311,23 +317,37 @@ class LinearConnector:
|
|||
"identifier": issue.get("identifier", ""),
|
||||
"title": issue.get("title", ""),
|
||||
"description": issue.get("description", ""),
|
||||
"state": issue.get("state", {}).get("name", "Unknown") if issue.get("state") else "Unknown",
|
||||
"state_type": issue.get("state", {}).get("type", "Unknown") if issue.get("state") else "Unknown",
|
||||
"state": issue.get("state", {}).get("name", "Unknown")
|
||||
if issue.get("state")
|
||||
else "Unknown",
|
||||
"state_type": issue.get("state", {}).get("type", "Unknown")
|
||||
if issue.get("state")
|
||||
else "Unknown",
|
||||
"created_at": issue.get("createdAt", ""),
|
||||
"updated_at": issue.get("updatedAt", ""),
|
||||
"creator": {
|
||||
"id": issue.get("creator", {}).get("id", "") if issue.get("creator") else "",
|
||||
"name": issue.get("creator", {}).get("name", "Unknown") if issue.get("creator") else "Unknown",
|
||||
"email": issue.get("creator", {}).get("email", "") if issue.get("creator") else ""
|
||||
} if issue.get("creator") else {"id": "", "name": "Unknown", "email": ""},
|
||||
"id": issue.get("creator", {}).get("id", "")
|
||||
if issue.get("creator")
|
||||
else "",
|
||||
"name": issue.get("creator", {}).get("name", "Unknown")
|
||||
if issue.get("creator")
|
||||
else "Unknown",
|
||||
"email": issue.get("creator", {}).get("email", "")
|
||||
if issue.get("creator")
|
||||
else "",
|
||||
}
|
||||
if issue.get("creator")
|
||||
else {"id": "", "name": "Unknown", "email": ""},
|
||||
"assignee": {
|
||||
"id": issue.get("assignee", {}).get("id", ""),
|
||||
"name": issue.get("assignee", {}).get("name", "Unknown"),
|
||||
"email": issue.get("assignee", {}).get("email", "")
|
||||
} if issue.get("assignee") else None,
|
||||
"comments": []
|
||||
"email": issue.get("assignee", {}).get("email", ""),
|
||||
}
|
||||
if issue.get("assignee")
|
||||
else None,
|
||||
"comments": [],
|
||||
}
|
||||
|
||||
|
||||
# Extract comments if available
|
||||
if "comments" in issue and "nodes" in issue["comments"]:
|
||||
for comment in issue["comments"]["nodes"]:
|
||||
|
|
@ -337,85 +357,93 @@ class LinearConnector:
|
|||
"created_at": comment.get("createdAt", ""),
|
||||
"updated_at": comment.get("updatedAt", ""),
|
||||
"user": {
|
||||
"id": comment.get("user", {}).get("id", "") if comment.get("user") else "",
|
||||
"name": comment.get("user", {}).get("name", "Unknown") if comment.get("user") else "Unknown",
|
||||
"email": comment.get("user", {}).get("email", "") if comment.get("user") else ""
|
||||
} if comment.get("user") else {"id": "", "name": "Unknown", "email": ""}
|
||||
"id": comment.get("user", {}).get("id", "")
|
||||
if comment.get("user")
|
||||
else "",
|
||||
"name": comment.get("user", {}).get("name", "Unknown")
|
||||
if comment.get("user")
|
||||
else "Unknown",
|
||||
"email": comment.get("user", {}).get("email", "")
|
||||
if comment.get("user")
|
||||
else "",
|
||||
}
|
||||
if comment.get("user")
|
||||
else {"id": "", "name": "Unknown", "email": ""},
|
||||
}
|
||||
formatted["comments"].append(formatted_comment)
|
||||
|
||||
|
||||
return formatted
|
||||
|
||||
def format_issue_to_markdown(self, issue: Dict[str, Any]) -> str:
|
||||
|
||||
def format_issue_to_markdown(self, issue: dict[str, Any]) -> str:
|
||||
"""
|
||||
Convert an issue to markdown format.
|
||||
|
||||
|
||||
Args:
|
||||
issue: The issue object (either raw or formatted)
|
||||
|
||||
|
||||
Returns:
|
||||
Markdown string representation of the issue
|
||||
"""
|
||||
# Format the issue if it's not already formatted
|
||||
if "identifier" not in issue:
|
||||
issue = self.format_issue(issue)
|
||||
|
||||
|
||||
# Build the markdown content
|
||||
markdown = f"# {issue.get('identifier', 'No ID')}: {issue.get('title', 'No Title')}\n\n"
|
||||
|
||||
if issue.get('state'):
|
||||
|
||||
if issue.get("state"):
|
||||
markdown += f"**Status:** {issue['state']}\n\n"
|
||||
|
||||
if issue.get('assignee') and issue['assignee'].get('name'):
|
||||
|
||||
if issue.get("assignee") and issue["assignee"].get("name"):
|
||||
markdown += f"**Assignee:** {issue['assignee']['name']}\n"
|
||||
|
||||
if issue.get('creator') and issue['creator'].get('name'):
|
||||
|
||||
if issue.get("creator") and issue["creator"].get("name"):
|
||||
markdown += f"**Created by:** {issue['creator']['name']}\n"
|
||||
|
||||
if issue.get('created_at'):
|
||||
created_date = self.format_date(issue['created_at'])
|
||||
|
||||
if issue.get("created_at"):
|
||||
created_date = self.format_date(issue["created_at"])
|
||||
markdown += f"**Created:** {created_date}\n"
|
||||
|
||||
if issue.get('updated_at'):
|
||||
updated_date = self.format_date(issue['updated_at'])
|
||||
|
||||
if issue.get("updated_at"):
|
||||
updated_date = self.format_date(issue["updated_at"])
|
||||
markdown += f"**Updated:** {updated_date}\n\n"
|
||||
|
||||
if issue.get('description'):
|
||||
|
||||
if issue.get("description"):
|
||||
markdown += f"## Description\n\n{issue['description']}\n\n"
|
||||
|
||||
if issue.get('comments'):
|
||||
|
||||
if issue.get("comments"):
|
||||
markdown += f"## Comments ({len(issue['comments'])})\n\n"
|
||||
|
||||
for comment in issue['comments']:
|
||||
|
||||
for comment in issue["comments"]:
|
||||
user_name = "Unknown"
|
||||
if comment.get('user') and comment['user'].get('name'):
|
||||
user_name = comment['user']['name']
|
||||
|
||||
if comment.get("user") and comment["user"].get("name"):
|
||||
user_name = comment["user"]["name"]
|
||||
|
||||
comment_date = "Unknown date"
|
||||
if comment.get('created_at'):
|
||||
comment_date = self.format_date(comment['created_at'])
|
||||
|
||||
if comment.get("created_at"):
|
||||
comment_date = self.format_date(comment["created_at"])
|
||||
|
||||
markdown += f"### {user_name} ({comment_date})\n\n{comment.get('body', '')}\n\n---\n\n"
|
||||
|
||||
|
||||
return markdown
|
||||
|
||||
|
||||
@staticmethod
|
||||
def format_date(iso_date: str) -> str:
|
||||
"""
|
||||
Format an ISO date string to a more readable format.
|
||||
|
||||
|
||||
Args:
|
||||
iso_date: ISO format date string
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted date string
|
||||
"""
|
||||
if not iso_date or not isinstance(iso_date, str):
|
||||
return "Unknown date"
|
||||
|
||||
|
||||
try:
|
||||
dt = datetime.fromisoformat(iso_date.replace('Z', '+00:00'))
|
||||
return dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
dt = datetime.fromisoformat(iso_date.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
return iso_date
|
||||
|
||||
|
|
|
|||
|
|
@ -1,176 +1,182 @@
|
|||
from notion_client import Client
|
||||
|
||||
|
||||
class NotionHistoryConnector:
|
||||
def __init__(self, token):
|
||||
"""
|
||||
Initialize the NotionPageFetcher with a token.
|
||||
|
||||
|
||||
Args:
|
||||
token (str): Notion integration token
|
||||
"""
|
||||
self.notion = Client(auth=token)
|
||||
|
||||
|
||||
def get_all_pages(self, start_date=None, end_date=None):
|
||||
"""
|
||||
Fetches all pages shared with your integration and their content.
|
||||
|
||||
|
||||
Args:
|
||||
start_date (str, optional): ISO 8601 date string (e.g., "2023-01-01T00:00:00Z")
|
||||
end_date (str, optional): ISO 8601 date string (e.g., "2023-12-31T23:59:59Z")
|
||||
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing page data
|
||||
"""
|
||||
# Build the filter for the search
|
||||
# Note: Notion API requires specific filter structure
|
||||
search_params = {}
|
||||
|
||||
|
||||
# Filter for pages only (not databases)
|
||||
search_params["filter"] = {
|
||||
"value": "page",
|
||||
"property": "object"
|
||||
}
|
||||
|
||||
search_params["filter"] = {"value": "page", "property": "object"}
|
||||
|
||||
# Add date filters if provided
|
||||
if start_date or end_date:
|
||||
date_filter = {}
|
||||
|
||||
|
||||
if start_date:
|
||||
date_filter["on_or_after"] = start_date
|
||||
|
||||
|
||||
if end_date:
|
||||
date_filter["on_or_before"] = end_date
|
||||
|
||||
|
||||
# Add the date filter to the search params
|
||||
if date_filter:
|
||||
search_params["sort"] = {
|
||||
"direction": "descending",
|
||||
"timestamp": "last_edited_time"
|
||||
"timestamp": "last_edited_time",
|
||||
}
|
||||
|
||||
|
||||
# First, get a list of all pages the integration has access to
|
||||
search_results = self.notion.search(**search_params)
|
||||
|
||||
|
||||
pages = search_results["results"]
|
||||
all_page_data = []
|
||||
|
||||
|
||||
for page in pages:
|
||||
page_id = page["id"]
|
||||
|
||||
|
||||
# Get detailed page information
|
||||
page_content = self.get_page_content(page_id)
|
||||
|
||||
all_page_data.append({
|
||||
"page_id": page_id,
|
||||
"title": self.get_page_title(page),
|
||||
"content": page_content
|
||||
})
|
||||
|
||||
|
||||
all_page_data.append(
|
||||
{
|
||||
"page_id": page_id,
|
||||
"title": self.get_page_title(page),
|
||||
"content": page_content,
|
||||
}
|
||||
)
|
||||
|
||||
return all_page_data
|
||||
|
||||
|
||||
def get_page_title(self, page):
|
||||
"""
|
||||
Extracts the title from a page object.
|
||||
|
||||
|
||||
Args:
|
||||
page (dict): Notion page object
|
||||
|
||||
|
||||
Returns:
|
||||
str: Page title or a fallback string
|
||||
"""
|
||||
# Title can be in different properties depending on the page type
|
||||
if "properties" in page:
|
||||
# Try to find a title property
|
||||
for prop_name, prop_data in page["properties"].items():
|
||||
for _prop_name, prop_data in page["properties"].items():
|
||||
if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
|
||||
return " ".join([text_obj["plain_text"] for text_obj in prop_data["title"]])
|
||||
|
||||
return " ".join(
|
||||
[text_obj["plain_text"] for text_obj in prop_data["title"]]
|
||||
)
|
||||
|
||||
# If no title found, return the page ID as fallback
|
||||
return f"Untitled page ({page['id']})"
|
||||
|
||||
|
||||
def get_page_content(self, page_id):
|
||||
"""
|
||||
Fetches the content (blocks) of a specific page.
|
||||
|
||||
|
||||
Args:
|
||||
page_id (str): The ID of the page to fetch
|
||||
|
||||
|
||||
Returns:
|
||||
list: List of processed blocks from the page
|
||||
"""
|
||||
blocks = []
|
||||
has_more = True
|
||||
cursor = None
|
||||
|
||||
|
||||
# Paginate through all blocks
|
||||
while has_more:
|
||||
if cursor:
|
||||
response = self.notion.blocks.children.list(block_id=page_id, start_cursor=cursor)
|
||||
response = self.notion.blocks.children.list(
|
||||
block_id=page_id, start_cursor=cursor
|
||||
)
|
||||
else:
|
||||
response = self.notion.blocks.children.list(block_id=page_id)
|
||||
|
||||
|
||||
blocks.extend(response["results"])
|
||||
has_more = response["has_more"]
|
||||
|
||||
|
||||
if has_more:
|
||||
cursor = response["next_cursor"]
|
||||
|
||||
|
||||
# Process nested blocks recursively
|
||||
processed_blocks = []
|
||||
for block in blocks:
|
||||
processed_block = self.process_block(block)
|
||||
processed_blocks.append(processed_block)
|
||||
|
||||
|
||||
return processed_blocks
|
||||
|
||||
|
||||
def process_block(self, block):
|
||||
"""
|
||||
Processes a block and recursively fetches any child blocks.
|
||||
|
||||
|
||||
Args:
|
||||
block (dict): The block to process
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Processed block with content and children
|
||||
"""
|
||||
block_id = block["id"]
|
||||
block_type = block["type"]
|
||||
|
||||
|
||||
# Extract block content based on its type
|
||||
content = self.extract_block_content(block)
|
||||
|
||||
|
||||
# Check if block has children
|
||||
has_children = block.get("has_children", False)
|
||||
child_blocks = []
|
||||
|
||||
|
||||
if has_children:
|
||||
# Fetch and process child blocks
|
||||
children_response = self.notion.blocks.children.list(block_id=block_id)
|
||||
for child_block in children_response["results"]:
|
||||
child_blocks.append(self.process_block(child_block))
|
||||
|
||||
|
||||
return {
|
||||
"id": block_id,
|
||||
"type": block_type,
|
||||
"content": content,
|
||||
"children": child_blocks
|
||||
"children": child_blocks,
|
||||
}
|
||||
|
||||
|
||||
def extract_block_content(self, block):
|
||||
"""
|
||||
Extracts the content from a block based on its type.
|
||||
|
||||
|
||||
Args:
|
||||
block (dict): The block to extract content from
|
||||
|
||||
|
||||
Returns:
|
||||
str: Extracted content as a string
|
||||
"""
|
||||
block_type = block["type"]
|
||||
|
||||
|
||||
# Different block types have different structures
|
||||
if block_type in block and "rich_text" in block[block_type]:
|
||||
return "".join([text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]])
|
||||
return "".join(
|
||||
[text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]]
|
||||
)
|
||||
elif block_type == "image":
|
||||
# Instead of returning the raw URL which may contain sensitive AWS credentials,
|
||||
# return a placeholder or reference to the image
|
||||
|
|
@ -183,18 +189,21 @@ class NotionHistoryConnector:
|
|||
# Only return the domain part of external URLs to avoid potential sensitive parameters
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
return f"[External Image from {parsed_url.netloc}]"
|
||||
except:
|
||||
except Exception:
|
||||
return "[External Image]"
|
||||
elif block_type == "code":
|
||||
language = block["code"]["language"]
|
||||
code_text = "".join([text_obj["plain_text"] for text_obj in block["code"]["rich_text"]])
|
||||
code_text = "".join(
|
||||
[text_obj["plain_text"] for text_obj in block["code"]["rich_text"]]
|
||||
)
|
||||
return f"```{language}\n{code_text}\n```"
|
||||
elif block_type == "equation":
|
||||
return block["equation"]["expression"]
|
||||
# Add more block types as needed
|
||||
|
||||
|
||||
# Return empty string for unsupported block types
|
||||
return ""
|
||||
|
||||
|
|
@ -203,23 +212,23 @@ class NotionHistoryConnector:
|
|||
# if __name__ == "__main__":
|
||||
# # Simple example of how to use this module
|
||||
# import argparse
|
||||
|
||||
|
||||
# parser = argparse.ArgumentParser(description="Fetch Notion pages using an integration token")
|
||||
# parser.add_argument("--token", help="Your Notion integration token")
|
||||
# parser.add_argument("--start-date", help="Start date in ISO format (e.g., 2023-01-01T00:00:00Z)")
|
||||
# parser.add_argument("--end-date", help="End date in ISO format (e.g., 2023-12-31T23:59:59Z)")
|
||||
# args = parser.parse_args()
|
||||
|
||||
|
||||
# token = args.token
|
||||
# if not token:
|
||||
# token = input("Enter your Notion integration token: ")
|
||||
|
||||
|
||||
# fetcher = NotionPageFetcher(token)
|
||||
|
||||
|
||||
# try:
|
||||
# pages = fetcher.get_all_pages(args.start_date, args.end_date)
|
||||
# print(f"Fetched {len(pages)} pages from Notion")
|
||||
# for page in pages:
|
||||
# print(f"- {page['title']}")
|
||||
# except Exception as e:
|
||||
# print(f"Error: {str(e)}")
|
||||
# print(f"Error: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -5,47 +5,48 @@ A module for retrieving conversation history from Slack channels.
|
|||
Allows fetching channel lists and message history with date range filtering.
|
||||
"""
|
||||
|
||||
import time # Added import
|
||||
import logging # Added import
|
||||
import logging # Added import
|
||||
import time # Added import
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
|
||||
logger = logging.getLogger(__name__) # Added logger
|
||||
logger = logging.getLogger(__name__) # Added logger
|
||||
|
||||
|
||||
class SlackHistory:
|
||||
"""Class for retrieving conversation history from Slack channels."""
|
||||
|
||||
def __init__(self, token: str = None):
|
||||
|
||||
def __init__(self, token: str | None = None):
|
||||
"""
|
||||
Initialize the SlackHistory class.
|
||||
|
||||
|
||||
Args:
|
||||
token: Slack API token (optional, can be set later with set_token)
|
||||
"""
|
||||
self.client = WebClient(token=token) if token else None
|
||||
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""
|
||||
Set the Slack API token.
|
||||
|
||||
|
||||
Args:
|
||||
token: Slack API token
|
||||
"""
|
||||
self.client = WebClient(token=token)
|
||||
|
||||
def get_all_channels(self, include_private: bool = True) -> List[Dict[str, Any]]:
|
||||
|
||||
def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all channels that the bot has access to, with rate limit handling.
|
||||
|
||||
|
||||
Args:
|
||||
include_private: Whether to include private channels
|
||||
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each representing a channel with id, name, is_private, is_member.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an unrecoverable error calling the Slack API
|
||||
|
|
@ -53,8 +54,8 @@ class SlackHistory:
|
|||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
channels_list = [] # Changed from dict to list
|
||||
|
||||
channels_list = [] # Changed from dict to list
|
||||
types = "public_channel"
|
||||
if include_private:
|
||||
types += ",private_channel"
|
||||
|
|
@ -65,16 +66,16 @@ class SlackHistory:
|
|||
while is_first_request or next_cursor:
|
||||
try:
|
||||
if not is_first_request: # Add delay only for paginated requests
|
||||
logger.info(f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}")
|
||||
logger.info(
|
||||
f"Paginating for channels, waiting 3 seconds before next call. Cursor: {next_cursor}"
|
||||
)
|
||||
time.sleep(3)
|
||||
|
||||
current_limit = 1000 # Max limit
|
||||
api_result = self.client.conversations_list(
|
||||
types=types,
|
||||
cursor=next_cursor,
|
||||
limit=current_limit
|
||||
types=types, cursor=next_cursor, limit=current_limit
|
||||
)
|
||||
|
||||
|
||||
channels_on_page = api_result["channels"]
|
||||
for channel in channels_on_page:
|
||||
if "name" in channel and "id" in channel:
|
||||
|
|
@ -86,12 +87,13 @@ class SlackHistory:
|
|||
# It indicates if the authenticated user (bot) is a member.
|
||||
# For public channels, this might be true or the API might not focus on it
|
||||
# if the bot can read it anyway. For private, it's crucial.
|
||||
"is_member": channel.get("is_member", False)
|
||||
"is_member": channel.get("is_member", False),
|
||||
}
|
||||
channels_list.append(channel_data)
|
||||
else:
|
||||
logger.warning(f"Channel found with missing name or id. Data: {channel}")
|
||||
|
||||
logger.warning(
|
||||
f"Channel found with missing name or id. Data: {channel}"
|
||||
)
|
||||
|
||||
next_cursor = api_result.get("response_metadata", {}).get("next_cursor")
|
||||
is_first_request = False # Subsequent requests are not the first
|
||||
|
|
@ -101,57 +103,65 @@ class SlackHistory:
|
|||
|
||||
except SlackApiError as e:
|
||||
if e.response is not None and e.response.status_code == 429:
|
||||
retry_after_header = e.response.headers.get('Retry-After')
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
wait_duration = 60 # Default wait time
|
||||
if retry_after_header and retry_after_header.isdigit():
|
||||
wait_duration = int(retry_after_header)
|
||||
|
||||
logger.warning(f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}")
|
||||
|
||||
logger.warning(
|
||||
f"Slack API rate limit hit while fetching channels. Waiting for {wait_duration} seconds. Cursor: {next_cursor}"
|
||||
)
|
||||
time.sleep(wait_duration)
|
||||
# The loop will continue, retrying with the same cursor
|
||||
else:
|
||||
# Not a 429 error, or no response object, re-raise
|
||||
raise SlackApiError(f"Error retrieving channels: {e}", e.response)
|
||||
raise SlackApiError(
|
||||
f"Error retrieving channels: {e}", e.response
|
||||
) from e
|
||||
except Exception as general_error:
|
||||
# Handle other potential errors like network issues if necessary, or re-raise
|
||||
logger.error(f"An unexpected error occurred during channel fetching: {general_error}")
|
||||
raise RuntimeError(f"An unexpected error occurred during channel fetching: {general_error}")
|
||||
|
||||
logger.error(
|
||||
f"An unexpected error occurred during channel fetching: {general_error}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"An unexpected error occurred during channel fetching: {general_error}"
|
||||
) from general_error
|
||||
|
||||
return channels_list
|
||||
|
||||
|
||||
def get_conversation_history(
|
||||
self,
|
||||
channel_id: str,
|
||||
limit: int = 1000,
|
||||
oldest: Optional[int] = None,
|
||||
latest: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
self,
|
||||
channel_id: str,
|
||||
limit: int = 1000,
|
||||
oldest: int | None = None,
|
||||
latest: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch conversation history for a channel.
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to fetch history for
|
||||
limit: Maximum number of messages to return per request (default 1000)
|
||||
oldest: Start of time range (Unix timestamp)
|
||||
latest: End of time range (Unix timestamp)
|
||||
|
||||
|
||||
Returns:
|
||||
List of message objects
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
|
||||
messages = []
|
||||
next_cursor = None
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Proactive delay for conversations.history (Tier 3)
|
||||
time.sleep(1.2) # Wait 1.2 seconds before each history call.
|
||||
time.sleep(1.2) # Wait 1.2 seconds before each history call.
|
||||
|
||||
kwargs = {
|
||||
"channel": channel_id,
|
||||
|
|
@ -163,16 +173,19 @@ class SlackHistory:
|
|||
kwargs["latest"] = latest
|
||||
if next_cursor:
|
||||
kwargs["cursor"] = next_cursor
|
||||
|
||||
|
||||
current_api_call_successful = False
|
||||
result = None # Ensure result is defined
|
||||
result = None # Ensure result is defined
|
||||
try:
|
||||
result = self.client.conversations_history(**kwargs)
|
||||
current_api_call_successful = True
|
||||
except SlackApiError as e_history:
|
||||
if e_history.response is not None and e_history.response.status_code == 429:
|
||||
retry_after_str = e_history.response.headers.get('Retry-After')
|
||||
wait_time = 60 # Default
|
||||
if (
|
||||
e_history.response is not None
|
||||
and e_history.response.status_code == 429
|
||||
):
|
||||
retry_after_str = e_history.response.headers.get("Retry-After")
|
||||
wait_time = 60 # Default
|
||||
if retry_after_str and retry_after_str.isdigit():
|
||||
wait_time = int(retry_after_str)
|
||||
logger.warning(
|
||||
|
|
@ -182,47 +195,54 @@ class SlackHistory:
|
|||
time.sleep(wait_time)
|
||||
# current_api_call_successful remains False, loop will retry this page
|
||||
else:
|
||||
raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors
|
||||
|
||||
raise # Re-raise to outer handler for not_in_channel or other SlackApiErrors
|
||||
|
||||
if not current_api_call_successful:
|
||||
continue # Retry the current page fetch due to handled rate limit
|
||||
continue # Retry the current page fetch due to handled rate limit
|
||||
|
||||
# Process result if successful
|
||||
batch = result["messages"]
|
||||
messages.extend(batch)
|
||||
|
||||
|
||||
if result.get("has_more", False) and len(messages) < limit:
|
||||
next_cursor = result["response_metadata"]["next_cursor"]
|
||||
else:
|
||||
break # Exit pagination loop
|
||||
|
||||
except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try
|
||||
if (e.response is not None and
|
||||
hasattr(e.response, 'data') and
|
||||
isinstance(e.response.data, dict) and
|
||||
e.response.data.get('error') == 'not_in_channel'):
|
||||
break # Exit pagination loop
|
||||
|
||||
except SlackApiError as e: # Outer catch for not_in_channel or unhandled SlackApiErrors from inner try
|
||||
if (
|
||||
e.response is not None
|
||||
and hasattr(e.response, "data")
|
||||
and isinstance(e.response.data, dict)
|
||||
and e.response.data.get("error") == "not_in_channel"
|
||||
):
|
||||
logger.warning(
|
||||
f"Bot is not in channel '{channel_id}'. Cannot fetch history. "
|
||||
"Please add the bot to this channel."
|
||||
)
|
||||
return []
|
||||
return []
|
||||
# For other SlackApiErrors from inner block or this level
|
||||
raise SlackApiError(f"Error retrieving history for channel {channel_id}: {e}", e.response)
|
||||
except Exception as general_error: # Catch any other unexpected errors
|
||||
logger.error(f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}")
|
||||
raise SlackApiError(
|
||||
f"Error retrieving history for channel {channel_id}: {e}",
|
||||
e.response,
|
||||
) from e
|
||||
except Exception as general_error: # Catch any other unexpected errors
|
||||
logger.error(
|
||||
f"Unexpected error in get_conversation_history for channel {channel_id}: {general_error}"
|
||||
)
|
||||
# Re-raise the general error to allow higher-level handling or visibility
|
||||
raise
|
||||
|
||||
raise general_error from general_error
|
||||
|
||||
return messages[:limit]
|
||||
|
||||
@staticmethod
|
||||
def convert_date_to_timestamp(date_str: str) -> Optional[int]:
|
||||
def convert_date_to_timestamp(date_str: str) -> int | None:
|
||||
"""
|
||||
Convert a date string in format YYYY-MM-DD to Unix timestamp.
|
||||
|
||||
|
||||
Args:
|
||||
date_str: Date string in YYYY-MM-DD format
|
||||
|
||||
|
||||
Returns:
|
||||
Unix timestamp (seconds since epoch) or None if invalid format
|
||||
"""
|
||||
|
|
@ -231,67 +251,63 @@ class SlackHistory:
|
|||
return int(dt.timestamp())
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def get_history_by_date_range(
|
||||
self,
|
||||
channel_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
limit: int = 1000
|
||||
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
self, channel_id: str, start_date: str, end_date: str, limit: int = 1000
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
Fetch conversation history within a date range.
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to fetch history for
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format (inclusive)
|
||||
limit: Maximum number of messages to return
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple containing (messages list, error message or None)
|
||||
"""
|
||||
oldest = self.convert_date_to_timestamp(start_date)
|
||||
if not oldest:
|
||||
return [], f"Invalid start date format: {start_date}. Please use YYYY-MM-DD."
|
||||
|
||||
return (
|
||||
[],
|
||||
f"Invalid start date format: {start_date}. Please use YYYY-MM-DD.",
|
||||
)
|
||||
|
||||
latest = self.convert_date_to_timestamp(end_date)
|
||||
if not latest:
|
||||
return [], f"Invalid end date format: {end_date}. Please use YYYY-MM-DD."
|
||||
|
||||
|
||||
# Add one day to end date to make it inclusive
|
||||
latest += 86400 # seconds in a day
|
||||
|
||||
|
||||
try:
|
||||
messages = self.get_conversation_history(
|
||||
channel_id=channel_id,
|
||||
limit=limit,
|
||||
oldest=oldest,
|
||||
latest=latest
|
||||
channel_id=channel_id, limit=limit, oldest=oldest, latest=latest
|
||||
)
|
||||
return messages, None
|
||||
except SlackApiError as e:
|
||||
return [], f"Slack API error: {str(e)}"
|
||||
return [], f"Slack API error: {e!s}"
|
||||
except ValueError as e:
|
||||
return [], str(e)
|
||||
|
||||
def get_user_info(self, user_id: str) -> Dict[str, Any]:
|
||||
|
||||
def get_user_info(self, user_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about a user.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to get info for
|
||||
|
||||
|
||||
Returns:
|
||||
User information dictionary
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Proactive delay for users.info (Tier 4) - generally not needed unless called extremely rapidly.
|
||||
|
|
@ -299,46 +315,60 @@ class SlackHistory:
|
|||
# time.sleep(0.6) # Optional: ~100 req/min if ever needed.
|
||||
|
||||
result = self.client.users_info(user=user_id)
|
||||
return result["user"] # Success, return and exit loop implicitly
|
||||
return result["user"] # Success, return and exit loop implicitly
|
||||
|
||||
except SlackApiError as e_user_info:
|
||||
if e_user_info.response is not None and e_user_info.response.status_code == 429:
|
||||
retry_after_str = e_user_info.response.headers.get('Retry-After')
|
||||
if (
|
||||
e_user_info.response is not None
|
||||
and e_user_info.response.status_code == 429
|
||||
):
|
||||
retry_after_str = e_user_info.response.headers.get("Retry-After")
|
||||
wait_time = 30 # Default for Tier 4, can be adjusted
|
||||
if retry_after_str and retry_after_str.isdigit():
|
||||
wait_time = int(retry_after_str)
|
||||
logger.warning(f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds.")
|
||||
logger.warning(
|
||||
f"Rate limited by Slack on users.info for user {user_id}. Retrying after {wait_time} seconds."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
continue # Retry the API call
|
||||
else:
|
||||
# Not a 429 error, or no response object, re-raise
|
||||
raise SlackApiError(f"Error retrieving user info for {user_id}: {e_user_info}", e_user_info.response)
|
||||
except Exception as general_error: # Catch any other unexpected errors
|
||||
logger.error(f"Unexpected error in get_user_info for user {user_id}: {general_error}")
|
||||
raise # Re-raise unexpected errors
|
||||
|
||||
def format_message(self, msg: Dict[str, Any], include_user_info: bool = False) -> Dict[str, Any]:
|
||||
raise SlackApiError(
|
||||
f"Error retrieving user info for {user_id}: {e_user_info}",
|
||||
e_user_info.response,
|
||||
) from e_user_info
|
||||
except Exception as general_error: # Catch any other unexpected errors
|
||||
logger.error(
|
||||
f"Unexpected error in get_user_info for user {user_id}: {general_error}"
|
||||
)
|
||||
raise general_error from general_error # Re-raise unexpected errors
|
||||
|
||||
def format_message(
|
||||
self, msg: dict[str, Any], include_user_info: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Format a message for easier consumption.
|
||||
|
||||
|
||||
Args:
|
||||
msg: The message object from Slack API
|
||||
include_user_info: Whether to fetch and include user info
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted message dictionary
|
||||
"""
|
||||
formatted = {
|
||||
"text": msg.get("text", ""),
|
||||
"timestamp": msg.get("ts"),
|
||||
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
"user_id": msg.get("user", "UNKNOWN"),
|
||||
"has_attachments": bool(msg.get("attachments")),
|
||||
"has_files": bool(msg.get("files")),
|
||||
"thread_ts": msg.get("thread_ts"),
|
||||
"is_thread": "thread_ts" in msg,
|
||||
}
|
||||
|
||||
|
||||
if include_user_info and "user" in msg and self.client:
|
||||
try:
|
||||
user_info = self.get_user_info(msg["user"])
|
||||
|
|
@ -347,7 +377,7 @@ class SlackHistory:
|
|||
except Exception:
|
||||
# If we can't get user info, just continue without it
|
||||
formatted["user_name"] = "Unknown"
|
||||
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
|
|
@ -388,4 +418,4 @@ if __name__ == "__main__":
|
|||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,23 +1,24 @@
|
|||
import unittest
|
||||
from unittest.mock import patch, Mock
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from github3.exceptions import ForbiddenError # Import the specific exception
|
||||
|
||||
# Adjust the import path based on the actual location if test_github_connector.py
|
||||
# is not in the same directory as github_connector.py or if paths are set up differently.
|
||||
# Assuming surfsend_backend/app/connectors/test_github_connector.py
|
||||
from surfsense_backend.app.connectors.github_connector import GitHubConnector
|
||||
from github3.exceptions import ForbiddenError # Import the specific exception
|
||||
|
||||
|
||||
class TestGitHubConnector(unittest.TestCase):
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_get_user_repositories_uses_type_all(self, mock_github_login):
|
||||
# Mock the GitHub client object and its methods
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
||||
# Mock the self.gh.me() call in __init__ to prevent an actual API call
|
||||
mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization
|
||||
mock_gh_instance.me.return_value = Mock() # Simple mock to pass initialization
|
||||
|
||||
# Prepare mock repository data
|
||||
mock_repo1_data = Mock()
|
||||
|
|
@ -27,7 +28,9 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
mock_repo1_data.private = False
|
||||
mock_repo1_data.html_url = "http://example.com/user/repo1"
|
||||
mock_repo1_data.description = "Test repo 1"
|
||||
mock_repo1_data.updated_at = datetime(2023, 1, 1, 10, 30, 0) # Added time component
|
||||
mock_repo1_data.updated_at = datetime(
|
||||
2023, 1, 1, 10, 30, 0
|
||||
) # Added time component
|
||||
|
||||
mock_repo2_data = Mock()
|
||||
mock_repo2_data.id = 2
|
||||
|
|
@ -36,8 +39,10 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
mock_repo2_data.private = True
|
||||
mock_repo2_data.html_url = "http://example.com/org/org-repo"
|
||||
mock_repo2_data.description = "Org repo"
|
||||
mock_repo2_data.updated_at = datetime(2023, 1, 2, 12, 0, 0) # Added time component
|
||||
|
||||
mock_repo2_data.updated_at = datetime(
|
||||
2023, 1, 2, 12, 0, 0
|
||||
) # Added time component
|
||||
|
||||
# Configure the mock for gh.repositories() call
|
||||
# This method is an iterator, so it should return an iterable (e.g., a list)
|
||||
mock_gh_instance.repositories.return_value = [mock_repo1_data, mock_repo2_data]
|
||||
|
|
@ -46,26 +51,38 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
repositories = connector.get_user_repositories()
|
||||
|
||||
# Assert that gh.repositories was called correctly
|
||||
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated')
|
||||
mock_gh_instance.repositories.assert_called_once_with(
|
||||
type="all", sort="updated"
|
||||
)
|
||||
|
||||
# Assert the structure and content of the returned data
|
||||
expected_repositories = [
|
||||
{
|
||||
"id": 1, "name": "repo1", "full_name": "user/repo1", "private": False,
|
||||
"url": "http://example.com/user/repo1", "description": "Test repo 1",
|
||||
"last_updated": datetime(2023, 1, 1, 10, 30, 0)
|
||||
"id": 1,
|
||||
"name": "repo1",
|
||||
"full_name": "user/repo1",
|
||||
"private": False,
|
||||
"url": "http://example.com/user/repo1",
|
||||
"description": "Test repo 1",
|
||||
"last_updated": datetime(2023, 1, 1, 10, 30, 0),
|
||||
},
|
||||
{
|
||||
"id": 2, "name": "org-repo", "full_name": "org/org-repo", "private": True,
|
||||
"url": "http://example.com/org/org-repo", "description": "Org repo",
|
||||
"last_updated": datetime(2023, 1, 2, 12, 0, 0)
|
||||
}
|
||||
"id": 2,
|
||||
"name": "org-repo",
|
||||
"full_name": "org/org-repo",
|
||||
"private": True,
|
||||
"url": "http://example.com/org/org-repo",
|
||||
"description": "Org repo",
|
||||
"last_updated": datetime(2023, 1, 2, 12, 0, 0),
|
||||
},
|
||||
]
|
||||
self.assertEqual(repositories, expected_repositories)
|
||||
self.assertEqual(len(repositories), 2)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
def test_get_user_repositories_handles_empty_description_and_none_updated_at(self, mock_github_login):
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_get_user_repositories_handles_empty_description_and_none_updated_at(
|
||||
self, mock_github_login
|
||||
):
|
||||
# Mock the GitHub client object and its methods
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
|
@ -77,61 +94,73 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
mock_repo_data.full_name = "user/repo_no_desc"
|
||||
mock_repo_data.private = False
|
||||
mock_repo_data.html_url = "http://example.com/user/repo_no_desc"
|
||||
mock_repo_data.description = None # Test None description
|
||||
mock_repo_data.updated_at = None # Test None updated_at
|
||||
mock_repo_data.description = None # Test None description
|
||||
mock_repo_data.updated_at = None # Test None updated_at
|
||||
|
||||
mock_gh_instance.repositories.return_value = [mock_repo_data]
|
||||
connector = GitHubConnector(token="fake_token")
|
||||
repositories = connector.get_user_repositories()
|
||||
|
||||
mock_gh_instance.repositories.assert_called_once_with(type='all', sort='updated')
|
||||
mock_gh_instance.repositories.assert_called_once_with(
|
||||
type="all", sort="updated"
|
||||
)
|
||||
expected_repositories = [
|
||||
{
|
||||
"id": 1, "name": "repo_no_desc", "full_name": "user/repo_no_desc", "private": False,
|
||||
"url": "http://example.com/user/repo_no_desc", "description": "", # Expect empty string
|
||||
"last_updated": None # Expect None
|
||||
"id": 1,
|
||||
"name": "repo_no_desc",
|
||||
"full_name": "user/repo_no_desc",
|
||||
"private": False,
|
||||
"url": "http://example.com/user/repo_no_desc",
|
||||
"description": "", # Expect empty string
|
||||
"last_updated": None, # Expect None
|
||||
}
|
||||
]
|
||||
self.assertEqual(repositories, expected_repositories)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_github_connector_initialization_failure_forbidden(self, mock_github_login):
|
||||
# Test that __init__ raises ValueError on auth failure (ForbiddenError)
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
||||
|
||||
# Create a mock response object for the ForbiddenError
|
||||
# The actual response structure might vary, but github3.py's ForbiddenError
|
||||
# can be instantiated with just a response object that has a status_code.
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 403 # Typically Forbidden
|
||||
|
||||
mock_response.status_code = 403 # Typically Forbidden
|
||||
|
||||
# Setup the side_effect for self.gh.me()
|
||||
mock_gh_instance.me.side_effect = ForbiddenError(mock_response)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
GitHubConnector(token="invalid_token_forbidden")
|
||||
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception))
|
||||
self.assertIn(
|
||||
"Invalid GitHub token or insufficient permissions.", str(context.exception)
|
||||
)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
def test_github_connector_initialization_failure_authentication_failed(self, mock_github_login):
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_github_connector_initialization_failure_authentication_failed(
|
||||
self, mock_github_login
|
||||
):
|
||||
# Test that __init__ raises ValueError on auth failure (AuthenticationFailed, which is a subclass of ForbiddenError)
|
||||
# For github3.py, AuthenticationFailed is more specific for token issues.
|
||||
from github3.exceptions import AuthenticationFailed
|
||||
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401 # Typically Unauthorized
|
||||
|
||||
mock_response.status_code = 401 # Typically Unauthorized
|
||||
|
||||
mock_gh_instance.me.side_effect = AuthenticationFailed(mock_response)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
GitHubConnector(token="invalid_token_authfailed")
|
||||
self.assertIn("Invalid GitHub token or insufficient permissions.", str(context.exception))
|
||||
|
||||
@patch('surfsense_backend.app.connectors.github_connector.github_login')
|
||||
self.assertIn(
|
||||
"Invalid GitHub token or insufficient permissions.", str(context.exception)
|
||||
)
|
||||
|
||||
@patch("surfsense_backend.app.connectors.github_connector.github_login")
|
||||
def test_get_user_repositories_handles_api_exception(self, mock_github_login):
|
||||
mock_gh_instance = Mock()
|
||||
mock_github_login.return_value = mock_gh_instance
|
||||
|
|
@ -142,13 +171,18 @@ class TestGitHubConnector(unittest.TestCase):
|
|||
|
||||
connector = GitHubConnector(token="fake_token")
|
||||
# We expect it to log an error and return an empty list
|
||||
with patch('surfsense_backend.app.connectors.github_connector.logger') as mock_logger:
|
||||
with patch(
|
||||
"surfsense_backend.app.connectors.github_connector.logger"
|
||||
) as mock_logger:
|
||||
repositories = connector.get_user_repositories()
|
||||
|
||||
|
||||
self.assertEqual(repositories, [])
|
||||
mock_logger.error.assert_called_once()
|
||||
self.assertIn("Failed to fetch GitHub repositories: API Error", mock_logger.error.call_args[0][0])
|
||||
self.assertIn(
|
||||
"Failed to fetch GitHub repositories: API Error",
|
||||
mock_logger.error.call_args[0][0],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,373 +1,448 @@
|
|||
import unittest
|
||||
import time # Imported to be available for patching target module
|
||||
from unittest.mock import patch, Mock, call
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
# Since test_slack_history.py is in the same directory as slack_history.py
|
||||
from .slack_history import SlackHistory
|
||||
|
||||
class TestSlackHistoryGetAllChannels(unittest.TestCase):
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_pagination_with_delay(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
class TestSlackHistoryGetAllChannels(unittest.TestCase):
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_pagination_with_delay(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
# Mock API responses now include is_private and is_member
|
||||
page1_response = {
|
||||
"channels": [
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True},
|
||||
{"name": "dev", "id": "C0", "is_private": False, "is_member": True}
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True},
|
||||
{"name": "dev", "id": "C0", "is_private": False, "is_member": True},
|
||||
],
|
||||
"response_metadata": {"next_cursor": "cursor123"}
|
||||
"response_metadata": {"next_cursor": "cursor123"},
|
||||
}
|
||||
page2_response = {
|
||||
"channels": [{"name": "random", "id": "C2", "is_private": True, "is_member": True}],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
"channels": [
|
||||
{"name": "random", "id": "C2", "is_private": True, "is_member": True}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
page1_response,
|
||||
page2_response
|
||||
page2_response,
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C1", "name": "general", "is_private": False, "is_member": True},
|
||||
{"id": "C0", "name": "dev", "is_private": False, "is_member": True},
|
||||
{"id": "C2", "name": "random", "is_private": True, "is_member": True}
|
||||
{"id": "C2", "name": "random", "is_private": True, "is_member": True},
|
||||
]
|
||||
|
||||
|
||||
self.assertEqual(len(channels_list), 3)
|
||||
self.assertListEqual(channels_list, expected_channels_list) # Assert list equality
|
||||
|
||||
self.assertListEqual(
|
||||
channels_list, expected_channels_list
|
||||
) # Assert list equality
|
||||
|
||||
expected_calls = [
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||
call(types="public_channel,private_channel", cursor="cursor123", limit=1000)
|
||||
call(
|
||||
types="public_channel,private_channel", cursor="cursor123", limit=1000
|
||||
),
|
||||
]
|
||||
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
mock_sleep.assert_called_once_with(3)
|
||||
mock_logger.info.assert_called_once_with("Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123")
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_rate_limit_with_retry_after(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_sleep.assert_called_once_with(3)
|
||||
mock_logger.info.assert_called_once_with(
|
||||
"Paginating for channels, waiting 3 seconds before next call. Cursor: cursor123"
|
||||
)
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_rate_limit_with_retry_after(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {'Retry-After': '5'}
|
||||
|
||||
mock_error_response.headers = {"Retry-After": "5"}
|
||||
|
||||
successful_response = {
|
||||
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
"channels": [
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
successful_response
|
||||
successful_response,
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
|
||||
]
|
||||
self.assertEqual(len(channels_list), 1)
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
|
||||
mock_sleep.assert_called_once_with(5)
|
||||
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None")
|
||||
|
||||
|
||||
mock_sleep.assert_called_once_with(5)
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Slack API rate limit hit while fetching channels. Waiting for 5 seconds. Cursor: None"
|
||||
)
|
||||
|
||||
expected_calls = [
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000)
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||
call(types="public_channel,private_channel", cursor=None, limit=1000),
|
||||
]
|
||||
mock_client_instance.conversations_list.assert_has_calls(expected_calls)
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_rate_limit_no_retry_after_valid_header(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_rate_limit_no_retry_after_valid_header(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {'Retry-After': 'invalid_value'}
|
||||
|
||||
mock_error_response.headers = {"Retry-After": "invalid_value"}
|
||||
|
||||
successful_response = {
|
||||
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
"channels": [
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
successful_response
|
||||
successful_response,
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
|
||||
]
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None")
|
||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
|
||||
)
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_rate_limit_no_retry_after_header(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_rate_limit_no_retry_after_header(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {}
|
||||
|
||||
successful_response = {
|
||||
"channels": [{"name": "general", "id": "C1", "is_private": False, "is_member": True}],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
}
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
successful_response
|
||||
]
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [{"id": "C1", "name": "general", "is_private": False, "is_member": True}]
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||
mock_logger.warning.assert_called_once_with("Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None")
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_other_slack_api_error(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 500
|
||||
mock_error_response.headers = {}
|
||||
mock_error_response.data = {"ok": False, "error": "internal_error"}
|
||||
|
||||
original_error = SlackApiError(message="server error", response=mock_error_response)
|
||||
mock_client_instance.conversations_list.side_effect = original_error
|
||||
|
||||
|
||||
successful_response = {
|
||||
"channels": [
|
||||
{"name": "general", "id": "C1", "is_private": False, "is_member": True}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
successful_response,
|
||||
]
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C1", "name": "general", "is_private": False, "is_member": True}
|
||||
]
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
mock_sleep.assert_called_once_with(60) # Default fallback
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Slack API rate limit hit while fetching channels. Waiting for 60 seconds. Cursor: None"
|
||||
)
|
||||
self.assertEqual(mock_client_instance.conversations_list.call_count, 2)
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_other_slack_api_error(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 500
|
||||
mock_error_response.headers = {}
|
||||
mock_error_response.data = {"ok": False, "error": "internal_error"}
|
||||
|
||||
original_error = SlackApiError(
|
||||
message="server error", response=mock_error_response
|
||||
)
|
||||
mock_client_instance.conversations_list.side_effect = original_error
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
with self.assertRaises(SlackApiError) as context:
|
||||
slack_history.get_all_channels(include_private=True)
|
||||
|
||||
|
||||
self.assertEqual(context.exception.response.status_code, 500)
|
||||
self.assertIn("server error", str(context.exception))
|
||||
mock_sleep.assert_not_called()
|
||||
mock_logger.warning.assert_not_called() # Ensure no rate limit log
|
||||
mock_logger.warning.assert_not_called() # Ensure no rate limit log
|
||||
mock_client_instance.conversations_list.assert_called_once_with(
|
||||
types="public_channel,private_channel", cursor=None, limit=1000
|
||||
)
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_get_all_channels_handles_missing_name_id_gracefully(self, MockWebClient, mock_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_get_all_channels_handles_missing_name_id_gracefully(
|
||||
self, mock_web_client, mock_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
response_with_malformed_data = {
|
||||
"channels": [
|
||||
{"id": "C1_missing_name", "is_private": False, "is_member": True},
|
||||
{"id": "C1_missing_name", "is_private": False, "is_member": True},
|
||||
{"name": "channel_missing_id", "is_private": False, "is_member": True},
|
||||
{"name": "general", "id": "C2_valid", "is_private": False, "is_member": True}
|
||||
{
|
||||
"name": "general",
|
||||
"id": "C2_valid",
|
||||
"is_private": False,
|
||||
"is_member": True,
|
||||
},
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""}
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
|
||||
mock_client_instance.conversations_list.return_value = response_with_malformed_data
|
||||
|
||||
|
||||
mock_client_instance.conversations_list.return_value = (
|
||||
response_with_malformed_data
|
||||
)
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
channels_list = slack_history.get_all_channels(include_private=True)
|
||||
|
||||
expected_channels_list = [
|
||||
{"id": "C2_valid", "name": "general", "is_private": False, "is_member": True}
|
||||
]
|
||||
self.assertEqual(len(channels_list), 1)
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
|
||||
self.assertEqual(mock_logger.warning.call_count, 2)
|
||||
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}")
|
||||
mock_logger.warning.assert_any_call("Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}")
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
expected_channels_list = [
|
||||
{
|
||||
"id": "C2_valid",
|
||||
"name": "general",
|
||||
"is_private": False,
|
||||
"is_member": True,
|
||||
}
|
||||
]
|
||||
self.assertEqual(len(channels_list), 1)
|
||||
self.assertListEqual(channels_list, expected_channels_list)
|
||||
|
||||
self.assertEqual(mock_logger.warning.call_count, 2)
|
||||
mock_logger.warning.assert_any_call(
|
||||
"Channel found with missing name or id. Data: {'id': 'C1_missing_name', 'is_private': False, 'is_member': True}"
|
||||
)
|
||||
mock_logger.warning.assert_any_call(
|
||||
"Channel found with missing name or id. Data: {'name': 'channel_missing_id', 'is_private': False, 'is_member': True}"
|
||||
)
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
mock_client_instance.conversations_list.assert_called_once_with(
|
||||
types="public_channel,private_channel", cursor=None, limit=1000
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_proactive_delay_single_page(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
class TestSlackHistoryGetConversationHistory(unittest.TestCase):
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_proactive_delay_single_page(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
mock_client_instance.conversations_history.return_value = {
|
||||
"messages": [{"text": "msg1"}],
|
||||
"has_more": False
|
||||
"has_more": False,
|
||||
}
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_proactive_delay_multiple_pages(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_proactive_delay_multiple_pages(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
mock_client_instance.conversations_history.side_effect = [
|
||||
{
|
||||
"messages": [{"text": "msg1"}],
|
||||
"has_more": True,
|
||||
"response_metadata": {"next_cursor": "cursor1"}
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
{
|
||||
"messages": [{"text": "msg2"}],
|
||||
"has_more": False
|
||||
}
|
||||
{"messages": [{"text": "msg2"}], "has_more": False},
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
|
||||
# Expected calls: 1.2 (page1), 1.2 (page2)
|
||||
self.assertEqual(mock_time_sleep.call_count, 2)
|
||||
mock_time_sleep.assert_has_calls([call(1.2), call(1.2)])
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {'Retry-After': '5'}
|
||||
|
||||
mock_error_response.headers = {"Retry-After": "5"}
|
||||
|
||||
mock_client_instance.conversations_history.side_effect = [
|
||||
SlackApiError(message="ratelimited", response=mock_error_response),
|
||||
{"messages": [{"text": "msg1"}], "has_more": False}
|
||||
{"messages": [{"text": "msg1"}], "has_more": False},
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
messages = slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
|
||||
self.assertEqual(len(messages), 1)
|
||||
self.assertEqual(messages[0]["text"], "msg1")
|
||||
|
||||
# Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt)
|
||||
mock_time_sleep.assert_has_calls([call(1.2), call(5), call(1.2)], any_order=False)
|
||||
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_not_in_channel_error(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
# Expected sleep calls: 1.2 (proactive for 1st attempt), 5 (rate limit), 1.2 (proactive for 2nd attempt)
|
||||
mock_time_sleep.assert_has_calls(
|
||||
[call(1.2), call(5), call(1.2)], any_order=False
|
||||
)
|
||||
mock_logger.warning.assert_called_once() # Check that a warning was logged for rate limiting
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_not_in_channel_error(self, mock_web_client, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 403 # Typical for not_in_channel, but data matters more
|
||||
mock_error_response.data = {'ok': False, 'error': 'not_in_channel'}
|
||||
|
||||
mock_error_response.status_code = (
|
||||
403 # Typical for not_in_channel, but data matters more
|
||||
)
|
||||
mock_error_response.data = {"ok": False, "error": "not_in_channel"}
|
||||
|
||||
# This error is now raised by the inner try-except, then caught by the outer one
|
||||
mock_client_instance.conversations_history.side_effect = SlackApiError(
|
||||
message="not_in_channel error",
|
||||
response=mock_error_response
|
||||
message="not_in_channel error", response=mock_error_response
|
||||
)
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
messages = slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
|
||||
self.assertEqual(messages, [])
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Bot is not in channel 'C123'. Cannot fetch history. Please add the bot to this channel."
|
||||
)
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay before the API call
|
||||
mock_time_sleep.assert_called_once_with(
|
||||
1.2
|
||||
) # Proactive delay before the API call
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_other_slack_api_error_propagates(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 500
|
||||
mock_error_response.data = {'ok': False, 'error': 'internal_error'}
|
||||
original_error = SlackApiError(message="server error", response=mock_error_response)
|
||||
mock_error_response.data = {"ok": False, "error": "internal_error"}
|
||||
original_error = SlackApiError(
|
||||
message="server error", response=mock_error_response
|
||||
)
|
||||
|
||||
mock_client_instance.conversations_history.side_effect = original_error
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
|
||||
with self.assertRaises(SlackApiError) as context:
|
||||
slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
self.assertIn("Error retrieving history for channel C123", str(context.exception))
|
||||
self.assertIs(context.exception.response, mock_error_response)
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
self.assertIn(
|
||||
"Error retrieving history for channel C123", str(context.exception)
|
||||
)
|
||||
self.assertIs(context.exception.response, mock_error_response)
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_general_exception_propagates(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
original_error = Exception("Something broke")
|
||||
mock_client_instance.conversations_history.side_effect = original_error
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
with self.assertRaises(Exception) as context: # Check for generic Exception
|
||||
|
||||
with self.assertRaises(Exception) as context: # Check for generic Exception
|
||||
slack_history.get_conversation_history(channel_id="C123")
|
||||
|
||||
self.assertIs(context.exception, original_error) # Should re-raise the original error
|
||||
mock_logger.error.assert_called_once_with("Unexpected error in get_conversation_history for channel C123: Something broke")
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
self.assertIs(
|
||||
context.exception, original_error
|
||||
) # Should re-raise the original error
|
||||
mock_logger.error.assert_called_once_with(
|
||||
"Unexpected error in get_conversation_history for channel C123: Something broke"
|
||||
)
|
||||
mock_time_sleep.assert_called_once_with(1.2) # Proactive delay
|
||||
|
||||
|
||||
class TestSlackHistoryGetUserInfo(unittest.TestCase):
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_retry_after_logic(self, mock_web_client, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_retry_after_logic(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 429
|
||||
mock_error_response.headers = {'Retry-After': '3'} # Using 3 seconds for test
|
||||
|
||||
mock_error_response.headers = {"Retry-After": "3"} # Using 3 seconds for test
|
||||
|
||||
successful_user_data = {"id": "U123", "name": "testuser"}
|
||||
|
||||
|
||||
mock_client_instance.users_info.side_effect = [
|
||||
SlackApiError(message="ratelimited_userinfo", response=mock_error_response),
|
||||
{"user": successful_user_data}
|
||||
{"user": successful_user_data},
|
||||
]
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
user_info = slack_history.get_user_info(user_id="U123")
|
||||
|
||||
|
||||
self.assertEqual(user_info, successful_user_data)
|
||||
|
||||
|
||||
# Assert that time.sleep was called for the rate limit
|
||||
mock_time_sleep.assert_called_once_with(3)
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
|
|
@ -375,46 +450,58 @@ class TestSlackHistoryGetUserInfo(unittest.TestCase):
|
|||
)
|
||||
# Assert users_info was called twice (original + retry)
|
||||
self.assertEqual(mock_client_instance.users_info.call_count, 2)
|
||||
mock_client_instance.users_info.assert_has_calls([call(user="U123"), call(user="U123")])
|
||||
mock_client_instance.users_info.assert_has_calls(
|
||||
[call(user="U123"), call(user="U123")]
|
||||
)
|
||||
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch(
|
||||
"surfsense_backend.app.connectors.slack_history.time.sleep"
|
||||
) # time.sleep might be called by other logic, but not expected here
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_other_slack_api_error_propagates(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep') # time.sleep might be called by other logic, but not expected here
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_other_slack_api_error_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
|
||||
mock_error_response = Mock()
|
||||
mock_error_response.status_code = 500 # Some other error
|
||||
mock_error_response.data = {'ok': False, 'error': 'internal_server_error'}
|
||||
original_error = SlackApiError(message="internal server error", response=mock_error_response)
|
||||
mock_error_response.status_code = 500 # Some other error
|
||||
mock_error_response.data = {"ok": False, "error": "internal_server_error"}
|
||||
original_error = SlackApiError(
|
||||
message="internal server error", response=mock_error_response
|
||||
)
|
||||
|
||||
mock_client_instance.users_info.side_effect = original_error
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
|
||||
with self.assertRaises(SlackApiError) as context:
|
||||
slack_history.get_user_info(user_id="U123")
|
||||
|
||||
|
||||
# Check that the raised error is the one we expect
|
||||
self.assertIn("Error retrieving user info for U123", str(context.exception))
|
||||
self.assertIs(context.exception.response, mock_error_response)
|
||||
mock_time_sleep.assert_not_called() # No rate limit sleep
|
||||
mock_time_sleep.assert_not_called() # No rate limit sleep
|
||||
|
||||
@patch('surfsense_backend.app.connectors.slack_history.logger')
|
||||
@patch('surfsense_backend.app.connectors.slack_history.time.sleep')
|
||||
@patch('slack_sdk.WebClient')
|
||||
def test_general_exception_propagates(self, MockWebClient, mock_time_sleep, mock_logger):
|
||||
mock_client_instance = MockWebClient.return_value
|
||||
@patch("surfsense_backend.app.connectors.slack_history.logger")
|
||||
@patch("surfsense_backend.app.connectors.slack_history.time.sleep")
|
||||
@patch("slack_sdk.WebClient")
|
||||
def test_general_exception_propagates(
|
||||
self, mock_web_client, mock_time_sleep, mock_logger
|
||||
):
|
||||
mock_client_instance = mock_web_client.return_value
|
||||
original_error = Exception("A very generic problem")
|
||||
mock_client_instance.users_info.side_effect = original_error
|
||||
|
||||
|
||||
slack_history = SlackHistory(token="fake_token")
|
||||
|
||||
|
||||
with self.assertRaises(Exception) as context:
|
||||
slack_history.get_user_info(user_id="U123")
|
||||
|
||||
self.assertIs(context.exception, original_error) # Check it's the exact same exception
|
||||
|
||||
self.assertIs(
|
||||
context.exception, original_error
|
||||
) # Check it's the exact same exception
|
||||
mock_logger.error.assert_called_once_with(
|
||||
"Unexpected error in get_user_info for user U123: A very generic problem"
|
||||
)
|
||||
mock_time_sleep.assert_not_called() # No rate limit sleep
|
||||
mock_time_sleep.assert_not_called() # No rate limit sleep
|
||||
|
|
|
|||
|
|
@ -1,22 +1,22 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
JSON,
|
||||
TIMESTAMP,
|
||||
Boolean,
|
||||
Column,
|
||||
Enum as SQLAlchemyEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
TIMESTAMP
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
|
@ -27,16 +27,7 @@ from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
|||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from fastapi_users.db import (
|
||||
SQLAlchemyBaseOAuthAccountTableUUID,
|
||||
SQLAlchemyBaseUserTableUUID,
|
||||
SQLAlchemyUserDatabase,
|
||||
)
|
||||
else:
|
||||
from fastapi_users.db import (
|
||||
SQLAlchemyBaseUserTableUUID,
|
||||
SQLAlchemyUserDatabase,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
|
||||
DATABASE_URL = config.DATABASE_URL
|
||||
|
||||
|
|
@ -51,9 +42,11 @@ class DocumentType(str, Enum):
|
|||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||||
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
||||
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
|
||||
JIRA_CONNECTOR = "JIRA_CONNECTOR"
|
||||
|
||||
|
||||
class SearchSourceConnectorType(str, Enum):
|
||||
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
||||
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
||||
TAVILY_API = "TAVILY_API"
|
||||
LINKUP_API = "LINKUP_API"
|
||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||
|
|
@ -61,13 +54,16 @@ class SearchSourceConnectorType(str, Enum):
|
|||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||||
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
||||
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
|
||||
|
||||
JIRA_CONNECTOR = "JIRA_CONNECTOR"
|
||||
|
||||
|
||||
class ChatType(str, Enum):
|
||||
QNA = "QNA"
|
||||
REPORT_GENERAL = "REPORT_GENERAL"
|
||||
REPORT_DEEP = "REPORT_DEEP"
|
||||
REPORT_DEEPER = "REPORT_DEEPER"
|
||||
|
||||
|
||||
class LiteLLMProvider(str, Enum):
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
|
|
@ -92,6 +88,7 @@ class LiteLLMProvider(str, Enum):
|
|||
PETALS = "PETALS"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
|
||||
class LogLevel(str, Enum):
|
||||
DEBUG = "DEBUG"
|
||||
INFO = "INFO"
|
||||
|
|
@ -99,18 +96,27 @@ class LogLevel(str, Enum):
|
|||
ERROR = "ERROR"
|
||||
CRITICAL = "CRITICAL"
|
||||
|
||||
|
||||
class LogStatus(str, Enum):
|
||||
IN_PROGRESS = "IN_PROGRESS"
|
||||
SUCCESS = "SUCCESS"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
@declared_attr
|
||||
def created_at(cls):
|
||||
return Column(TIMESTAMP(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), index=True)
|
||||
def created_at(cls): # noqa: N805
|
||||
return Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
__abstract__ = True
|
||||
|
|
@ -118,6 +124,7 @@ class BaseModel(Base):
|
|||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
class Chat(BaseModel, TimestampMixin):
|
||||
__tablename__ = "chats"
|
||||
|
||||
|
|
@ -125,73 +132,115 @@ class Chat(BaseModel, TimestampMixin):
|
|||
title = Column(String, nullable=False, index=True)
|
||||
initial_connectors = Column(ARRAY(String), nullable=True)
|
||||
messages = Column(JSON, nullable=False)
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey('searchspaces.id', ondelete='CASCADE'), nullable=False)
|
||||
search_space = relationship('SearchSpace', back_populates='chats')
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="chats")
|
||||
|
||||
|
||||
class Document(BaseModel, TimestampMixin):
|
||||
__tablename__ = "documents"
|
||||
|
||||
|
||||
title = Column(String, nullable=False, index=True)
|
||||
document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False)
|
||||
document_metadata = Column(JSON, nullable=True)
|
||||
|
||||
|
||||
content = Column(Text, nullable=False)
|
||||
content_hash = Column(String, nullable=False, index=True, unique=True)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="documents")
|
||||
chunks = relationship("Chunk", back_populates="document", cascade="all, delete-orphan")
|
||||
chunks = relationship(
|
||||
"Chunk", back_populates="document", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Chunk(BaseModel, TimestampMixin):
|
||||
__tablename__ = "chunks"
|
||||
|
||||
|
||||
content = Column(Text, nullable=False)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
|
||||
document_id = Column(Integer, ForeignKey("documents.id", ondelete='CASCADE'), nullable=False)
|
||||
|
||||
document_id = Column(
|
||||
Integer, ForeignKey("documents.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
document = relationship("Document", back_populates="chunks")
|
||||
|
||||
|
||||
class Podcast(BaseModel, TimestampMixin):
|
||||
__tablename__ = "podcasts"
|
||||
|
||||
|
||||
title = Column(String, nullable=False, index=True)
|
||||
podcast_transcript = Column(JSON, nullable=False, default={})
|
||||
file_location = Column(String(500), nullable=False, default="")
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||||
|
||||
|
||||
|
||||
class SearchSpace(BaseModel, TimestampMixin):
|
||||
__tablename__ = "searchspaces"
|
||||
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="search_spaces")
|
||||
|
||||
documents = relationship("Document", back_populates="search_space", order_by="Document.id", cascade="all, delete-orphan")
|
||||
podcasts = relationship("Podcast", back_populates="search_space", order_by="Podcast.id", cascade="all, delete-orphan")
|
||||
chats = relationship('Chat', back_populates='search_space', order_by='Chat.id', cascade="all, delete-orphan")
|
||||
logs = relationship("Log", back_populates="search_space", order_by="Log.id", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
documents = relationship(
|
||||
"Document",
|
||||
back_populates="search_space",
|
||||
order_by="Document.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
podcasts = relationship(
|
||||
"Podcast",
|
||||
back_populates="search_space",
|
||||
order_by="Podcast.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
chats = relationship(
|
||||
"Chat",
|
||||
back_populates="search_space",
|
||||
order_by="Chat.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
logs = relationship(
|
||||
"Log",
|
||||
back_populates="search_space",
|
||||
order_by="Log.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||
__tablename__ = "search_source_connectors"
|
||||
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True)
|
||||
connector_type = Column(
|
||||
SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True
|
||||
)
|
||||
is_indexable = Column(Boolean, nullable=False, default=False)
|
||||
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
config = Column(JSON, nullable=False)
|
||||
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="search_source_connectors")
|
||||
|
||||
|
||||
class LLMConfig(BaseModel, TimestampMixin):
|
||||
__tablename__ = "llm_configs"
|
||||
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
# Provider from the enum
|
||||
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
|
||||
|
|
@ -202,78 +251,142 @@ class LLMConfig(BaseModel, TimestampMixin):
|
|||
# API Key should be encrypted before storing
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
|
||||
|
||||
# For any other parameters that litellm supports
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id])
|
||||
|
||||
|
||||
class Log(BaseModel, TimestampMixin):
|
||||
__tablename__ = "logs"
|
||||
|
||||
|
||||
level = Column(SQLAlchemyEnum(LogLevel), nullable=False, index=True)
|
||||
status = Column(SQLAlchemyEnum(LogStatus), nullable=False, index=True)
|
||||
message = Column(Text, nullable=False)
|
||||
source = Column(String(200), nullable=True, index=True) # Service/component that generated the log
|
||||
source = Column(
|
||||
String(200), nullable=True, index=True
|
||||
) # Service/component that generated the log
|
||||
log_metadata = Column(JSON, nullable=True, default={}) # Additional context data
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="logs")
|
||||
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
pass
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
"OAuthAccount", lazy="joined"
|
||||
)
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
|
||||
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
|
||||
search_source_connectors = relationship(
|
||||
"SearchSourceConnector", back_populates="user"
|
||||
)
|
||||
llm_configs = relationship(
|
||||
"LLMConfig",
|
||||
back_populates="user",
|
||||
foreign_keys="LLMConfig.user_id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
long_context_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
fast_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
strategic_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
long_context_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
|
||||
)
|
||||
fast_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
|
||||
)
|
||||
strategic_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
|
||||
)
|
||||
|
||||
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
|
||||
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
||||
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
|
||||
else:
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
|
||||
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
|
||||
search_source_connectors = relationship(
|
||||
"SearchSourceConnector", back_populates="user"
|
||||
)
|
||||
llm_configs = relationship(
|
||||
"LLMConfig",
|
||||
back_populates="user",
|
||||
foreign_keys="LLMConfig.user_id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
long_context_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
fast_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
strategic_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
|
||||
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
||||
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
|
||||
long_context_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
|
||||
)
|
||||
fast_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
|
||||
)
|
||||
strategic_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
|
||||
)
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
|
||||
async def setup_indexes():
|
||||
async with engine.begin() as conn:
|
||||
# Create indexes
|
||||
# Create indexes
|
||||
# Document Summary Indexes
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)'))
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector(\'english\', content))'))
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))"
|
||||
)
|
||||
)
|
||||
# Document Chuck Indexes
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)'))
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector(\'english\', content))'))
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def create_db_and_tables():
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
||||
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await setup_indexes()
|
||||
|
||||
|
|
@ -284,14 +397,23 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
|||
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
|
||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
||||
|
||||
else:
|
||||
|
||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||
yield SQLAlchemyUserDatabase(session, User)
|
||||
|
||||
async def get_chucks_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)):
|
||||
|
||||
|
||||
async def get_chucks_hybrid_search_retriever(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
return ChucksHybridSearchRetriever(session)
|
||||
|
||||
async def get_documents_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)):
|
||||
|
||||
async def get_documents_hybrid_search_retriever(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
return DocumentHybridSearchRetriever(session)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from datetime import datetime, timezone
|
||||
|
||||
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
|
||||
DATE_TODAY = "Today's date is " + datetime.now(UTC).astimezone().isoformat() + "\n"
|
||||
|
||||
SUMMARY_PROMPT = DATE_TODAY + """
|
||||
SUMMARY_PROMPT = (
|
||||
DATE_TODAY
|
||||
+ """
|
||||
<INSTRUCTIONS>
|
||||
<context>
|
||||
You are an expert document analyst and summarization specialist tasked with distilling complex information into clear,
|
||||
|
|
@ -96,8 +99,8 @@ SUMMARY_PROMPT = DATE_TODAY + """
|
|||
</document_to_summarize>
|
||||
</INSTRUCTIONS>
|
||||
"""
|
||||
)
|
||||
|
||||
SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
|
||||
input_variables=["document"],
|
||||
template=SUMMARY_PROMPT
|
||||
)
|
||||
input_variables=["document"], template=SUMMARY_PROMPT
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,34 +2,41 @@ class ChucksHybridSearchRetriever:
|
|||
def __init__(self, db_session):
|
||||
"""
|
||||
Initialize the hybrid search retriever with a database session.
|
||||
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
|
||||
"""
|
||||
self.db_session = db_session
|
||||
|
||||
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
async def vector_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Perform vector similarity search on chunks.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
|
||||
Returns:
|
||||
List of chunks sorted by vector similarity
|
||||
"""
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
|
||||
from app.config import config
|
||||
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Chunk)
|
||||
|
|
@ -38,45 +45,48 @@ class ChucksHybridSearchRetriever:
|
|||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
)
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add vector similarity ordering
|
||||
query = (
|
||||
query
|
||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
|
||||
return chunks
|
||||
|
||||
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
async def full_text_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Perform full-text keyword search on chunks.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
|
||||
Returns:
|
||||
List of chunks sorted by text relevance
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Chunk.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
tsvector = func.to_tsvector("english", Chunk.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Chunk)
|
||||
|
|
@ -84,64 +94,70 @@ class ChucksHybridSearchRetriever:
|
|||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
|
||||
.where(
|
||||
tsvector.op("@@")(tsquery)
|
||||
) # Only include results that match the query
|
||||
)
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add text search ranking
|
||||
query = (
|
||||
query
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
|
||||
return chunks
|
||||
|
||||
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
document_type: str | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing chunk data and relevance scores
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace, DocumentType
|
||||
|
||||
from app.config import config
|
||||
|
||||
from app.db import Chunk, Document, DocumentType, SearchSpace
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
|
||||
# Constants for RRF calculation
|
||||
k = 60 # Constant for RRF calculation
|
||||
n_results = top_k * 2 # Get more results for better fusion
|
||||
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Chunk.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
tsvector = func.to_tsvector("english", Chunk.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
||||
# Base conditions for document filtering
|
||||
base_conditions = [SearchSpace.user_id == user_id]
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
base_conditions.append(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add document type filter if provided
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
|
|
@ -154,90 +170,97 @@ class ChucksHybridSearchRetriever:
|
|||
return []
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
|
||||
|
||||
# CTE for semantic search with user ownership check
|
||||
semantic_search_cte = (
|
||||
select(
|
||||
Chunk.id,
|
||||
func.rank().over(order_by=Chunk.embedding.op("<=>")(query_embedding)).label("rank")
|
||||
func.rank()
|
||||
.over(order_by=Chunk.embedding.op("<=>")(query_embedding))
|
||||
.label("rank"),
|
||||
)
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
)
|
||||
|
||||
|
||||
semantic_search_cte = (
|
||||
semantic_search_cte
|
||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
semantic_search_cte.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
.limit(n_results)
|
||||
.cte("semantic_search")
|
||||
)
|
||||
|
||||
|
||||
# CTE for keyword search with user ownership check
|
||||
keyword_search_cte = (
|
||||
select(
|
||||
Chunk.id,
|
||||
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
|
||||
func.rank()
|
||||
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.label("rank"),
|
||||
)
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
.where(tsvector.op("@@")(tsquery))
|
||||
)
|
||||
|
||||
|
||||
keyword_search_cte = (
|
||||
keyword_search_cte
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(n_results)
|
||||
.cte("keyword_search")
|
||||
)
|
||||
|
||||
|
||||
# Final combined query using a FULL OUTER JOIN with RRF scoring
|
||||
final_query = (
|
||||
select(
|
||||
Chunk,
|
||||
(
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
|
||||
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score")
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
|
||||
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score"),
|
||||
)
|
||||
.select_from(
|
||||
semantic_search_cte.outerjoin(
|
||||
keyword_search_cte,
|
||||
keyword_search_cte,
|
||||
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
||||
full=True
|
||||
full=True,
|
||||
)
|
||||
)
|
||||
.join(
|
||||
Chunk,
|
||||
Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
|
||||
Chunk.id
|
||||
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
|
||||
)
|
||||
.options(joinedload(Chunk.document))
|
||||
.order_by(text("score DESC"))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(final_query)
|
||||
chunks_with_scores = result.all()
|
||||
|
||||
|
||||
# If no results were found, return an empty list
|
||||
if not chunks_with_scores:
|
||||
return []
|
||||
|
||||
|
||||
# Convert to serializable dictionaries if no reranker is available or if reranking failed
|
||||
serialized_results = []
|
||||
for chunk, score in chunks_with_scores:
|
||||
serialized_results.append({
|
||||
"chunk_id": chunk.id,
|
||||
"content": chunk.content,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"document": {
|
||||
"id": chunk.document.id,
|
||||
"title": chunk.document.title,
|
||||
"document_type": chunk.document.document_type.value if hasattr(chunk.document, 'document_type') else None,
|
||||
"metadata": chunk.document.document_metadata
|
||||
serialized_results.append(
|
||||
{
|
||||
"chunk_id": chunk.id,
|
||||
"content": chunk.content,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"document": {
|
||||
"id": chunk.document.id,
|
||||
"title": chunk.document.title,
|
||||
"document_type": chunk.document.document_type.value
|
||||
if hasattr(chunk.document, "document_type")
|
||||
else None,
|
||||
"metadata": chunk.document.document_metadata,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
)
|
||||
|
||||
return serialized_results
|
||||
|
|
|
|||
|
|
@ -2,34 +2,41 @@ class DocumentHybridSearchRetriever:
|
|||
def __init__(self, db_session):
|
||||
"""
|
||||
Initialize the hybrid search retriever with a database session.
|
||||
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
|
||||
"""
|
||||
self.db_session = db_session
|
||||
|
||||
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
async def vector_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Perform vector similarity search on documents.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
|
||||
Returns:
|
||||
List of documents sorted by vector similarity
|
||||
"""
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Document, SearchSpace
|
||||
|
||||
from app.config import config
|
||||
|
||||
from app.db import Document, SearchSpace
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Document)
|
||||
|
|
@ -37,107 +44,118 @@ class DocumentHybridSearchRetriever:
|
|||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
)
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add vector similarity ordering
|
||||
query = (
|
||||
query
|
||||
.order_by(Document.embedding.op("<=>")(query_embedding))
|
||||
.limit(top_k)
|
||||
query = query.order_by(Document.embedding.op("<=>")(query_embedding)).limit(
|
||||
top_k
|
||||
)
|
||||
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
|
||||
return documents
|
||||
|
||||
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
async def full_text_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Perform full-text keyword search on documents.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
|
||||
Returns:
|
||||
List of documents sorted by text relevance
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.db import Document, SearchSpace
|
||||
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Document.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
tsvector = func.to_tsvector("english", Document.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Document)
|
||||
.options(joinedload(Document.search_space))
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
|
||||
.where(
|
||||
tsvector.op("@@")(tsquery)
|
||||
) # Only include results that match the query
|
||||
)
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add text search ranking
|
||||
query = (
|
||||
query
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
|
||||
return documents
|
||||
|
||||
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int,
|
||||
user_id: str,
|
||||
search_space_id: int | None = None,
|
||||
document_type: str | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
|
||||
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Document, SearchSpace, DocumentType
|
||||
|
||||
from app.config import config
|
||||
|
||||
from app.db import Document, DocumentType, SearchSpace
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
|
||||
# Constants for RRF calculation
|
||||
k = 60 # Constant for RRF calculation
|
||||
n_results = top_k * 2 # Get more results for better fusion
|
||||
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Document.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
tsvector = func.to_tsvector("english", Document.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
||||
# Base conditions for document filtering
|
||||
base_conditions = [SearchSpace.user_id == user_id]
|
||||
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
base_conditions.append(Document.search_space_id == search_space_id)
|
||||
|
||||
|
||||
# Add document type filter if provided
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
|
|
@ -150,98 +168,112 @@ class DocumentHybridSearchRetriever:
|
|||
return []
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
|
||||
|
||||
# CTE for semantic search with user ownership check
|
||||
semantic_search_cte = (
|
||||
select(
|
||||
Document.id,
|
||||
func.rank().over(order_by=Document.embedding.op("<=>")(query_embedding)).label("rank")
|
||||
func.rank()
|
||||
.over(order_by=Document.embedding.op("<=>")(query_embedding))
|
||||
.label("rank"),
|
||||
)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
)
|
||||
|
||||
|
||||
semantic_search_cte = (
|
||||
semantic_search_cte
|
||||
.order_by(Document.embedding.op("<=>")(query_embedding))
|
||||
semantic_search_cte.order_by(Document.embedding.op("<=>")(query_embedding))
|
||||
.limit(n_results)
|
||||
.cte("semantic_search")
|
||||
)
|
||||
|
||||
|
||||
# CTE for keyword search with user ownership check
|
||||
keyword_search_cte = (
|
||||
select(
|
||||
Document.id,
|
||||
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
|
||||
func.rank()
|
||||
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.label("rank"),
|
||||
)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
.where(tsvector.op("@@")(tsquery))
|
||||
)
|
||||
|
||||
|
||||
keyword_search_cte = (
|
||||
keyword_search_cte
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(n_results)
|
||||
.cte("keyword_search")
|
||||
)
|
||||
|
||||
|
||||
# Final combined query using a FULL OUTER JOIN with RRF scoring
|
||||
final_query = (
|
||||
select(
|
||||
Document,
|
||||
(
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
|
||||
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score")
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
|
||||
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score"),
|
||||
)
|
||||
.select_from(
|
||||
semantic_search_cte.outerjoin(
|
||||
keyword_search_cte,
|
||||
keyword_search_cte,
|
||||
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
||||
full=True
|
||||
full=True,
|
||||
)
|
||||
)
|
||||
.join(
|
||||
Document,
|
||||
Document.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
|
||||
Document.id
|
||||
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
|
||||
)
|
||||
.options(joinedload(Document.search_space))
|
||||
.order_by(text("score DESC"))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(final_query)
|
||||
documents_with_scores = result.all()
|
||||
|
||||
|
||||
# If no results were found, return an empty list
|
||||
if not documents_with_scores:
|
||||
return []
|
||||
|
||||
|
||||
# Convert to serializable dictionaries
|
||||
serialized_results = []
|
||||
for document, score in documents_with_scores:
|
||||
# Fetch associated chunks for this document
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import Chunk
|
||||
|
||||
chunks_query = select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
|
||||
|
||||
chunks_query = (
|
||||
select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
|
||||
# Concatenate chunks content
|
||||
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else document.content
|
||||
|
||||
serialized_results.append({
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"content": document.content,
|
||||
"chunks_content": concatenated_chunks_content,
|
||||
"document_type": document.document_type.value if hasattr(document, 'document_type') else None,
|
||||
"metadata": document.document_metadata,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"search_space_id": document.search_space_id
|
||||
})
|
||||
|
||||
return serialized_results
|
||||
concatenated_chunks_content = (
|
||||
" ".join([chunk.content for chunk in chunks])
|
||||
if chunks
|
||||
else document.content
|
||||
)
|
||||
|
||||
serialized_results.append(
|
||||
{
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"content": document.content,
|
||||
"chunks_content": concatenated_chunks_content,
|
||||
"document_type": document.document_type.value
|
||||
if hasattr(document, "document_type")
|
||||
else None,
|
||||
"metadata": document.document_metadata,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"search_space_id": document.search_space_id,
|
||||
}
|
||||
)
|
||||
|
||||
return serialized_results
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from fastapi import APIRouter
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
from .documents_routes import router as documents_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
|
||||
from .chats_routes import router as chats_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .documents_routes import router as documents_router
|
||||
from .llm_config_routes import router as llm_config_router
|
||||
from .logs_routes import router as logs_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,38 +1,40 @@
|
|||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import Chat, SearchSpace, User, get_async_session
|
||||
from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate
|
||||
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from langchain.schema import HumanMessage, AIMessage
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def handle_chat_data(
|
||||
request: AISDKChatRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
messages = request.messages
|
||||
if messages[-1]['role'] != "user":
|
||||
if messages[-1]["role"] != "user":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Last message must be a user message")
|
||||
status_code=400, detail="Last message must be a user message"
|
||||
)
|
||||
|
||||
user_query = messages[-1]['content']
|
||||
search_space_id = request.data.get('search_space_id')
|
||||
research_mode: str = request.data.get('research_mode')
|
||||
selected_connectors: List[str] = request.data.get('selected_connectors')
|
||||
document_ids_to_add_in_context: List[int] = request.data.get('document_ids_to_add_in_context')
|
||||
|
||||
search_mode_str = request.data.get('search_mode', "CHUNKS")
|
||||
user_query = messages[-1]["content"]
|
||||
search_space_id = request.data.get("search_space_id")
|
||||
research_mode: str = request.data.get("research_mode")
|
||||
selected_connectors: list[str] = request.data.get("selected_connectors")
|
||||
document_ids_to_add_in_context: list[int] = request.data.get(
|
||||
"document_ids_to_add_in_context"
|
||||
)
|
||||
|
||||
search_mode_str = request.data.get("search_mode", "CHUNKS")
|
||||
|
||||
# Convert search_space_id to integer if it's a string
|
||||
if search_space_id and isinstance(search_space_id, str):
|
||||
|
|
@ -40,21 +42,23 @@ async def handle_chat_data(
|
|||
search_space_id = int(search_space_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid search_space_id format")
|
||||
status_code=400, detail="Invalid search_space_id format"
|
||||
) from None
|
||||
|
||||
# Check if the search space belongs to the current user
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
except HTTPException:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this search space")
|
||||
|
||||
status_code=403, detail="You don't have access to this search space"
|
||||
) from None
|
||||
|
||||
langchain_chat_history = []
|
||||
for message in messages[:-1]:
|
||||
if message['role'] == "user":
|
||||
langchain_chat_history.append(HumanMessage(content=message['content']))
|
||||
elif message['role'] == "assistant":
|
||||
langchain_chat_history.append(AIMessage(content=message['content']))
|
||||
if message["role"] == "user":
|
||||
langchain_chat_history.append(HumanMessage(content=message["content"]))
|
||||
elif message["role"] == "assistant":
|
||||
langchain_chat_history.append(AIMessage(content=message["content"]))
|
||||
|
||||
response = StreamingResponse(
|
||||
stream_connector_search_results(
|
||||
|
|
@ -69,7 +73,7 @@ async def handle_chat_data(
|
|||
document_ids_to_add_in_context,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
response.headers["x-vercel-ai-data-stream"] = "v1"
|
||||
return response
|
||||
|
||||
|
|
@ -78,7 +82,7 @@ async def handle_chat_data(
|
|||
async def create_chat(
|
||||
chat: ChatCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, chat.search_space_id, user)
|
||||
|
|
@ -89,52 +93,57 @@ async def create_chat(
|
|||
return db_chat
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError as e:
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
||||
except OperationalError as e:
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception as e:
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while creating the chat.")
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while creating the chat.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/chats/", response_model=List[ChatRead])
|
||||
@router.get("/chats/", response_model=list[ChatRead])
|
||||
async def read_chats(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: int = None,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
query = select(Chat).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
||||
|
||||
|
||||
# Filter by search_space_id if provided
|
||||
if search_space_id is not None:
|
||||
query = query.filter(Chat.search_space_id == search_space_id)
|
||||
|
||||
result = await session.execute(
|
||||
query.offset(skip).limit(limit)
|
||||
)
|
||||
|
||||
result = await session.execute(query.offset(skip).limit(limit))
|
||||
return result.scalars().all()
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while fetching chats.")
|
||||
status_code=500, detail="An unexpected error occurred while fetching chats."
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/chats/{chat_id}", response_model=ChatRead)
|
||||
async def read_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -145,14 +154,19 @@ async def read_chat(
|
|||
chat = result.scalars().first()
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat not found or you don't have permission to access it")
|
||||
status_code=404,
|
||||
detail="Chat not found or you don't have permission to access it",
|
||||
)
|
||||
return chat
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while fetching the chat.")
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while fetching the chat.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.put("/chats/{chat_id}", response_model=ChatRead)
|
||||
|
|
@ -160,7 +174,7 @@ async def update_chat(
|
|||
chat_id: int,
|
||||
chat_update: ChatUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_chat = await read_chat(chat_id, session, user)
|
||||
|
|
@ -175,22 +189,27 @@ async def update_chat(
|
|||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while updating the chat.")
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while updating the chat.",
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/chats/{chat_id}", response_model=dict)
|
||||
async def delete_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_chat = await read_chat(chat_id, session, user)
|
||||
|
|
@ -202,81 +221,16 @@ async def delete_chat(
|
|||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete chat due to existing dependencies.")
|
||||
status_code=400, detail="Cannot delete chat due to existing dependencies."
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while deleting the chat.")
|
||||
|
||||
|
||||
# test_data = [
|
||||
# {
|
||||
# "type": "TERMINAL_INFO",
|
||||
# "content": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "text": "Starting to search for crawled URLs...",
|
||||
# "type": "info"
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "text": "Found 2 relevant crawled URLs",
|
||||
# "type": "success"
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "type": "SOURCES",
|
||||
# "content": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "name": "Crawled URLs",
|
||||
# "type": "CRAWLED_URL",
|
||||
# "sources": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://jsoneditoronline.org/"
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://www.google.com/"
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "name": "Files",
|
||||
# "type": "FILE",
|
||||
# "sources": [
|
||||
# {
|
||||
# "id": 3,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://jsoneditoronline.org/"
|
||||
# },
|
||||
# {
|
||||
# "id": 4,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://www.google.com/"
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "type": "ANSWER",
|
||||
# "content": [
|
||||
# "## SurfSense Introduction",
|
||||
# "Surfsense is A Personal NotebookLM and Perplexity-like AI Assistant for Everyone. Research and Never forget Anything. [1] [3]"
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while deleting the chat.",
|
||||
) from None
|
||||
|
|
|
|||
|
|
@ -1,23 +1,35 @@
|
|||
from litellm import atranscription
|
||||
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List
|
||||
from app.db import Log, get_async_session, User, SearchSpace, Document, DocumentType
|
||||
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from app.tasks.background_tasks import add_received_markdown_file_document, add_extension_received_document, add_received_file_document_using_unstructured, add_crawled_url_document, add_youtube_video_document, add_received_file_document_using_llamacloud, add_received_file_document_using_docling
|
||||
from app.config import config as app_config
|
||||
# Force asyncio to use standard event loop before unstructured imports
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Form, HTTPException, UploadFile
|
||||
from litellm import atranscription
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, DocumentType, Log, SearchSpace, User, get_async_session
|
||||
from app.schemas import DocumentRead, DocumentsCreate, DocumentUpdate
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.background_tasks import (
|
||||
add_crawled_url_document,
|
||||
add_extension_received_document,
|
||||
add_received_file_document_using_docling,
|
||||
add_received_file_document_using_llamacloud,
|
||||
add_received_file_document_using_unstructured,
|
||||
add_received_markdown_file_document,
|
||||
add_youtube_video_document,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
try:
|
||||
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
||||
except RuntimeError:
|
||||
except RuntimeError as e:
|
||||
print("Error setting event loop policy", e)
|
||||
pass
|
||||
|
||||
import os
|
||||
|
||||
os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
||||
|
||||
|
||||
|
|
@ -29,7 +41,7 @@ async def create_documents(
|
|||
request: DocumentsCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
try:
|
||||
# Check if the user owns the search space
|
||||
|
|
@ -41,7 +53,7 @@ async def create_documents(
|
|||
process_extension_document_with_new_session,
|
||||
individual_document,
|
||||
request.search_space_id,
|
||||
str(user.id)
|
||||
str(user.id),
|
||||
)
|
||||
elif request.document_type == DocumentType.CRAWLED_URL:
|
||||
for url in request.content:
|
||||
|
|
@ -49,7 +61,7 @@ async def create_documents(
|
|||
process_crawled_url_with_new_session,
|
||||
url,
|
||||
request.search_space_id,
|
||||
str(user.id)
|
||||
str(user.id),
|
||||
)
|
||||
elif request.document_type == DocumentType.YOUTUBE_VIDEO:
|
||||
for url in request.content:
|
||||
|
|
@ -57,13 +69,10 @@ async def create_documents(
|
|||
process_youtube_video_with_new_session,
|
||||
url,
|
||||
request.search_space_id,
|
||||
str(user.id)
|
||||
str(user.id),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid document type"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Invalid document type")
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Documents processed successfully"}
|
||||
|
|
@ -72,18 +81,17 @@ async def create_documents(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to process documents: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to process documents: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/documents/fileupload")
|
||||
async def create_documents(
|
||||
async def create_documents_file_upload(
|
||||
files: list[UploadFile],
|
||||
search_space_id: int = Form(...),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
|
@ -94,31 +102,32 @@ async def create_documents(
|
|||
for file in files:
|
||||
try:
|
||||
# Save file to a temporary location to avoid stream issues
|
||||
import tempfile
|
||||
import aiofiles
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
# Create temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=os.path.splitext(file.filename)[1]
|
||||
) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
# Write uploaded file to temp file
|
||||
content = await file.read()
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
fastapi_background_tasks.add_task(
|
||||
process_file_in_background_with_new_session,
|
||||
temp_path,
|
||||
file.filename,
|
||||
search_space_id,
|
||||
str(user.id)
|
||||
str(user.id),
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to process file {file.filename}: {str(e)}"
|
||||
)
|
||||
detail=f"Failed to process file {file.filename}: {e!s}",
|
||||
) from e
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Files uploaded for processing"}
|
||||
|
|
@ -127,9 +136,8 @@ async def create_documents(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to upload files: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to upload files: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def process_file_in_background(
|
||||
|
|
@ -139,64 +147,71 @@ async def process_file_in_background(
|
|||
user_id: str,
|
||||
session: AsyncSession,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: Log
|
||||
log_entry: Log,
|
||||
):
|
||||
try:
|
||||
# Check if the file is a markdown or text file
|
||||
if filename.lower().endswith(('.md', '.markdown', '.txt')):
|
||||
if filename.lower().endswith((".md", ".markdown", ".txt")):
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing markdown/text file: {filename}",
|
||||
{"file_type": "markdown", "processing_stage": "reading_file"}
|
||||
{"file_type": "markdown", "processing_stage": "reading_file"},
|
||||
)
|
||||
|
||||
|
||||
# For markdown files, read the content directly
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
markdown_content = f.read()
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Creating document from markdown content: {filename}",
|
||||
{"processing_stage": "creating_document", "content_length": len(markdown_content)}
|
||||
{
|
||||
"processing_stage": "creating_document",
|
||||
"content_length": len(markdown_content),
|
||||
},
|
||||
)
|
||||
|
||||
# Process markdown directly through specialized function
|
||||
result = await add_received_markdown_file_document(
|
||||
session,
|
||||
filename,
|
||||
markdown_content,
|
||||
search_space_id,
|
||||
user_id
|
||||
session, filename, markdown_content, search_space_id, user_id
|
||||
)
|
||||
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed markdown file: {filename}",
|
||||
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "markdown"}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "markdown",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Markdown file already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "markdown"}
|
||||
{"duplicate_detected": True, "file_type": "markdown"},
|
||||
)
|
||||
|
||||
|
||||
# Check if the file is an audio file
|
||||
elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')):
|
||||
elif filename.lower().endswith(
|
||||
(".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
|
||||
):
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing audio file for transcription: {filename}",
|
||||
{"file_type": "audio", "processing_stage": "starting_transcription"}
|
||||
{"file_type": "audio", "processing_stage": "starting_transcription"},
|
||||
)
|
||||
|
||||
|
||||
# Open the audio file for transcription
|
||||
with open(file_path, "rb") as audio_file:
|
||||
# Use LiteLLM for audio transcription
|
||||
|
|
@ -205,65 +220,76 @@ async def process_file_in_background(
|
|||
model=app_config.STT_SERVICE,
|
||||
file=audio_file,
|
||||
api_base=app_config.STT_SERVICE_API_BASE,
|
||||
api_key=app_config.STT_SERVICE_API_KEY
|
||||
api_key=app_config.STT_SERVICE_API_KEY,
|
||||
)
|
||||
else:
|
||||
transcription_response = await atranscription(
|
||||
model=app_config.STT_SERVICE,
|
||||
api_key=app_config.STT_SERVICE_API_KEY,
|
||||
file=audio_file
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
# Extract the transcribed text
|
||||
transcribed_text = transcription_response.get("text", "")
|
||||
|
||||
# Add metadata about the transcription
|
||||
transcribed_text = f"# Transcription of {filename}\n\n{transcribed_text}"
|
||||
transcribed_text = (
|
||||
f"# Transcription of {filename}\n\n{transcribed_text}"
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Transcription completed, creating document: {filename}",
|
||||
{"processing_stage": "transcription_complete", "transcript_length": len(transcribed_text)}
|
||||
{
|
||||
"processing_stage": "transcription_complete",
|
||||
"transcript_length": len(transcribed_text),
|
||||
},
|
||||
)
|
||||
|
||||
# Clean up the temp file
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
# Process transcription as markdown document
|
||||
result = await add_received_markdown_file_document(
|
||||
session,
|
||||
filename,
|
||||
transcribed_text,
|
||||
search_space_id,
|
||||
user_id
|
||||
session, filename, transcribed_text, search_space_id, user_id
|
||||
)
|
||||
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully transcribed and processed audio file: {filename}",
|
||||
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "audio", "transcript_length": len(transcribed_text)}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "audio",
|
||||
"transcript_length": len(transcribed_text),
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Audio file transcript already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "audio"}
|
||||
{"duplicate_detected": True, "file_type": "audio"},
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing file with Unstructured ETL: {filename}",
|
||||
{"file_type": "document", "etl_service": "UNSTRUCTURED", "processing_stage": "loading"}
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
"processing_stage": "loading",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
|
||||
# Process the file
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
|
|
@ -280,212 +306,257 @@ async def process_file_in_background(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Unstructured ETL completed, creating document: {filename}",
|
||||
{"processing_stage": "etl_complete", "elements_count": len(docs)}
|
||||
{"processing_stage": "etl_complete", "elements_count": len(docs)},
|
||||
)
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
# Pass the documents to the existing background task
|
||||
result = await add_received_file_document_using_unstructured(
|
||||
session,
|
||||
filename,
|
||||
docs,
|
||||
search_space_id,
|
||||
user_id
|
||||
session, filename, docs, search_space_id, user_id
|
||||
)
|
||||
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file with Unstructured: {filename}",
|
||||
{"document_id": result.id, "content_hash": result.content_hash, "file_type": "document", "etl_service": "UNSTRUCTURED"}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Document already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "document", "etl_service": "UNSTRUCTURED"}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
elif app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing file with LlamaCloud ETL: {filename}",
|
||||
{"file_type": "document", "etl_service": "LLAMACLOUD", "processing_stage": "parsing"}
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
"processing_stage": "parsing",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
|
||||
# Create LlamaParse parser instance
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1, # Use single worker for file processing
|
||||
verbose=True,
|
||||
language="en",
|
||||
result_type=ResultType.MD
|
||||
result_type=ResultType.MD,
|
||||
)
|
||||
|
||||
|
||||
# Parse the file asynchronously
|
||||
result = await parser.aparse(file_path)
|
||||
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
|
||||
# Get markdown documents from the result
|
||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||
|
||||
markdown_documents = await result.aget_markdown_documents(
|
||||
split_by_page=False
|
||||
)
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"LlamaCloud parsing completed, creating documents: {filename}",
|
||||
{"processing_stage": "parsing_complete", "documents_count": len(markdown_documents)}
|
||||
{
|
||||
"processing_stage": "parsing_complete",
|
||||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
for doc in markdown_documents:
|
||||
# Extract text content from the markdown documents
|
||||
markdown_content = doc.text
|
||||
|
||||
|
||||
# Process the documents using our LlamaCloud background task
|
||||
doc_result = await add_received_file_document_using_llamacloud(
|
||||
session,
|
||||
filename,
|
||||
llamacloud_markdown_document=markdown_content,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
if doc_result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file with LlamaCloud: {filename}",
|
||||
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "LLAMACLOUD"}
|
||||
{
|
||||
"document_id": doc_result.id,
|
||||
"content_hash": doc_result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Document already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "document", "etl_service": "LLAMACLOUD"}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
elif app_config.ETL_SERVICE == "DOCLING":
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing file with Docling ETL: {filename}",
|
||||
{"file_type": "document", "etl_service": "DOCLING", "processing_stage": "parsing"}
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
"processing_stage": "parsing",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Use Docling service for document processing
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
|
||||
# Create Docling service
|
||||
docling_service = create_docling_service()
|
||||
|
||||
|
||||
# Process the document
|
||||
result = await docling_service.process_document(file_path, filename)
|
||||
|
||||
|
||||
# Clean up the temp file
|
||||
import os
|
||||
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
except Exception as e:
|
||||
print("Error deleting temp file", e)
|
||||
pass
|
||||
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Docling parsing completed, creating document: {filename}",
|
||||
{"processing_stage": "parsing_complete", "content_length": len(result['content'])}
|
||||
{
|
||||
"processing_stage": "parsing_complete",
|
||||
"content_length": len(result["content"]),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Process the document using our Docling background task
|
||||
doc_result = await add_received_file_document_using_docling(
|
||||
session,
|
||||
filename,
|
||||
docling_markdown_document=result['content'],
|
||||
docling_markdown_document=result["content"],
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
if doc_result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file with Docling: {filename}",
|
||||
{"document_id": doc_result.id, "content_hash": doc_result.content_hash, "file_type": "document", "etl_service": "DOCLING"}
|
||||
{
|
||||
"document_id": doc_result.id,
|
||||
"content_hash": doc_result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Document already exists (duplicate): {filename}",
|
||||
{"duplicate_detected": True, "file_type": "document", "etl_service": "DOCLING"}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to process file: {filename}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__, "filename": filename}
|
||||
{"error_type": type(e).__name__, "filename": filename},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing file in background: {str(e)}")
|
||||
|
||||
logging.error(f"Error processing file in background: {e!s}")
|
||||
raise # Re-raise so the wrapper can also handle it
|
||||
|
||||
|
||||
@router.get("/documents/", response_model=List[DocumentRead])
|
||||
@router.get("/documents/", response_model=list[DocumentRead])
|
||||
async def read_documents(
|
||||
skip: int = 0,
|
||||
limit: int = 300,
|
||||
search_space_id: int = None,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
query = select(Document).join(SearchSpace).filter(
|
||||
SearchSpace.user_id == user.id)
|
||||
query = (
|
||||
select(Document).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
||||
)
|
||||
|
||||
# Filter by search_space_id if provided
|
||||
if search_space_id is not None:
|
||||
query = query.filter(Document.search_space_id == search_space_id)
|
||||
|
||||
result = await session.execute(
|
||||
query.offset(skip).limit(limit)
|
||||
)
|
||||
result = await session.execute(query.offset(skip).limit(limit))
|
||||
db_documents = result.scalars().all()
|
||||
|
||||
# Convert database objects to API-friendly format
|
||||
api_documents = []
|
||||
for doc in db_documents:
|
||||
api_documents.append(DocumentRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
document_metadata=doc.document_metadata,
|
||||
content=doc.content,
|
||||
created_at=doc.created_at,
|
||||
search_space_id=doc.search_space_id
|
||||
))
|
||||
api_documents.append(
|
||||
DocumentRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
document_metadata=doc.document_metadata,
|
||||
content=doc.content,
|
||||
created_at=doc.created_at,
|
||||
search_space_id=doc.search_space_id,
|
||||
)
|
||||
)
|
||||
|
||||
return api_documents
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch documents: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch documents: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}", response_model=DocumentRead)
|
||||
async def read_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -497,8 +568,7 @@ async def read_document(
|
|||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
status_code=404, detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
# Convert database object to API-friendly format
|
||||
|
|
@ -509,13 +579,12 @@ async def read_document(
|
|||
document_metadata=document.document_metadata,
|
||||
content=document.content,
|
||||
created_at=document.created_at,
|
||||
search_space_id=document.search_space_id
|
||||
search_space_id=document.search_space_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch document: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch document: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/documents/{document_id}", response_model=DocumentRead)
|
||||
|
|
@ -523,7 +592,7 @@ async def update_document(
|
|||
document_id: int,
|
||||
document_update: DocumentUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
# Query the document directly instead of using read_document function
|
||||
|
|
@ -536,8 +605,7 @@ async def update_document(
|
|||
|
||||
if not db_document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
status_code=404, detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
update_data = document_update.model_dump(exclude_unset=True)
|
||||
|
|
@ -554,23 +622,22 @@ async def update_document(
|
|||
document_metadata=db_document.document_metadata,
|
||||
content=db_document.content,
|
||||
created_at=db_document.created_at,
|
||||
search_space_id=db_document.search_space_id
|
||||
search_space_id=db_document.search_space_id,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update document: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update document: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/documents/{document_id}", response_model=dict)
|
||||
async def delete_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
# Query the document directly instead of using read_document function
|
||||
|
|
@ -583,8 +650,7 @@ async def delete_document(
|
|||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
status_code=404, detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
await session.delete(document)
|
||||
|
|
@ -595,15 +661,12 @@ async def delete_document(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete document: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to delete document: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def process_extension_document_with_new_session(
|
||||
individual_document,
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
individual_document, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Create a new session and process extension document."""
|
||||
from app.db import async_session_maker
|
||||
|
|
@ -612,7 +675,7 @@ async def process_extension_document_with_new_session(
|
|||
async with async_session_maker() as session:
|
||||
# Initialize task logging service
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_extension_document",
|
||||
|
|
@ -622,40 +685,41 @@ async def process_extension_document_with_new_session(
|
|||
"document_type": "EXTENSION",
|
||||
"url": individual_document.metadata.VisitedWebPageURL,
|
||||
"title": individual_document.metadata.VisitedWebPageTitle,
|
||||
"user_id": user_id
|
||||
}
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await add_extension_received_document(session, individual_document, search_space_id, user_id)
|
||||
|
||||
result = await add_extension_received_document(
|
||||
session, individual_document, search_space_id, user_id
|
||||
)
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed extension document: {individual_document.metadata.VisitedWebPageTitle}",
|
||||
{"document_id": result.id, "content_hash": result.content_hash}
|
||||
{"document_id": result.id, "content_hash": result.content_hash},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Extension document already exists (duplicate): {individual_document.metadata.VisitedWebPageTitle}",
|
||||
{"duplicate_detected": True}
|
||||
{"duplicate_detected": True},
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to process extension document: {individual_document.metadata.VisitedWebPageTitle}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing extension document: {str(e)}")
|
||||
|
||||
logging.error(f"Error processing extension document: {e!s}")
|
||||
|
||||
|
||||
async def process_crawled_url_with_new_session(
|
||||
url: str,
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
url: str, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Create a new session and process crawled URL."""
|
||||
from app.db import async_session_maker
|
||||
|
|
@ -664,50 +728,50 @@ async def process_crawled_url_with_new_session(
|
|||
async with async_session_maker() as session:
|
||||
# Initialize task logging service
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_crawled_url",
|
||||
source="document_processor",
|
||||
message=f"Starting URL crawling and processing for: {url}",
|
||||
metadata={
|
||||
"document_type": "CRAWLED_URL",
|
||||
"url": url,
|
||||
"user_id": user_id
|
||||
}
|
||||
metadata={"document_type": "CRAWLED_URL", "url": url, "user_id": user_id},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await add_crawled_url_document(session, url, search_space_id, user_id)
|
||||
|
||||
result = await add_crawled_url_document(
|
||||
session, url, search_space_id, user_id
|
||||
)
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully crawled and processed URL: {url}",
|
||||
{"document_id": result.id, "title": result.title, "content_hash": result.content_hash}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"title": result.title,
|
||||
"content_hash": result.content_hash,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"URL document already exists (duplicate): {url}",
|
||||
{"duplicate_detected": True}
|
||||
{"duplicate_detected": True},
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to crawl URL: {url}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing crawled URL: {str(e)}")
|
||||
|
||||
logging.error(f"Error processing crawled URL: {e!s}")
|
||||
|
||||
|
||||
async def process_file_in_background_with_new_session(
|
||||
file_path: str,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
file_path: str, filename: str, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Create a new session and process file."""
|
||||
from app.db import async_session_maker
|
||||
|
|
@ -716,7 +780,7 @@ async def process_file_in_background_with_new_session(
|
|||
async with async_session_maker() as session:
|
||||
# Initialize task logging service
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_file_upload",
|
||||
|
|
@ -726,29 +790,36 @@ async def process_file_in_background_with_new_session(
|
|||
"document_type": "FILE",
|
||||
"filename": filename,
|
||||
"file_path": file_path,
|
||||
"user_id": user_id
|
||||
}
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
await process_file_in_background(file_path, filename, search_space_id, user_id, session, task_logger, log_entry)
|
||||
|
||||
await process_file_in_background(
|
||||
file_path,
|
||||
filename,
|
||||
search_space_id,
|
||||
user_id,
|
||||
session,
|
||||
task_logger,
|
||||
log_entry,
|
||||
)
|
||||
|
||||
# Note: success/failure logging is handled within process_file_in_background
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to process file: {filename}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing file: {str(e)}")
|
||||
|
||||
logging.error(f"Error processing file: {e!s}")
|
||||
|
||||
|
||||
async def process_youtube_video_with_new_session(
|
||||
url: str,
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
url: str, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Create a new session and process YouTube video."""
|
||||
from app.db import async_session_maker
|
||||
|
|
@ -757,42 +828,43 @@ async def process_youtube_video_with_new_session(
|
|||
async with async_session_maker() as session:
|
||||
# Initialize task logging service
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_youtube_video",
|
||||
source="document_processor",
|
||||
message=f"Starting YouTube video processing for: {url}",
|
||||
metadata={
|
||||
"document_type": "YOUTUBE_VIDEO",
|
||||
"url": url,
|
||||
"user_id": user_id
|
||||
}
|
||||
metadata={"document_type": "YOUTUBE_VIDEO", "url": url, "user_id": user_id},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await add_youtube_video_document(session, url, search_space_id, user_id)
|
||||
|
||||
result = await add_youtube_video_document(
|
||||
session, url, search_space_id, user_id
|
||||
)
|
||||
|
||||
if result:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed YouTube video: {result.title}",
|
||||
{"document_id": result.id, "video_id": result.document_metadata.get("video_id"), "content_hash": result.content_hash}
|
||||
{
|
||||
"document_id": result.id,
|
||||
"video_id": result.document_metadata.get("video_id"),
|
||||
"content_hash": result.content_hash,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"YouTube video document already exists (duplicate): {url}",
|
||||
{"duplicate_detected": True}
|
||||
{"duplicate_detected": True},
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to process YouTube video: {url}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
import logging
|
||||
logging.error(f"Error processing YouTube video: {str(e)}")
|
||||
|
||||
|
||||
logging.error(f"Error processing YouTube video: {e!s}")
|
||||
|
|
|
|||
|
|
@ -1,35 +1,40 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from app.db import get_async_session, User, LLMConfig
|
||||
from app.schemas import LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
|
||||
|
||||
from app.db import LLMConfig, User, get_async_session
|
||||
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class LLMPreferencesUpdate(BaseModel):
|
||||
"""Schema for updating user LLM preferences"""
|
||||
long_context_llm_id: Optional[int] = None
|
||||
fast_llm_id: Optional[int] = None
|
||||
strategic_llm_id: Optional[int] = None
|
||||
|
||||
long_context_llm_id: int | None = None
|
||||
fast_llm_id: int | None = None
|
||||
strategic_llm_id: int | None = None
|
||||
|
||||
|
||||
class LLMPreferencesRead(BaseModel):
|
||||
"""Schema for reading user LLM preferences"""
|
||||
long_context_llm_id: Optional[int] = None
|
||||
fast_llm_id: Optional[int] = None
|
||||
strategic_llm_id: Optional[int] = None
|
||||
long_context_llm: Optional[LLMConfigRead] = None
|
||||
fast_llm: Optional[LLMConfigRead] = None
|
||||
strategic_llm: Optional[LLMConfigRead] = None
|
||||
|
||||
long_context_llm_id: int | None = None
|
||||
fast_llm_id: int | None = None
|
||||
strategic_llm_id: int | None = None
|
||||
long_context_llm: LLMConfigRead | None = None
|
||||
fast_llm: LLMConfigRead | None = None
|
||||
strategic_llm: LLMConfigRead | None = None
|
||||
|
||||
|
||||
@router.post("/llm-configs/", response_model=LLMConfigRead)
|
||||
async def create_llm_config(
|
||||
llm_config: LLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a new LLM configuration for the authenticated user"""
|
||||
try:
|
||||
|
|
@ -43,16 +48,16 @@ async def create_llm_config(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create LLM configuration: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to create LLM configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
@router.get("/llm-configs/", response_model=List[LLMConfigRead])
|
||||
|
||||
@router.get("/llm-configs/", response_model=list[LLMConfigRead])
|
||||
async def read_llm_configs(
|
||||
skip: int = 0,
|
||||
limit: int = 200,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get all LLM configurations for the authenticated user"""
|
||||
try:
|
||||
|
|
@ -65,15 +70,15 @@ async def read_llm_configs(
|
|||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch LLM configurations: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
||||
async def read_llm_config(
|
||||
llm_config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific LLM configuration by ID"""
|
||||
try:
|
||||
|
|
@ -83,25 +88,25 @@ async def read_llm_config(
|
|||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch LLM configuration: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch LLM configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
||||
async def update_llm_config(
|
||||
llm_config_id: int,
|
||||
llm_config_update: LLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update an existing LLM configuration"""
|
||||
try:
|
||||
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
|
||||
update_data = llm_config_update.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(db_llm_config, key, value)
|
||||
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_llm_config)
|
||||
return db_llm_config
|
||||
|
|
@ -110,15 +115,15 @@ async def update_llm_config(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update LLM configuration: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update LLM configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/llm-configs/{llm_config_id}", response_model=dict)
|
||||
async def delete_llm_config(
|
||||
llm_config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete an LLM configuration"""
|
||||
try:
|
||||
|
|
@ -131,22 +136,23 @@ async def delete_llm_config(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete LLM configuration: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to delete LLM configuration: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# User LLM Preferences endpoints
|
||||
|
||||
|
||||
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||
async def get_user_llm_preferences(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get the current user's LLM preferences"""
|
||||
try:
|
||||
# Refresh user to get latest relationships
|
||||
await session.refresh(user)
|
||||
|
||||
|
||||
result = {
|
||||
"long_context_llm_id": user.long_context_llm_id,
|
||||
"fast_llm_id": user.fast_llm_id,
|
||||
|
|
@ -155,82 +161,79 @@ async def get_user_llm_preferences(
|
|||
"fast_llm": None,
|
||||
"strategic_llm": None,
|
||||
}
|
||||
|
||||
|
||||
# Fetch the actual LLM configs if they exist
|
||||
if user.long_context_llm_id:
|
||||
long_context_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.long_context_llm_id,
|
||||
LLMConfig.user_id == user.id
|
||||
LLMConfig.user_id == user.id,
|
||||
)
|
||||
)
|
||||
llm_config = long_context_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["long_context_llm"] = llm_config
|
||||
|
||||
|
||||
if user.fast_llm_id:
|
||||
fast_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.fast_llm_id,
|
||||
LLMConfig.user_id == user.id
|
||||
LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = fast_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["fast_llm"] = llm_config
|
||||
|
||||
|
||||
if user.strategic_llm_id:
|
||||
strategic_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.strategic_llm_id,
|
||||
LLMConfig.user_id == user.id
|
||||
LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = strategic_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["strategic_llm"] = llm_config
|
||||
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch LLM preferences: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||
async def update_user_llm_preferences(
|
||||
preferences: LLMPreferencesUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update the current user's LLM preferences"""
|
||||
try:
|
||||
# Validate that all provided LLM config IDs belong to the user
|
||||
update_data = preferences.model_dump(exclude_unset=True)
|
||||
|
||||
for key, llm_config_id in update_data.items():
|
||||
|
||||
for _key, llm_config_id in update_data.items():
|
||||
if llm_config_id is not None:
|
||||
# Verify ownership of the LLM config
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.user_id == user.id
|
||||
LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
if not llm_config:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it"
|
||||
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it",
|
||||
)
|
||||
|
||||
|
||||
# Update user preferences
|
||||
for key, value in update_data.items():
|
||||
setattr(user, key, value)
|
||||
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
|
||||
# Return updated preferences
|
||||
return await get_user_llm_preferences(session, user)
|
||||
except HTTPException:
|
||||
|
|
@ -238,6 +241,5 @@ async def update_user_llm_preferences(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update LLM preferences: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import and_, desc
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.db import get_async_session, User, SearchSpace, Log, LogLevel, LogStatus
|
||||
from app.schemas import LogCreate, LogUpdate, LogRead, LogFilter
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import Log, LogLevel, LogStatus, SearchSpace, User, get_async_session
|
||||
from app.schemas import LogCreate, LogRead, LogUpdate
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/logs/", response_model=LogRead)
|
||||
async def create_log(
|
||||
log: LogCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a new log entry."""
|
||||
try:
|
||||
|
|
@ -33,22 +34,22 @@ async def create_log(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create log: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to create log: {e!s}"
|
||||
) from e
|
||||
|
||||
@router.get("/logs/", response_model=List[LogRead])
|
||||
|
||||
@router.get("/logs/", response_model=list[LogRead])
|
||||
async def read_logs(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: Optional[int] = None,
|
||||
level: Optional[LogLevel] = None,
|
||||
status: Optional[LogStatus] = None,
|
||||
source: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
search_space_id: int | None = None,
|
||||
level: LogLevel | None = None,
|
||||
status: LogStatus | None = None,
|
||||
source: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get logs with optional filtering."""
|
||||
try:
|
||||
|
|
@ -62,23 +63,23 @@ async def read_logs(
|
|||
|
||||
# Apply filters
|
||||
filters = []
|
||||
|
||||
|
||||
if search_space_id is not None:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
filters.append(Log.search_space_id == search_space_id)
|
||||
|
||||
|
||||
if level is not None:
|
||||
filters.append(Log.level == level)
|
||||
|
||||
|
||||
if status is not None:
|
||||
filters.append(Log.status == status)
|
||||
|
||||
|
||||
if source is not None:
|
||||
filters.append(Log.source.ilike(f"%{source}%"))
|
||||
|
||||
|
||||
if start_date is not None:
|
||||
filters.append(Log.created_at >= start_date)
|
||||
|
||||
|
||||
if end_date is not None:
|
||||
filters.append(Log.created_at <= end_date)
|
||||
|
||||
|
|
@ -93,15 +94,15 @@ async def read_logs(
|
|||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch logs: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch logs: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/logs/{log_id}", response_model=LogRead)
|
||||
async def read_log(
|
||||
log_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific log by ID."""
|
||||
try:
|
||||
|
|
@ -112,25 +113,25 @@ async def read_log(
|
|||
.filter(Log.id == log_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
log = result.scalars().first()
|
||||
|
||||
|
||||
if not log:
|
||||
raise HTTPException(status_code=404, detail="Log not found")
|
||||
|
||||
|
||||
return log
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch log: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch log: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/logs/{log_id}", response_model=LogRead)
|
||||
async def update_log(
|
||||
log_id: int,
|
||||
log_update: LogUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update a log entry."""
|
||||
try:
|
||||
|
|
@ -141,7 +142,7 @@ async def update_log(
|
|||
.filter(Log.id == log_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
db_log = result.scalars().first()
|
||||
|
||||
|
||||
if not db_log:
|
||||
raise HTTPException(status_code=404, detail="Log not found")
|
||||
|
||||
|
|
@ -158,15 +159,15 @@ async def update_log(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update log: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update log: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/logs/{log_id}")
|
||||
async def delete_log(
|
||||
log_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete a log entry."""
|
||||
try:
|
||||
|
|
@ -177,7 +178,7 @@ async def delete_log(
|
|||
.filter(Log.id == log_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
db_log = result.scalars().first()
|
||||
|
||||
|
||||
if not db_log:
|
||||
raise HTTPException(status_code=404, detail="Log not found")
|
||||
|
||||
|
|
@ -189,38 +190,35 @@ async def delete_log(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete log: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to delete log: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/logs/search-space/{search_space_id}/summary")
|
||||
async def get_logs_summary(
|
||||
search_space_id: int,
|
||||
hours: int = 24,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a summary of logs for a search space in the last X hours."""
|
||||
try:
|
||||
# Check ownership
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
||||
|
||||
# Calculate time window
|
||||
since = datetime.utcnow().replace(microsecond=0) - timedelta(hours=hours)
|
||||
|
||||
|
||||
# Get logs from the time window
|
||||
result = await session.execute(
|
||||
select(Log)
|
||||
.filter(
|
||||
and_(
|
||||
Log.search_space_id == search_space_id,
|
||||
Log.created_at >= since
|
||||
)
|
||||
and_(Log.search_space_id == search_space_id, Log.created_at >= since)
|
||||
)
|
||||
.order_by(desc(Log.created_at))
|
||||
)
|
||||
logs = result.scalars().all()
|
||||
|
||||
|
||||
# Create summary
|
||||
summary = {
|
||||
"total_logs": len(logs),
|
||||
|
|
@ -229,52 +227,69 @@ async def get_logs_summary(
|
|||
"by_level": {},
|
||||
"by_source": {},
|
||||
"active_tasks": [],
|
||||
"recent_failures": []
|
||||
"recent_failures": [],
|
||||
}
|
||||
|
||||
|
||||
# Count by status and level
|
||||
for log in logs:
|
||||
# Status counts
|
||||
status_str = log.status.value
|
||||
summary["by_status"][status_str] = summary["by_status"].get(status_str, 0) + 1
|
||||
|
||||
summary["by_status"][status_str] = (
|
||||
summary["by_status"].get(status_str, 0) + 1
|
||||
)
|
||||
|
||||
# Level counts
|
||||
level_str = log.level.value
|
||||
summary["by_level"][level_str] = summary["by_level"].get(level_str, 0) + 1
|
||||
|
||||
|
||||
# Source counts
|
||||
if log.source:
|
||||
summary["by_source"][log.source] = summary["by_source"].get(log.source, 0) + 1
|
||||
|
||||
summary["by_source"][log.source] = (
|
||||
summary["by_source"].get(log.source, 0) + 1
|
||||
)
|
||||
|
||||
# Active tasks (IN_PROGRESS)
|
||||
if log.status == LogStatus.IN_PROGRESS:
|
||||
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown"
|
||||
summary["active_tasks"].append({
|
||||
"id": log.id,
|
||||
"task_name": task_name,
|
||||
"message": log.message,
|
||||
"started_at": log.created_at,
|
||||
"source": log.source
|
||||
})
|
||||
|
||||
task_name = (
|
||||
log.log_metadata.get("task_name", "Unknown")
|
||||
if log.log_metadata
|
||||
else "Unknown"
|
||||
)
|
||||
summary["active_tasks"].append(
|
||||
{
|
||||
"id": log.id,
|
||||
"task_name": task_name,
|
||||
"message": log.message,
|
||||
"started_at": log.created_at,
|
||||
"source": log.source,
|
||||
}
|
||||
)
|
||||
|
||||
# Recent failures
|
||||
if log.status == LogStatus.FAILED and len(summary["recent_failures"]) < 10:
|
||||
task_name = log.log_metadata.get("task_name", "Unknown") if log.log_metadata else "Unknown"
|
||||
summary["recent_failures"].append({
|
||||
"id": log.id,
|
||||
"task_name": task_name,
|
||||
"message": log.message,
|
||||
"failed_at": log.created_at,
|
||||
"source": log.source,
|
||||
"error_details": log.log_metadata.get("error_details") if log.log_metadata else None
|
||||
})
|
||||
|
||||
task_name = (
|
||||
log.log_metadata.get("task_name", "Unknown")
|
||||
if log.log_metadata
|
||||
else "Unknown"
|
||||
)
|
||||
summary["recent_failures"].append(
|
||||
{
|
||||
"id": log.id,
|
||||
"task_name": task_name,
|
||||
"message": log.message,
|
||||
"failed_at": log.created_at,
|
||||
"source": log.source,
|
||||
"error_details": log.log_metadata.get("error_details")
|
||||
if log.log_metadata
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to generate logs summary: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to generate logs summary: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -1,24 +1,31 @@
|
|||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace, Podcast, Chat
|
||||
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||
from fastapi.responses import StreamingResponse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import Chat, Podcast, SearchSpace, User, get_async_session
|
||||
from app.schemas import (
|
||||
PodcastCreate,
|
||||
PodcastGenerateRequest,
|
||||
PodcastRead,
|
||||
PodcastUpdate,
|
||||
)
|
||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/podcasts/", response_model=PodcastRead)
|
||||
async def create_podcast(
|
||||
podcast: PodcastCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, podcast.search_space_id, user)
|
||||
|
|
@ -29,22 +36,30 @@ async def create_podcast(
|
|||
return db_podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError as e:
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Podcast creation failed due to constraint violation")
|
||||
except SQLAlchemyError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Podcast creation failed due to constraint violation",
|
||||
) from None
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while creating podcast")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while creating podcast"
|
||||
) from None
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="An unexpected error occurred")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred"
|
||||
) from None
|
||||
|
||||
@router.get("/podcasts/", response_model=List[PodcastRead])
|
||||
|
||||
@router.get("/podcasts/", response_model=list[PodcastRead])
|
||||
async def read_podcasts(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
|
|
@ -58,13 +73,16 @@ async def read_podcasts(
|
|||
)
|
||||
return result.scalars().all()
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcasts")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while fetching podcasts"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def read_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -76,20 +94,23 @@ async def read_podcast(
|
|||
if not podcast:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Podcast not found or you don't have permission to access it"
|
||||
detail="Podcast not found or you don't have permission to access it",
|
||||
)
|
||||
return podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcast")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while fetching podcast"
|
||||
) from None
|
||||
|
||||
|
||||
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def update_podcast(
|
||||
podcast_id: int,
|
||||
podcast_update: PodcastUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_podcast = await read_podcast(podcast_id, session, user)
|
||||
|
|
@ -103,16 +124,21 @@ async def update_podcast(
|
|||
raise he
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Update failed due to constraint violation")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Update failed due to constraint violation"
|
||||
) from None
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while updating podcast")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while updating podcast"
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
||||
async def delete_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_podcast = await read_podcast(podcast_id, session, user)
|
||||
|
|
@ -123,83 +149,100 @@ async def delete_podcast(
|
|||
raise he
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while deleting podcast"
|
||||
) from None
|
||||
|
||||
|
||||
async def generate_chat_podcast_with_new_session(
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
podcast_title: str,
|
||||
user_id: int
|
||||
chat_id: int, search_space_id: int, podcast_title: str, user_id: int
|
||||
):
|
||||
"""Create a new session and process chat podcast generation."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title, user_id)
|
||||
await generate_chat_podcast(
|
||||
session, chat_id, search_space_id, podcast_title, user_id
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Error generating podcast from chat: {str(e)}")
|
||||
|
||||
logging.error(f"Error generating podcast from chat: {e!s}")
|
||||
|
||||
|
||||
@router.post("/podcasts/generate/")
|
||||
async def generate_podcast(
|
||||
request: PodcastGenerateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
try:
|
||||
# Check if the user owns the search space
|
||||
await check_ownership(session, SearchSpace, request.search_space_id, user)
|
||||
|
||||
|
||||
if request.type == "CHAT":
|
||||
# Verify that all chat IDs belong to this user and search space
|
||||
query = select(Chat).filter(
|
||||
Chat.id.in_(request.ids),
|
||||
Chat.search_space_id == request.search_space_id
|
||||
).join(SearchSpace).filter(SearchSpace.user_id == user.id)
|
||||
|
||||
query = (
|
||||
select(Chat)
|
||||
.filter(
|
||||
Chat.id.in_(request.ids),
|
||||
Chat.search_space_id == request.search_space_id,
|
||||
)
|
||||
.join(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
valid_chats = result.scalars().all()
|
||||
valid_chat_ids = [chat.id for chat in valid_chats]
|
||||
|
||||
|
||||
# If any requested ID is not in valid IDs, raise error immediately
|
||||
if len(valid_chat_ids) != len(request.ids):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="One or more chat IDs do not belong to this user or search space"
|
||||
status_code=403,
|
||||
detail="One or more chat IDs do not belong to this user or search space",
|
||||
)
|
||||
|
||||
|
||||
# Only add a single task with the first chat ID
|
||||
for chat_id in valid_chat_ids:
|
||||
fastapi_background_tasks.add_task(
|
||||
generate_chat_podcast_with_new_session,
|
||||
chat_id,
|
||||
generate_chat_podcast_with_new_session,
|
||||
chat_id,
|
||||
request.search_space_id,
|
||||
request.podcast_title,
|
||||
user.id
|
||||
user.id,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"message": "Podcast generation started",
|
||||
}
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError as e:
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Podcast generation failed due to constraint violation")
|
||||
except SQLAlchemyError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Podcast generation failed due to constraint violation",
|
||||
) from None
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while generating podcast")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while generating podcast"
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"An unexpected error occurred: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}/stream")
|
||||
async def stream_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Stream a podcast audio file."""
|
||||
try:
|
||||
|
|
@ -210,36 +253,38 @@ async def stream_podcast(
|
|||
.filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
podcast = result.scalars().first()
|
||||
|
||||
|
||||
if not podcast:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Podcast not found or you don't have permission to access it"
|
||||
detail="Podcast not found or you don't have permission to access it",
|
||||
)
|
||||
|
||||
|
||||
# Get the file path
|
||||
file_path = podcast.file_location
|
||||
|
||||
|
||||
# Check if the file exists
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Podcast audio file not found")
|
||||
|
||||
|
||||
# Define a generator function to stream the file
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
|
||||
|
||||
# Return a streaming response with appropriate headers
|
||||
return StreamingResponse(
|
||||
iterfile(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Disposition": f"inline; filename={Path(file_path).name}"
|
||||
}
|
||||
"Content-Disposition": f"inline; filename={Path(file_path).name}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error streaming podcast: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error streaming podcast: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -9,35 +9,59 @@ POST /search-source-connectors/{connector_id}/index - Index content from a conne
|
|||
|
||||
Note: Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR, NOTION_CONNECTOR, GITHUB_CONNECTOR, LINEAR_CONNECTOR, DISCORD_CONNECTOR).
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks, Body
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from typing import List, Dict, Any
|
||||
from app.db import get_async_session, User, SearchSourceConnector, SearchSourceConnectorType, SearchSpace, async_session_maker
|
||||
from app.schemas import SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead, SearchSourceConnectorBase
|
||||
|
||||
from app.connectors.github_connector import GitHubConnector
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
SearchSpace,
|
||||
User,
|
||||
async_session_maker,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
SearchSourceConnectorBase,
|
||||
SearchSourceConnectorCreate,
|
||||
SearchSourceConnectorRead,
|
||||
SearchSourceConnectorUpdate,
|
||||
)
|
||||
from app.tasks.connectors_indexing_tasks import (
|
||||
index_discord_messages,
|
||||
index_github_repos,
|
||||
index_jira_issues,
|
||||
index_linear_issues,
|
||||
index_notion_pages,
|
||||
index_slack_messages,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from app.tasks.connectors_indexing_tasks import index_slack_messages, index_notion_pages, index_github_repos, index_linear_issues, index_discord_messages
|
||||
from app.connectors.github_connector import GitHubConnector
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Use Pydantic's BaseModel here
|
||||
class GitHubPATRequest(BaseModel):
|
||||
github_pat: str = Field(..., description="GitHub Personal Access Token")
|
||||
|
||||
|
||||
# --- New Endpoint to list GitHub Repositories ---
|
||||
@router.post("/github/repositories/", response_model=List[Dict[str, Any]])
|
||||
@router.post("/github/repositories/", response_model=list[dict[str, Any]])
|
||||
async def list_github_repositories(
|
||||
pat_request: GitHubPATRequest,
|
||||
user: User = Depends(current_active_user) # Ensure the user is logged in
|
||||
user: User = Depends(current_active_user), # Ensure the user is logged in
|
||||
):
|
||||
"""
|
||||
Fetches a list of repositories accessible by the provided GitHub PAT.
|
||||
|
|
@ -51,38 +75,40 @@ async def list_github_repositories(
|
|||
return repositories
|
||||
except ValueError as e:
|
||||
# Handle invalid token error specifically
|
||||
logger.error(f"GitHub PAT validation failed for user {user.id}: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {str(e)}")
|
||||
logger.error(f"GitHub PAT validation failed for user {user.id}: {e!s}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid GitHub PAT: {e!s}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch GitHub repositories for user {user.id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch GitHub repositories.")
|
||||
logger.error(f"Failed to fetch GitHub repositories for user {user.id}: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to fetch GitHub repositories."
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead)
|
||||
async def create_search_source_connector(
|
||||
connector: SearchSourceConnectorCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Create a new search source connector.
|
||||
|
||||
|
||||
Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR, etc.).
|
||||
The config must contain the appropriate keys for the connector type.
|
||||
"""
|
||||
try:
|
||||
# Check if a connector with the same type already exists for this user
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type == connector.connector_type
|
||||
SearchSourceConnector.connector_type == connector.connector_type,
|
||||
)
|
||||
)
|
||||
existing_connector = result.scalars().first()
|
||||
if existing_connector:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A connector with type {connector.connector_type} already exists. Each user can have only one connector of each type."
|
||||
detail=f"A connector with type {connector.connector_type} already exists. Each user can have only one connector of each type.",
|
||||
)
|
||||
db_connector = SearchSourceConnector(**connector.model_dump(), user_id=user.id)
|
||||
session.add(db_connector)
|
||||
|
|
@ -91,56 +117,59 @@ async def create_search_source_connector(
|
|||
return db_connector
|
||||
except ValidationError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Validation error: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=422, detail=f"Validation error: {e!s}") from e
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Integrity error: A connector with this type already exists. {str(e)}"
|
||||
)
|
||||
detail=f"Integrity error: A connector with this type already exists. {e!s}",
|
||||
) from e
|
||||
except HTTPException:
|
||||
await session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create search source connector: {str(e)}")
|
||||
logger.error(f"Failed to create search source connector: {e!s}")
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create search source connector: {str(e)}"
|
||||
)
|
||||
detail=f"Failed to create search source connector: {e!s}",
|
||||
) from e
|
||||
|
||||
@router.get("/search-source-connectors/", response_model=List[SearchSourceConnectorRead])
|
||||
|
||||
@router.get(
|
||||
"/search-source-connectors/", response_model=list[SearchSourceConnectorRead]
|
||||
)
|
||||
async def read_search_source_connectors(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: int = None,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List all search source connectors for the current user."""
|
||||
try:
|
||||
query = select(SearchSourceConnector).filter(SearchSourceConnector.user_id == user.id)
|
||||
|
||||
# No need to filter by search_space_id as connectors are user-owned, not search space specific
|
||||
|
||||
result = await session.execute(
|
||||
query.offset(skip).limit(limit)
|
||||
query = select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.user_id == user.id
|
||||
)
|
||||
|
||||
# No need to filter by search_space_id as connectors are user-owned, not search space specific
|
||||
|
||||
result = await session.execute(query.offset(skip).limit(limit))
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search source connectors: {str(e)}"
|
||||
)
|
||||
detail=f"Failed to fetch search source connectors: {e!s}",
|
||||
) from e
|
||||
|
||||
@router.get("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
|
||||
|
||||
@router.get(
|
||||
"/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead
|
||||
)
|
||||
async def read_search_source_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get a specific search source connector by ID."""
|
||||
try:
|
||||
|
|
@ -149,31 +178,37 @@ async def read_search_source_connector(
|
|||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search source connector: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch search source connector: {e!s}"
|
||||
) from e
|
||||
|
||||
@router.put("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
|
||||
|
||||
@router.put(
|
||||
"/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead
|
||||
)
|
||||
async def update_search_source_connector(
|
||||
connector_id: int,
|
||||
connector_update: SearchSourceConnectorUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update a search source connector.
|
||||
Handles partial updates, including merging changes into the 'config' field.
|
||||
"""
|
||||
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
|
||||
db_connector = await check_ownership(
|
||||
session, SearchSourceConnector, connector_id, user
|
||||
)
|
||||
|
||||
# Convert the sparse update data (only fields present in request) to a dict
|
||||
update_data = connector_update.model_dump(exclude_unset=True)
|
||||
|
||||
# Special handling for 'config' field
|
||||
if "config" in update_data:
|
||||
incoming_config = update_data["config"] # Config data from the request
|
||||
existing_config = db_connector.config if db_connector.config else {} # Current config from DB
|
||||
|
||||
incoming_config = update_data["config"] # Config data from the request
|
||||
existing_config = (
|
||||
db_connector.config if db_connector.config else {}
|
||||
) # Current config from DB
|
||||
|
||||
# Merge incoming config into existing config
|
||||
# This preserves existing keys (like GITHUB_PAT) if they are not in the incoming data
|
||||
merged_config = existing_config.copy()
|
||||
|
|
@ -182,26 +217,29 @@ async def update_search_source_connector(
|
|||
# -- Validation after merging --
|
||||
# Validate the *merged* config based on the connector type
|
||||
# We need the connector type - use the one from the update if provided, else the existing one
|
||||
current_connector_type = connector_update.connector_type if connector_update.connector_type is not None else db_connector.connector_type
|
||||
|
||||
current_connector_type = (
|
||||
connector_update.connector_type
|
||||
if connector_update.connector_type is not None
|
||||
else db_connector.connector_type
|
||||
)
|
||||
|
||||
try:
|
||||
# We can reuse the base validator by creating a temporary base model instance
|
||||
# Note: This assumes 'name' and 'is_indexable' are not crucial for config validation itself
|
||||
temp_data_for_validation = {
|
||||
"name": db_connector.name, # Use existing name
|
||||
"name": db_connector.name, # Use existing name
|
||||
"connector_type": current_connector_type,
|
||||
"is_indexable": db_connector.is_indexable, # Use existing value
|
||||
"last_indexed_at": db_connector.last_indexed_at, # Not used by validator
|
||||
"config": merged_config
|
||||
"is_indexable": db_connector.is_indexable, # Use existing value
|
||||
"last_indexed_at": db_connector.last_indexed_at, # Not used by validator
|
||||
"config": merged_config,
|
||||
}
|
||||
SearchSourceConnectorBase.model_validate(temp_data_for_validation)
|
||||
except ValidationError as e:
|
||||
# Raise specific validation error for the merged config
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Validation error for merged config: {str(e)}"
|
||||
)
|
||||
|
||||
status_code=422, detail=f"Validation error for merged config: {e!s}"
|
||||
) from e
|
||||
|
||||
# If validation passes, update the main update_data dict with the merged config
|
||||
update_data["config"] = merged_config
|
||||
|
||||
|
|
@ -210,20 +248,19 @@ async def update_search_source_connector(
|
|||
# Prevent changing connector_type if it causes a duplicate (check moved here)
|
||||
if key == "connector_type" and value != db_connector.connector_type:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type == value,
|
||||
SearchSourceConnector.id != connector_id
|
||||
SearchSourceConnector.id != connector_id,
|
||||
)
|
||||
)
|
||||
existing_connector = result.scalars().first()
|
||||
if existing_connector:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A connector with type {value} already exists. Each user can have only one connector of each type."
|
||||
detail=f"A connector with type {value} already exists. Each user can have only one connector of each type.",
|
||||
)
|
||||
|
||||
|
||||
setattr(db_connector, key, value)
|
||||
|
||||
try:
|
||||
|
|
@ -234,26 +271,31 @@ async def update_search_source_connector(
|
|||
await session.rollback()
|
||||
# This might occur if connector_type constraint is violated somehow after the check
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Database integrity error during update: {str(e)}"
|
||||
)
|
||||
status_code=409, detail=f"Database integrity error during update: {e!s}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to update search source connector {connector_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to update search source connector {connector_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update search source connector: {str(e)}"
|
||||
)
|
||||
detail=f"Failed to update search source connector: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/search-source-connectors/{connector_id}", response_model=dict)
|
||||
async def delete_search_source_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete a search source connector."""
|
||||
try:
|
||||
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
db_connector = await check_ownership(
|
||||
session, SearchSourceConnector, connector_id, user
|
||||
)
|
||||
await session.delete(db_connector)
|
||||
await session.commit()
|
||||
return {"message": "Search source connector deleted successfully"}
|
||||
|
|
@ -263,48 +305,64 @@ async def delete_search_source_connector(
|
|||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete search source connector: {str(e)}"
|
||||
)
|
||||
detail=f"Failed to delete search source connector: {e!s}",
|
||||
) from e
|
||||
|
||||
@router.post("/search-source-connectors/{connector_id}/index", response_model=Dict[str, Any])
|
||||
|
||||
@router.post(
|
||||
"/search-source-connectors/{connector_id}/index", response_model=dict[str, Any]
|
||||
)
|
||||
async def index_connector_content(
|
||||
connector_id: int,
|
||||
search_space_id: int = Query(..., description="ID of the search space to store indexed content"),
|
||||
start_date: str = Query(None, description="Start date for indexing (YYYY-MM-DD format). If not provided, uses last_indexed_at or defaults to 365 days ago"),
|
||||
end_date: str = Query(None, description="End date for indexing (YYYY-MM-DD format). If not provided, uses today's date"),
|
||||
search_space_id: int = Query(
|
||||
..., description="ID of the search space to store indexed content"
|
||||
),
|
||||
start_date: str = Query(
|
||||
None,
|
||||
description="Start date for indexing (YYYY-MM-DD format). If not provided, uses last_indexed_at or defaults to 365 days ago",
|
||||
),
|
||||
end_date: str = Query(
|
||||
None,
|
||||
description="End date for indexing (YYYY-MM-DD format). If not provided, uses today's date",
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
background_tasks: BackgroundTasks = None
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""
|
||||
Index content from a connector to a search space.
|
||||
|
||||
|
||||
Currently supports:
|
||||
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels
|
||||
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages
|
||||
- GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories
|
||||
- LINEAR_CONNECTOR: Indexes issues and comments from Linear
|
||||
- JIRA_CONNECTOR: Indexes issues and comments from Jira
|
||||
- DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels
|
||||
|
||||
|
||||
Args:
|
||||
connector_id: ID of the connector to use
|
||||
search_space_id: ID of the search space to store indexed content
|
||||
background_tasks: FastAPI background tasks
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with indexing status
|
||||
"""
|
||||
try:
|
||||
# Check if the connector belongs to the user
|
||||
connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
|
||||
connector = await check_ownership(
|
||||
session, SearchSourceConnector, connector_id, user
|
||||
)
|
||||
|
||||
# Check if the search space belongs to the user
|
||||
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
||||
_search_space = await check_ownership(
|
||||
session, SearchSpace, search_space_id, user
|
||||
)
|
||||
|
||||
# Handle different connector types
|
||||
response_message = ""
|
||||
today_str = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
# Determine the actual date range to use
|
||||
if start_date is None:
|
||||
# Use last_indexed_at or default to 365 days ago
|
||||
|
|
@ -316,110 +374,172 @@ async def index_connector_content(
|
|||
else:
|
||||
indexing_from = connector.last_indexed_at.strftime("%Y-%m-%d")
|
||||
else:
|
||||
indexing_from = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
indexing_from = (datetime.now() - timedelta(days=365)).strftime(
|
||||
"%Y-%m-%d"
|
||||
)
|
||||
else:
|
||||
indexing_from = start_date
|
||||
|
||||
if end_date is None:
|
||||
indexing_to = today_str
|
||||
else:
|
||||
indexing_to = end_date
|
||||
|
||||
indexing_to = end_date if end_date else today_str
|
||||
|
||||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
|
||||
background_tasks.add_task(run_slack_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
|
||||
logger.info(
|
||||
f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
background_tasks.add_task(
|
||||
run_slack_indexing_with_new_session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
indexing_from,
|
||||
indexing_to,
|
||||
)
|
||||
response_message = "Slack indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(f"Triggering Notion indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
|
||||
background_tasks.add_task(run_notion_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
|
||||
logger.info(
|
||||
f"Triggering Notion indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
background_tasks.add_task(
|
||||
run_notion_indexing_with_new_session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
indexing_from,
|
||||
indexing_to,
|
||||
)
|
||||
response_message = "Notion indexing started in the background."
|
||||
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(f"Triggering GitHub indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
|
||||
background_tasks.add_task(run_github_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
|
||||
logger.info(
|
||||
f"Triggering GitHub indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
background_tasks.add_task(
|
||||
run_github_indexing_with_new_session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
indexing_from,
|
||||
indexing_to,
|
||||
)
|
||||
response_message = "GitHub indexing started in the background."
|
||||
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
|
||||
background_tasks.add_task(run_linear_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
|
||||
logger.info(
|
||||
f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
background_tasks.add_task(
|
||||
run_linear_indexing_with_new_session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
indexing_from,
|
||||
indexing_to,
|
||||
)
|
||||
response_message = "Linear indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.JIRA_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(
|
||||
f"Triggering Jira indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
background_tasks.add_task(
|
||||
run_jira_indexing_with_new_session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
indexing_from,
|
||||
indexing_to,
|
||||
)
|
||||
response_message = "Jira indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(
|
||||
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
background_tasks.add_task(
|
||||
run_discord_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
run_discord_indexing_with_new_session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
indexing_from,
|
||||
indexing_to,
|
||||
)
|
||||
response_message = "Discord indexing started in the background."
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Indexing not supported for connector type: {connector.connector_type}"
|
||||
detail=f"Indexing not supported for connector type: {connector.connector_type}",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": response_message,
|
||||
"connector_id": connector_id,
|
||||
"message": response_message,
|
||||
"connector_id": connector_id,
|
||||
"search_space_id": search_space_id,
|
||||
"indexing_from": indexing_from,
|
||||
"indexing_to": indexing_to
|
||||
"indexing_to": indexing_to,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate indexing for connector {connector_id}: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to initiate indexing: {str(e)}"
|
||||
logger.error(
|
||||
f"Failed to initiate indexing for connector {connector_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def update_connector_last_indexed(
|
||||
session: AsyncSession,
|
||||
connector_id: int
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate indexing: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def update_connector_last_indexed(session: AsyncSession, connector_id: int):
|
||||
"""
|
||||
Update the last_indexed_at timestamp for a connector.
|
||||
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the connector to update
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(SearchSourceConnector.id == connector_id)
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
|
||||
if connector:
|
||||
connector.last_indexed_at = datetime.now()
|
||||
await session.commit()
|
||||
logger.info(f"Updated last_indexed_at for connector {connector_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update last_indexed_at for connector {connector_id}: {str(e)}")
|
||||
logger.error(
|
||||
f"Failed to update last_indexed_at for connector {connector_id}: {e!s}"
|
||||
)
|
||||
await session.rollback()
|
||||
|
||||
|
||||
async def run_slack_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Slack indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_slack_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
await run_slack_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_slack_indexing(
|
||||
session: AsyncSession,
|
||||
|
|
@ -427,11 +547,11 @@ async def run_slack_indexing(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Slack indexing.
|
||||
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
|
|
@ -449,31 +569,39 @@ async def run_slack_indexing(
|
|||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
update_last_indexed=False, # Don't update timestamp in the indexing function
|
||||
)
|
||||
|
||||
|
||||
# Only update last_indexed_at if indexing was successful (either new docs or updated docs)
|
||||
if documents_processed > 0:
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
logger.info(f"Slack indexing completed successfully: {documents_processed} documents processed")
|
||||
logger.info(
|
||||
f"Slack indexing completed successfully: {documents_processed} documents processed"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Slack indexing failed or no documents processed: {error_or_warning}")
|
||||
logger.error(
|
||||
f"Slack indexing failed or no documents processed: {error_or_warning}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Slack indexing task: {str(e)}")
|
||||
logger.error(f"Error in background Slack indexing task: {e!s}")
|
||||
|
||||
|
||||
async def run_notion_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Notion indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_notion_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
await run_notion_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_notion_indexing(
|
||||
session: AsyncSession,
|
||||
|
|
@ -481,11 +609,11 @@ async def run_notion_indexing(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Notion indexing.
|
||||
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Notion connector
|
||||
|
|
@ -503,17 +631,22 @@ async def run_notion_indexing(
|
|||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
update_last_indexed=False, # Don't update timestamp in the indexing function
|
||||
)
|
||||
|
||||
|
||||
# Only update last_indexed_at if indexing was successful (either new docs or updated docs)
|
||||
if documents_processed > 0:
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
logger.info(f"Notion indexing completed successfully: {documents_processed} documents processed")
|
||||
logger.info(
|
||||
f"Notion indexing completed successfully: {documents_processed} documents processed"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Notion indexing failed or no documents processed: {error_or_warning}")
|
||||
logger.error(
|
||||
f"Notion indexing failed or no documents processed: {error_or_warning}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Notion indexing task: {str(e)}")
|
||||
logger.error(f"Error in background Notion indexing task: {e!s}")
|
||||
|
||||
|
||||
# Add new helper functions for GitHub indexing
|
||||
async def run_github_indexing_with_new_session(
|
||||
|
|
@ -521,94 +654,135 @@ async def run_github_indexing_with_new_session(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run GitHub indexing with its own database session."""
|
||||
logger.info(f"Background task started: Indexing GitHub connector {connector_id} into space {search_space_id} from {start_date} to {end_date}")
|
||||
logger.info(
|
||||
f"Background task started: Indexing GitHub connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_github_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
await run_github_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing GitHub connector {connector_id}")
|
||||
|
||||
|
||||
async def run_github_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""Runs the GitHub indexing task and updates the timestamp."""
|
||||
try:
|
||||
indexed_count, error_message = await index_github_repos(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
start_date,
|
||||
end_date,
|
||||
update_last_indexed=False,
|
||||
)
|
||||
if error_message:
|
||||
logger.error(f"GitHub indexing failed for connector {connector_id}: {error_message}")
|
||||
logger.error(
|
||||
f"GitHub indexing failed for connector {connector_id}: {error_message}"
|
||||
)
|
||||
# Optionally update status in DB to indicate failure
|
||||
else:
|
||||
logger.info(f"GitHub indexing successful for connector {connector_id}. Indexed {indexed_count} documents.")
|
||||
logger.info(
|
||||
f"GitHub indexing successful for connector {connector_id}. Indexed {indexed_count} documents."
|
||||
)
|
||||
# Update the last indexed timestamp only on success
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
await session.commit() # Commit timestamp update
|
||||
await session.commit() # Commit timestamp update
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Critical error in run_github_indexing for connector {connector_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Critical error in run_github_indexing for connector {connector_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Optionally update status in DB to indicate failure
|
||||
|
||||
|
||||
# Add new helper functions for Linear indexing
|
||||
async def run_linear_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Linear indexing with its own database session."""
|
||||
logger.info(f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}")
|
||||
logger.info(
|
||||
f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_linear_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
await run_linear_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
|
||||
|
||||
|
||||
async def run_linear_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""Runs the Linear indexing task and updates the timestamp."""
|
||||
try:
|
||||
indexed_count, error_message = await index_linear_issues(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
start_date,
|
||||
end_date,
|
||||
update_last_indexed=False,
|
||||
)
|
||||
if error_message:
|
||||
logger.error(f"Linear indexing failed for connector {connector_id}: {error_message}")
|
||||
logger.error(
|
||||
f"Linear indexing failed for connector {connector_id}: {error_message}"
|
||||
)
|
||||
# Optionally update status in DB to indicate failure
|
||||
else:
|
||||
logger.info(f"Linear indexing successful for connector {connector_id}. Indexed {indexed_count} documents.")
|
||||
logger.info(
|
||||
f"Linear indexing successful for connector {connector_id}. Indexed {indexed_count} documents."
|
||||
)
|
||||
# Update the last indexed timestamp only on success
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
await session.commit() # Commit timestamp update
|
||||
await session.commit() # Commit timestamp update
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Critical error in run_linear_indexing for connector {connector_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Critical error in run_linear_indexing for connector {connector_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Optionally update status in DB to indicate failure
|
||||
|
||||
|
||||
# Add new helper functions for discord indexing
|
||||
async def run_discord_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Discord indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_discord_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
await run_discord_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_discord_indexing(
|
||||
session: AsyncSession,
|
||||
|
|
@ -616,7 +790,7 @@ async def run_discord_indexing(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Discord indexing.
|
||||
|
|
@ -637,14 +811,76 @@ async def run_discord_indexing(
|
|||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
update_last_indexed=False, # Don't update timestamp in the indexing function
|
||||
)
|
||||
|
||||
# Only update last_indexed_at if indexing was successful (either new docs or updated docs)
|
||||
if documents_processed > 0:
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
logger.info(f"Discord indexing completed successfully: {documents_processed} documents processed")
|
||||
logger.info(
|
||||
f"Discord indexing completed successfully: {documents_processed} documents processed"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Discord indexing failed or no documents processed: {error_or_warning}")
|
||||
logger.error(
|
||||
f"Discord indexing failed or no documents processed: {error_or_warning}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Discord indexing task: {str(e)}")
|
||||
logger.error(f"Error in background Discord indexing task: {e!s}")
|
||||
|
||||
|
||||
# Add new helper functions for Jira indexing
|
||||
async def run_jira_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Jira indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Jira connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_jira_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Jira connector {connector_id}")
|
||||
|
||||
|
||||
async def run_jira_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Runs the Jira indexing task and updates the timestamp."""
|
||||
try:
|
||||
indexed_count, error_message = await index_jira_issues(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
start_date,
|
||||
end_date,
|
||||
update_last_indexed=False,
|
||||
)
|
||||
if error_message:
|
||||
logger.error(
|
||||
f"Jira indexing failed for connector {connector_id}: {error_message}"
|
||||
)
|
||||
# Optionally update status in DB to indicate failure
|
||||
else:
|
||||
logger.info(
|
||||
f"Jira indexing successful for connector {connector_id}. Indexed {indexed_count} documents."
|
||||
)
|
||||
# Update the last indexed timestamp only on success
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
await session.commit() # Commit timestamp update
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Critical error in run_jira_indexing for connector {connector_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Optionally update status in DB to indicate failure
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace
|
||||
from app.schemas import SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
||||
|
||||
from app.db import SearchSpace, User, get_async_session
|
||||
from app.schemas import SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from fastapi import HTTPException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/searchspaces/", response_model=SearchSpaceRead)
|
||||
async def create_search_space(
|
||||
search_space: SearchSpaceCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id)
|
||||
|
|
@ -27,16 +27,16 @@ async def create_search_space(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create search space: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to create search space: {e!s}"
|
||||
) from e
|
||||
|
||||
@router.get("/searchspaces/", response_model=List[SearchSpaceRead])
|
||||
|
||||
@router.get("/searchspaces/", response_model=list[SearchSpaceRead])
|
||||
async def read_search_spaces(
|
||||
skip: int = 0,
|
||||
limit: int = 200,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -48,37 +48,41 @@ async def read_search_spaces(
|
|||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search spaces: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch search spaces: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||
async def read_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
search_space = await check_ownership(
|
||||
session, SearchSpace, search_space_id, user
|
||||
)
|
||||
return search_space
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search space: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to fetch search space: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||
async def update_search_space(
|
||||
search_space_id: int,
|
||||
search_space_update: SearchSpaceUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
db_search_space = await check_ownership(
|
||||
session, SearchSpace, search_space_id, user
|
||||
)
|
||||
update_data = search_space_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_search_space, key, value)
|
||||
|
|
@ -90,18 +94,20 @@ async def update_search_space(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update search space: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to update search space: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
||||
async def delete_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
try:
|
||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
db_search_space = await check_ownership(
|
||||
session, SearchSpace, search_space_id, user
|
||||
)
|
||||
await session.delete(db_search_space)
|
||||
await session.commit()
|
||||
return {"message": "Search space deleted successfully"}
|
||||
|
|
@ -110,6 +116,5 @@ async def delete_search_space(
|
|||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete search space: {str(e)}"
|
||||
)
|
||||
status_code=500, detail=f"Failed to delete search space: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -1,62 +1,78 @@
|
|||
from .base import TimestampModel, IDModel
|
||||
from .users import UserRead, UserCreate, UserUpdate
|
||||
from .search_space import SearchSpaceBase, SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
||||
from .base import IDModel, TimestampModel
|
||||
from .chats import AISDKChatRequest, ChatBase, ChatCreate, ChatRead, ChatUpdate
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
||||
from .documents import (
|
||||
ExtensionDocumentMetadata,
|
||||
ExtensionDocumentContent,
|
||||
DocumentBase,
|
||||
DocumentRead,
|
||||
DocumentsCreate,
|
||||
DocumentUpdate,
|
||||
DocumentRead,
|
||||
ExtensionDocumentContent,
|
||||
ExtensionDocumentMetadata,
|
||||
)
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
|
||||
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
|
||||
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
|
||||
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
|
||||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
|
||||
from .logs import LogBase, LogCreate, LogUpdate, LogRead, LogFilter
|
||||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
|
||||
from .podcasts import (
|
||||
PodcastBase,
|
||||
PodcastCreate,
|
||||
PodcastGenerateRequest,
|
||||
PodcastRead,
|
||||
PodcastUpdate,
|
||||
)
|
||||
from .search_source_connector import (
|
||||
SearchSourceConnectorBase,
|
||||
SearchSourceConnectorCreate,
|
||||
SearchSourceConnectorRead,
|
||||
SearchSourceConnectorUpdate,
|
||||
)
|
||||
from .search_space import (
|
||||
SearchSpaceBase,
|
||||
SearchSpaceCreate,
|
||||
SearchSpaceRead,
|
||||
SearchSpaceUpdate,
|
||||
)
|
||||
from .users import UserCreate, UserRead, UserUpdate
|
||||
|
||||
__all__ = [
|
||||
"AISDKChatRequest",
|
||||
"TimestampModel",
|
||||
"IDModel",
|
||||
"UserRead",
|
||||
"UserCreate",
|
||||
"UserUpdate",
|
||||
"SearchSpaceBase",
|
||||
"SearchSpaceCreate",
|
||||
"SearchSpaceUpdate",
|
||||
"SearchSpaceRead",
|
||||
"ExtensionDocumentMetadata",
|
||||
"ExtensionDocumentContent",
|
||||
"DocumentBase",
|
||||
"DocumentsCreate",
|
||||
"DocumentUpdate",
|
||||
"DocumentRead",
|
||||
"ChunkBase",
|
||||
"ChunkCreate",
|
||||
"ChunkUpdate",
|
||||
"ChunkRead",
|
||||
"PodcastBase",
|
||||
"PodcastCreate",
|
||||
"PodcastUpdate",
|
||||
"PodcastRead",
|
||||
"PodcastGenerateRequest",
|
||||
"ChatBase",
|
||||
"ChatCreate",
|
||||
"ChatUpdate",
|
||||
"ChatRead",
|
||||
"SearchSourceConnectorBase",
|
||||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorUpdate",
|
||||
"SearchSourceConnectorRead",
|
||||
"ChatUpdate",
|
||||
"ChunkBase",
|
||||
"ChunkCreate",
|
||||
"ChunkRead",
|
||||
"ChunkUpdate",
|
||||
"DocumentBase",
|
||||
"DocumentRead",
|
||||
"DocumentUpdate",
|
||||
"DocumentsCreate",
|
||||
"ExtensionDocumentContent",
|
||||
"ExtensionDocumentMetadata",
|
||||
"IDModel",
|
||||
"LLMConfigBase",
|
||||
"LLMConfigCreate",
|
||||
"LLMConfigUpdate",
|
||||
"LLMConfigRead",
|
||||
"LLMConfigUpdate",
|
||||
"LogBase",
|
||||
"LogCreate",
|
||||
"LogUpdate",
|
||||
"LogRead",
|
||||
"LogFilter",
|
||||
]
|
||||
"LogRead",
|
||||
"LogUpdate",
|
||||
"PodcastBase",
|
||||
"PodcastCreate",
|
||||
"PodcastGenerateRequest",
|
||||
"PodcastRead",
|
||||
"PodcastUpdate",
|
||||
"SearchSourceConnectorBase",
|
||||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorRead",
|
||||
"SearchSourceConnectorUpdate",
|
||||
"SearchSpaceBase",
|
||||
"SearchSpaceCreate",
|
||||
"SearchSpaceRead",
|
||||
"SearchSpaceUpdate",
|
||||
"TimestampModel",
|
||||
"UserCreate",
|
||||
"UserRead",
|
||||
"UserUpdate",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class TimestampModel(BaseModel):
|
||||
created_at: datetime
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class IDModel(BaseModel):
|
||||
id: int
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from app.db import ChatType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
|
@ -9,39 +10,43 @@ from .base import IDModel, TimestampModel
|
|||
class ChatBase(BaseModel):
|
||||
type: ChatType
|
||||
title: str
|
||||
initial_connectors: Optional[List[str]] = None
|
||||
messages: List[Any]
|
||||
initial_connectors: list[str] | None = None
|
||||
messages: list[Any]
|
||||
search_space_id: int
|
||||
|
||||
|
||||
|
||||
class ClientAttachment(BaseModel):
|
||||
name: str
|
||||
contentType: str
|
||||
content_type: str
|
||||
url: str
|
||||
|
||||
|
||||
class ToolInvocation(BaseModel):
|
||||
toolCallId: str
|
||||
toolName: str
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
args: dict
|
||||
result: dict
|
||||
|
||||
|
||||
|
||||
|
||||
# class ClientMessage(BaseModel):
|
||||
# role: str
|
||||
# content: str
|
||||
# experimental_attachments: Optional[List[ClientAttachment]] = None
|
||||
# toolInvocations: Optional[List[ToolInvocation]] = None
|
||||
|
||||
|
||||
|
||||
class AISDKChatRequest(BaseModel):
|
||||
messages: List[Any]
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
messages: list[Any]
|
||||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatCreate(ChatBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChatUpdate(ChatBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChatRead(ChatBase, IDModel, TimestampModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class ChunkBase(BaseModel):
|
||||
content: str
|
||||
document_id: int
|
||||
|
||||
|
||||
class ChunkCreate(ChunkBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChunkUpdate(ChunkBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChunkRead(ChunkBase, IDModel, TimestampModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from typing import List
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from app.db import DocumentType
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from app.db import DocumentType
|
||||
|
||||
|
||||
class ExtensionDocumentMetadata(BaseModel):
|
||||
BrowsingSessionId: str
|
||||
VisitedWebPageURL: str
|
||||
|
|
@ -11,21 +13,28 @@ class ExtensionDocumentMetadata(BaseModel):
|
|||
VisitedWebPageReffererURL: str
|
||||
VisitedWebPageVisitDurationInMilliseconds: str
|
||||
|
||||
|
||||
class ExtensionDocumentContent(BaseModel):
|
||||
metadata: ExtensionDocumentMetadata
|
||||
pageContent: str
|
||||
pageContent: str # noqa: N815
|
||||
|
||||
|
||||
class DocumentBase(BaseModel):
|
||||
document_type: DocumentType
|
||||
content: List[ExtensionDocumentContent] | List[str] | str # Updated to allow string content
|
||||
content: (
|
||||
list[ExtensionDocumentContent] | list[str] | str
|
||||
) # Updated to allow string content
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class DocumentsCreate(DocumentBase):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentUpdate(DocumentBase):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentRead(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
|
|
@ -34,6 +43,5 @@ class DocumentRead(BaseModel):
|
|||
content: str # Changed to string to match frontend
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,34 +1,61 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class LLMConfigBase(BaseModel):
|
||||
name: str = Field(..., max_length=100, description="User-friendly name for the LLM configuration")
|
||||
name: str = Field(
|
||||
..., max_length=100, description="User-friendly name for the LLM configuration"
|
||||
)
|
||||
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
|
||||
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
|
||||
model_name: str = Field(..., max_length=100, description="Model name without provider prefix")
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., max_length=100, description="Model name without provider prefix"
|
||||
)
|
||||
api_key: str = Field(..., description="API key for the provider")
|
||||
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
|
||||
litellm_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional LiteLLM parameters")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
|
||||
class LLMConfigCreate(LLMConfigBase):
|
||||
pass
|
||||
|
||||
|
||||
class LLMConfigUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, max_length=100, description="User-friendly name for the LLM configuration")
|
||||
provider: Optional[LiteLLMProvider] = Field(None, description="LiteLLM provider type")
|
||||
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
|
||||
model_name: Optional[str] = Field(None, max_length=100, description="Model name without provider prefix")
|
||||
api_key: Optional[str] = Field(None, description="API key for the provider")
|
||||
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
|
||||
litellm_params: Optional[Dict[str, Any]] = Field(None, description="Additional LiteLLM parameters")
|
||||
name: str | None = Field(
|
||||
None, max_length=100, description="User-friendly name for the LLM configuration"
|
||||
)
|
||||
provider: LiteLLMProvider | None = Field(None, description="LiteLLM provider type")
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
None, max_length=100, description="Model name without provider prefix"
|
||||
)
|
||||
api_key: str | None = Field(None, description="API key for the provider")
|
||||
api_base: str | None = Field(
|
||||
None, max_length=500, description="Optional API base URL"
|
||||
)
|
||||
litellm_params: dict[str, Any] | None = Field(
|
||||
None, description="Additional LiteLLM parameters"
|
||||
)
|
||||
|
||||
|
||||
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,30 +1,37 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
from app.db import LogLevel, LogStatus
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class LogBase(BaseModel):
|
||||
level: LogLevel
|
||||
status: LogStatus
|
||||
message: str
|
||||
source: Optional[str] = None
|
||||
log_metadata: Optional[Dict[str, Any]] = None
|
||||
source: str | None = None
|
||||
log_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LogCreate(BaseModel):
|
||||
level: LogLevel
|
||||
status: LogStatus
|
||||
message: str
|
||||
source: Optional[str] = None
|
||||
log_metadata: Optional[Dict[str, Any]] = None
|
||||
source: str | None = None
|
||||
log_metadata: dict[str, Any] | None = None
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class LogUpdate(BaseModel):
|
||||
level: Optional[LogLevel] = None
|
||||
status: Optional[LogStatus] = None
|
||||
message: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
log_metadata: Optional[Dict[str, Any]] = None
|
||||
level: LogLevel | None = None
|
||||
status: LogStatus | None = None
|
||||
message: str | None = None
|
||||
source: str | None = None
|
||||
log_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LogRead(LogBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
|
|
@ -33,12 +40,13 @@ class LogRead(LogBase, IDModel, TimestampModel):
|
|||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class LogFilter(BaseModel):
|
||||
level: Optional[LogLevel] = None
|
||||
status: Optional[LogStatus] = None
|
||||
source: Optional[str] = None
|
||||
search_space_id: Optional[int] = None
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
class LogFilter(BaseModel):
|
||||
level: LogLevel | None = None
|
||||
status: LogStatus | None = None
|
||||
source: str | None = None
|
||||
search_space_id: int | None = None
|
||||
start_date: datetime | None = None
|
||||
end_date: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,24 +1,31 @@
|
|||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class PodcastBase(BaseModel):
|
||||
title: str
|
||||
podcast_transcript: List[Any]
|
||||
podcast_transcript: list[Any]
|
||||
file_location: str = ""
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class PodcastCreate(PodcastBase):
|
||||
pass
|
||||
|
||||
|
||||
class PodcastUpdate(PodcastBase):
|
||||
pass
|
||||
|
||||
|
||||
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PodcastGenerateRequest(BaseModel):
|
||||
type: Literal["DOCUMENT", "CHAT"]
|
||||
ids: List[int]
|
||||
ids: list[int]
|
||||
search_space_id: int
|
||||
podcast_title: str = "SurfSense Podcast"
|
||||
podcast_title: str = "SurfSense Podcast"
|
||||
|
|
|
|||
|
|
@ -1,120 +1,164 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, field_validator, ConfigDict
|
||||
from .base import IDModel, TimestampModel
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class SearchSourceConnectorBase(BaseModel):
|
||||
name: str
|
||||
connector_type: SearchSourceConnectorType
|
||||
is_indexable: bool
|
||||
last_indexed_at: Optional[datetime] = None
|
||||
config: Dict[str, Any]
|
||||
|
||||
@field_validator('config')
|
||||
last_indexed_at: datetime | None = None
|
||||
config: dict[str, Any]
|
||||
|
||||
@field_validator("config")
|
||||
@classmethod
|
||||
def validate_config_for_connector_type(cls, config: Dict[str, Any], values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
connector_type = values.data.get('connector_type')
|
||||
|
||||
def validate_config_for_connector_type(
|
||||
cls, config: dict[str, Any], values: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
connector_type = values.data.get("connector_type")
|
||||
|
||||
if connector_type == SearchSourceConnectorType.SERPER_API:
|
||||
# For SERPER_API, only allow SERPER_API_KEY
|
||||
allowed_keys = ["SERPER_API_KEY"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For SERPER_API connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
raise ValueError(
|
||||
f"For SERPER_API connector type, config must only contain these keys: {allowed_keys}"
|
||||
)
|
||||
|
||||
# Ensure the API key is not empty
|
||||
if not config.get("SERPER_API_KEY"):
|
||||
raise ValueError("SERPER_API_KEY cannot be empty")
|
||||
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.TAVILY_API:
|
||||
# For TAVILY_API, only allow TAVILY_API_KEY
|
||||
allowed_keys = ["TAVILY_API_KEY"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For TAVILY_API connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
raise ValueError(
|
||||
f"For TAVILY_API connector type, config must only contain these keys: {allowed_keys}"
|
||||
)
|
||||
|
||||
# Ensure the API key is not empty
|
||||
if not config.get("TAVILY_API_KEY"):
|
||||
raise ValueError("TAVILY_API_KEY cannot be empty")
|
||||
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.LINKUP_API:
|
||||
# For LINKUP_API, only allow LINKUP_API_KEY
|
||||
allowed_keys = ["LINKUP_API_KEY"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For LINKUP_API connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
raise ValueError(
|
||||
f"For LINKUP_API connector type, config must only contain these keys: {allowed_keys}"
|
||||
)
|
||||
|
||||
# Ensure the API key is not empty
|
||||
if not config.get("LINKUP_API_KEY"):
|
||||
raise ValueError("LINKUP_API_KEY cannot be empty")
|
||||
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
# For SLACK_CONNECTOR, only allow SLACK_BOT_TOKEN
|
||||
allowed_keys = ["SLACK_BOT_TOKEN"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For SLACK_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
|
||||
raise ValueError(
|
||||
f"For SLACK_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
|
||||
)
|
||||
|
||||
# Ensure the bot token is not empty
|
||||
if not config.get("SLACK_BOT_TOKEN"):
|
||||
raise ValueError("SLACK_BOT_TOKEN cannot be empty")
|
||||
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
# For NOTION_CONNECTOR, only allow NOTION_INTEGRATION_TOKEN
|
||||
allowed_keys = ["NOTION_INTEGRATION_TOKEN"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For NOTION_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
raise ValueError(
|
||||
f"For NOTION_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
|
||||
)
|
||||
|
||||
# Ensure the integration token is not empty
|
||||
if not config.get("NOTION_INTEGRATION_TOKEN"):
|
||||
raise ValueError("NOTION_INTEGRATION_TOKEN cannot be empty")
|
||||
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR:
|
||||
# For GITHUB_CONNECTOR, only allow GITHUB_PAT and repo_full_names
|
||||
allowed_keys = ["GITHUB_PAT", "repo_full_names"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For GITHUB_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
raise ValueError(
|
||||
f"For GITHUB_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
|
||||
)
|
||||
|
||||
# Ensure the token is not empty
|
||||
if not config.get("GITHUB_PAT"):
|
||||
raise ValueError("GITHUB_PAT cannot be empty")
|
||||
|
||||
|
||||
# Ensure the repo_full_names is present and is a non-empty list
|
||||
repo_full_names = config.get("repo_full_names")
|
||||
if not isinstance(repo_full_names, list) or not repo_full_names:
|
||||
raise ValueError("repo_full_names must be a non-empty list of strings")
|
||||
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
|
||||
# For LINEAR_CONNECTOR, only allow LINEAR_API_KEY
|
||||
allowed_keys = ["LINEAR_API_KEY"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For LINEAR_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
raise ValueError(
|
||||
f"For LINEAR_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
|
||||
)
|
||||
|
||||
# Ensure the token is not empty
|
||||
if not config.get("LINEAR_API_KEY"):
|
||||
raise ValueError("LINEAR_API_KEY cannot be empty")
|
||||
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
|
||||
# For DISCORD_CONNECTOR, only allow DISCORD_BOT_TOKEN
|
||||
allowed_keys = ["DISCORD_BOT_TOKEN"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For DISCORD_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
|
||||
raise ValueError(
|
||||
f"For DISCORD_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
|
||||
)
|
||||
|
||||
# Ensure the bot token is not empty
|
||||
if not config.get("DISCORD_BOT_TOKEN"):
|
||||
raise ValueError("DISCORD_BOT_TOKEN cannot be empty")
|
||||
elif connector_type == SearchSourceConnectorType.JIRA_CONNECTOR:
|
||||
# For JIRA_CONNECTOR, require JIRA_EMAIL, JIRA_API_TOKEN and JIRA_BASE_URL
|
||||
allowed_keys = ["JIRA_EMAIL", "JIRA_API_TOKEN", "JIRA_BASE_URL"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(
|
||||
f"For JIRA_CONNECTOR connector type, config must only contain these keys: {allowed_keys}"
|
||||
)
|
||||
|
||||
# Ensure the email is not empty
|
||||
if not config.get("JIRA_EMAIL"):
|
||||
raise ValueError("JIRA_EMAIL cannot be empty")
|
||||
|
||||
# Ensure the API token is not empty
|
||||
if not config.get("JIRA_API_TOKEN"):
|
||||
raise ValueError("JIRA_API_TOKEN cannot be empty")
|
||||
|
||||
# Ensure the base URL is not empty
|
||||
if not config.get("JIRA_BASE_URL"):
|
||||
raise ValueError("JIRA_BASE_URL cannot be empty")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class SearchSourceConnectorCreate(SearchSourceConnectorBase):
|
||||
pass
|
||||
|
||||
|
||||
class SearchSourceConnectorUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
connector_type: Optional[SearchSourceConnectorType] = None
|
||||
is_indexable: Optional[bool] = None
|
||||
last_indexed_at: Optional[datetime] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
name: str | None = None
|
||||
connector_type: SearchSourceConnectorType | None = None
|
||||
is_indexable: bool | None = None
|
||||
last_indexed_at: datetime | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,22 +1,27 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
||||
class SearchSpaceBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class SearchSpaceCreate(SearchSpaceBase):
|
||||
pass
|
||||
|
||||
|
||||
class SearchSpaceUpdate(SearchSpaceBase):
|
||||
pass
|
||||
|
||||
|
||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
import uuid
|
||||
|
||||
from fastapi_users import schemas
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
pass
|
||||
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
pass
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
# Services package
|
||||
# Services package
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -5,15 +5,16 @@ SSL-safe implementation with pre-downloaded models
|
|||
"""
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DoclingService:
|
||||
"""Docling service for enhanced document processing with SSL fixes."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Docling service with SSL, model fixes, and GPU acceleration."""
|
||||
self.converter = None
|
||||
|
|
@ -21,30 +22,32 @@ class DoclingService:
|
|||
self._configure_ssl_environment()
|
||||
self._check_wsl2_gpu_support()
|
||||
self._initialize_docling()
|
||||
|
||||
|
||||
def _configure_ssl_environment(self):
|
||||
"""Configure SSL environment for secure model downloads."""
|
||||
try:
|
||||
# Set SSL context for downloads
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
|
||||
|
||||
# Set SSL environment variables if not already set
|
||||
if not os.environ.get('SSL_CERT_FILE'):
|
||||
if not os.environ.get("SSL_CERT_FILE"):
|
||||
try:
|
||||
import certifi
|
||||
os.environ['SSL_CERT_FILE'] = certifi.where()
|
||||
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
|
||||
|
||||
os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
logger.info("🔐 SSL environment configured for model downloads")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ SSL configuration warning: {e}")
|
||||
|
||||
|
||||
def _check_wsl2_gpu_support(self):
|
||||
"""Check and configure GPU support for WSL2 environment."""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown"
|
||||
|
|
@ -60,34 +63,34 @@ class DoclingService:
|
|||
except Exception as e:
|
||||
logger.warning(f"⚠️ GPU detection failed: {e}, falling back to CPU")
|
||||
self.use_gpu = False
|
||||
|
||||
|
||||
def _initialize_docling(self):
|
||||
"""Initialize Docling with version-safe configuration."""
|
||||
try:
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
||||
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
|
||||
logger.info("🔧 Initializing Docling with version-safe configuration...")
|
||||
|
||||
|
||||
# Create pipeline options with version-safe attribute checking
|
||||
pipeline_options = PdfPipelineOptions()
|
||||
|
||||
|
||||
# Disable OCR (user request)
|
||||
if hasattr(pipeline_options, 'do_ocr'):
|
||||
if hasattr(pipeline_options, "do_ocr"):
|
||||
pipeline_options.do_ocr = False
|
||||
logger.info("⚠️ OCR disabled by user request")
|
||||
else:
|
||||
logger.warning("⚠️ OCR attribute not available in this Docling version")
|
||||
|
||||
|
||||
# Enable table structure if available
|
||||
if hasattr(pipeline_options, 'do_table_structure'):
|
||||
if hasattr(pipeline_options, "do_table_structure"):
|
||||
pipeline_options.do_table_structure = True
|
||||
logger.info("✅ Table structure detection enabled")
|
||||
|
||||
|
||||
# Configure GPU acceleration for WSL2 if available
|
||||
if hasattr(pipeline_options, 'accelerator_device'):
|
||||
if hasattr(pipeline_options, "accelerator_device"):
|
||||
if self.use_gpu:
|
||||
try:
|
||||
pipeline_options.accelerator_device = "cuda"
|
||||
|
|
@ -99,164 +102,180 @@ class DoclingService:
|
|||
pipeline_options.accelerator_device = "cpu"
|
||||
logger.info("🖥️ Using CPU acceleration")
|
||||
else:
|
||||
logger.info("ℹ️ Accelerator device attribute not available in this Docling version")
|
||||
|
||||
logger.info(
|
||||
"⚠️ Accelerator device attribute not available in this Docling version"
|
||||
)
|
||||
|
||||
# Create PDF format option with backend
|
||||
pdf_format_option = PdfFormatOption(
|
||||
pipeline_options=pipeline_options,
|
||||
backend=PyPdfiumDocumentBackend
|
||||
pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend
|
||||
)
|
||||
|
||||
|
||||
# Initialize DocumentConverter
|
||||
self.converter = DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.PDF: pdf_format_option
|
||||
}
|
||||
format_options={InputFormat.PDF: pdf_format_option}
|
||||
)
|
||||
|
||||
|
||||
acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU"
|
||||
logger.info(f"✅ Docling initialized successfully with {acceleration_type} acceleration")
|
||||
|
||||
logger.info(
|
||||
f"✅ Docling initialized successfully with {acceleration_type} acceleration"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"❌ Docling not installed: {e}")
|
||||
raise RuntimeError(f"Docling not available: {e}")
|
||||
raise RuntimeError(f"Docling not available: {e}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Docling initialization failed: {e}")
|
||||
raise RuntimeError(f"Docling initialization failed: {e}")
|
||||
|
||||
raise RuntimeError(f"Docling initialization failed: {e}") from e
|
||||
|
||||
def _configure_easyocr_local_models(self):
|
||||
"""Configure EasyOCR to use pre-downloaded local models."""
|
||||
try:
|
||||
import easyocr
|
||||
import os
|
||||
|
||||
|
||||
import easyocr
|
||||
|
||||
# Set SSL environment for EasyOCR downloads
|
||||
os.environ['CURL_CA_BUNDLE'] = ''
|
||||
os.environ['REQUESTS_CA_BUNDLE'] = ''
|
||||
|
||||
os.environ["CURL_CA_BUNDLE"] = ""
|
||||
os.environ["REQUESTS_CA_BUNDLE"] = ""
|
||||
|
||||
# Try to use local models first, fallback to download if needed
|
||||
try:
|
||||
reader = easyocr.Reader(['en'],
|
||||
download_enabled=False,
|
||||
model_storage_directory="/root/.EasyOCR/model")
|
||||
reader = easyocr.Reader(
|
||||
["en"],
|
||||
download_enabled=False,
|
||||
model_storage_directory="/root/.EasyOCR/model",
|
||||
)
|
||||
logger.info("✅ EasyOCR configured for local models")
|
||||
return reader
|
||||
except:
|
||||
except Exception:
|
||||
# If local models fail, allow download with SSL bypass
|
||||
logger.info("🔄 Local models failed, attempting download with SSL bypass...")
|
||||
reader = easyocr.Reader(['en'],
|
||||
download_enabled=True,
|
||||
model_storage_directory="/root/.EasyOCR/model")
|
||||
logger.info(
|
||||
"🔄 Local models failed, attempting download with SSL bypass..."
|
||||
)
|
||||
reader = easyocr.Reader(
|
||||
["en"],
|
||||
download_enabled=True,
|
||||
model_storage_directory="/root/.EasyOCR/model",
|
||||
)
|
||||
logger.info("✅ EasyOCR configured with downloaded models")
|
||||
return reader
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ EasyOCR configuration failed: {e}")
|
||||
return None
|
||||
|
||||
async def process_document(self, file_path: str, filename: str = None) -> Dict[str, Any]:
|
||||
|
||||
async def process_document(
|
||||
self, file_path: str, filename: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Process document with Docling using pre-downloaded models."""
|
||||
|
||||
|
||||
if self.converter is None:
|
||||
raise RuntimeError("Docling converter not initialized")
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"🔄 Processing {filename} with Docling (using local models)...")
|
||||
|
||||
logger.info(
|
||||
f"🔄 Processing {filename} with Docling (using local models)..."
|
||||
)
|
||||
|
||||
# Process document with local models
|
||||
result = self.converter.convert(file_path)
|
||||
|
||||
|
||||
# Extract content using version-safe methods
|
||||
content = None
|
||||
if hasattr(result, 'document') and result.document:
|
||||
if hasattr(result, "document") and result.document:
|
||||
# Try different export methods (version compatibility)
|
||||
if hasattr(result.document, 'export_to_markdown'):
|
||||
if hasattr(result.document, "export_to_markdown"):
|
||||
content = result.document.export_to_markdown()
|
||||
logger.info("📄 Used export_to_markdown method")
|
||||
elif hasattr(result.document, 'to_markdown'):
|
||||
elif hasattr(result.document, "to_markdown"):
|
||||
content = result.document.to_markdown()
|
||||
logger.info("📄 Used to_markdown method")
|
||||
elif hasattr(result.document, 'text'):
|
||||
elif hasattr(result.document, "text"):
|
||||
content = result.document.text
|
||||
logger.info("📄 Used text property")
|
||||
elif hasattr(result.document, '__str__'):
|
||||
elif hasattr(result.document, "__str__"):
|
||||
content = str(result.document)
|
||||
logger.info("📄 Used string conversion")
|
||||
|
||||
|
||||
if content:
|
||||
logger.info(f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)")
|
||||
|
||||
logger.info(
|
||||
f"✅ Docling SUCCESS - {filename}: {len(content)} chars (local models)"
|
||||
)
|
||||
|
||||
return {
|
||||
'content': content,
|
||||
'full_text': content,
|
||||
'service_used': 'docling',
|
||||
'status': 'success',
|
||||
'processing_notes': 'Processed with Docling using pre-downloaded models'
|
||||
"content": content,
|
||||
"full_text": content,
|
||||
"service_used": "docling",
|
||||
"status": "success",
|
||||
"processing_notes": "Processed with Docling using pre-downloaded models",
|
||||
}
|
||||
else:
|
||||
raise ValueError("No content could be extracted from document")
|
||||
else:
|
||||
raise ValueError("No document object returned by Docling")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Docling processing failed for {filename}: {e}")
|
||||
# Log the full error for debugging
|
||||
import traceback
|
||||
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
raise RuntimeError(f"Docling processing failed: {e}")
|
||||
|
||||
raise RuntimeError(f"Docling processing failed: {e}") from e
|
||||
|
||||
async def process_large_document_summary(
|
||||
self,
|
||||
content: str,
|
||||
llm,
|
||||
document_title: str = "Document"
|
||||
self, content: str, llm, document_title: str = "Document"
|
||||
) -> str:
|
||||
"""
|
||||
Process large documents using chunked LLM summarization.
|
||||
|
||||
|
||||
Args:
|
||||
content: The full document content
|
||||
llm: The language model to use for summarization
|
||||
document_title: Title of the document for context
|
||||
|
||||
|
||||
Returns:
|
||||
Final summary of the document
|
||||
"""
|
||||
# Large document threshold (100K characters ≈ 25K tokens)
|
||||
LARGE_DOCUMENT_THRESHOLD = 100_000
|
||||
|
||||
if len(content) <= LARGE_DOCUMENT_THRESHOLD:
|
||||
large_document_threshold = 100_000
|
||||
|
||||
if len(content) <= large_document_threshold:
|
||||
# For smaller documents, use direct processing
|
||||
logger.info(f"📄 Document size: {len(content)} chars - using direct processing")
|
||||
logger.info(
|
||||
f"📄 Document size: {len(content)} chars - using direct processing"
|
||||
)
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
||||
result = await summary_chain.ainvoke({"document": content})
|
||||
return result.content
|
||||
|
||||
logger.info(f"📚 Large document detected: {len(content)} chars - using chunked processing")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"📚 Large document detected: {len(content)} chars - using chunked processing"
|
||||
)
|
||||
|
||||
# Import chunker from config
|
||||
from app.config import config
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
# Create LLM-optimized chunks (8K tokens max for safety)
|
||||
from chonkie import RecursiveChunker, OverlapRefinery
|
||||
from chonkie import OverlapRefinery, RecursiveChunker
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
llm_chunker = RecursiveChunker(
|
||||
chunk_size=8000 # Conservative for most LLMs
|
||||
)
|
||||
|
||||
|
||||
# Apply overlap refinery for context preservation (10% overlap = 800 tokens)
|
||||
overlap_refinery = OverlapRefinery(
|
||||
context_size=0.1, # 10% overlap for context preservation
|
||||
method="suffix" # Add next chunk context to current chunk
|
||||
method="suffix", # Add next chunk context to current chunk
|
||||
)
|
||||
|
||||
|
||||
# First chunk the content, then apply overlap refinery
|
||||
initial_chunks = llm_chunker.chunk(content)
|
||||
chunks = overlap_refinery.refine(initial_chunks)
|
||||
total_chunks = len(chunks)
|
||||
|
||||
|
||||
logger.info(f"📄 Split into {total_chunks} chunks for LLM processing")
|
||||
|
||||
|
||||
# Template for chunk processing
|
||||
chunk_template = PromptTemplate(
|
||||
input_variables=["chunk", "chunk_number", "total_chunks"],
|
||||
|
|
@ -274,34 +293,38 @@ Chunk {chunk_number}/{total_chunks}:
|
|||
<document_chunk>
|
||||
{chunk}
|
||||
</document_chunk>
|
||||
</INSTRUCTIONS>"""
|
||||
</INSTRUCTIONS>""",
|
||||
)
|
||||
|
||||
|
||||
# Process each chunk individually
|
||||
chunk_summaries = []
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
try:
|
||||
logger.info(f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)")
|
||||
|
||||
logger.info(
|
||||
f"🔄 Processing chunk {i}/{total_chunks} ({len(chunk.text)} chars)"
|
||||
)
|
||||
|
||||
chunk_chain = chunk_template | llm
|
||||
chunk_result = await chunk_chain.ainvoke({
|
||||
"chunk": chunk.text,
|
||||
"chunk_number": i,
|
||||
"total_chunks": total_chunks
|
||||
})
|
||||
|
||||
chunk_result = await chunk_chain.ainvoke(
|
||||
{
|
||||
"chunk": chunk.text,
|
||||
"chunk_number": i,
|
||||
"total_chunks": total_chunks,
|
||||
}
|
||||
)
|
||||
|
||||
chunk_summary = chunk_result.content
|
||||
chunk_summaries.append(f"=== Section {i} ===\n{chunk_summary}")
|
||||
|
||||
|
||||
logger.info(f"✅ Completed chunk {i}/{total_chunks}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to process chunk {i}/{total_chunks}: {e}")
|
||||
chunk_summaries.append(f"=== Section {i} ===\n[Processing failed]")
|
||||
|
||||
|
||||
# Combine summaries into final document summary
|
||||
logger.info(f"🔄 Combining {len(chunk_summaries)} chunk summaries")
|
||||
|
||||
|
||||
try:
|
||||
combine_template = PromptTemplate(
|
||||
input_variables=["summaries", "document_title"],
|
||||
|
|
@ -318,22 +341,23 @@ Ensure:
|
|||
<section_summaries>
|
||||
{summaries}
|
||||
</section_summaries>
|
||||
</INSTRUCTIONS>"""
|
||||
</INSTRUCTIONS>""",
|
||||
)
|
||||
|
||||
|
||||
combined_summaries = "\n\n".join(chunk_summaries)
|
||||
combine_chain = combine_template | llm
|
||||
|
||||
final_result = await combine_chain.ainvoke({
|
||||
"summaries": combined_summaries,
|
||||
"document_title": document_title
|
||||
})
|
||||
|
||||
|
||||
final_result = await combine_chain.ainvoke(
|
||||
{"summaries": combined_summaries, "document_title": document_title}
|
||||
)
|
||||
|
||||
final_summary = final_result.content
|
||||
logger.info(f"✅ Large document processing complete: {len(final_summary)} chars summary")
|
||||
|
||||
logger.info(
|
||||
f"✅ Large document processing complete: {len(final_summary)} chars summary"
|
||||
)
|
||||
|
||||
return final_summary
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to combine summaries: {e}")
|
||||
# Fallback: return concatenated chunk summaries
|
||||
|
|
@ -341,6 +365,7 @@ Ensure:
|
|||
logger.warning("⚠️ Using fallback combined summary")
|
||||
return fallback_summary
|
||||
|
||||
|
||||
def create_docling_service() -> DoclingService:
|
||||
"""Create a Docling service instance."""
|
||||
return DoclingService()
|
||||
return DoclingService()
|
||||
|
|
|
|||
|
|
@ -1,45 +1,43 @@
|
|||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
import logging
|
||||
|
||||
from app.db import User, LLMConfig
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import LLMConfig, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMRole:
|
||||
LONG_CONTEXT = "long_context"
|
||||
FAST = "fast"
|
||||
STRATEGIC = "strategic"
|
||||
|
||||
|
||||
async def get_user_llm_instance(
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
role: str
|
||||
) -> Optional[ChatLiteLLM]:
|
||||
session: AsyncSession, user_id: str, role: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""
|
||||
Get a ChatLiteLLM instance for a specific user and role.
|
||||
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
role: LLM role ('long_context', 'fast', or 'strategic')
|
||||
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None if not found
|
||||
"""
|
||||
try:
|
||||
# Get user with their LLM preferences
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
|
||||
|
||||
if not user:
|
||||
logger.error(f"User {user_id} not found")
|
||||
return None
|
||||
|
||||
|
||||
# Get the appropriate LLM config ID based on role
|
||||
llm_config_id = None
|
||||
if role == LLMRole.LONG_CONTEXT:
|
||||
|
|
@ -51,24 +49,23 @@ async def get_user_llm_instance(
|
|||
else:
|
||||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
||||
|
||||
if not llm_config_id:
|
||||
logger.error(f"No {role} LLM configured for user {user_id}")
|
||||
return None
|
||||
|
||||
|
||||
# Get the LLM configuration
|
||||
result = await session.execute(
|
||||
select(LLMConfig).where(
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.user_id == user_id
|
||||
LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
|
||||
|
||||
if not llm_config:
|
||||
logger.error(f"LLM config {llm_config_id} not found for user {user_id}")
|
||||
return None
|
||||
|
||||
|
||||
# Build the model string for litellm
|
||||
if llm_config.custom_provider:
|
||||
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
|
||||
|
|
@ -76,7 +73,7 @@ async def get_user_llm_instance(
|
|||
# Map provider enum to litellm format
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
|
|
@ -84,37 +81,48 @@ async def get_user_llm_instance(
|
|||
"MISTRAL": "mistral",
|
||||
# Add more mappings as needed
|
||||
}
|
||||
provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower())
|
||||
provider_prefix = provider_map.get(
|
||||
llm_config.provider.value, llm_config.provider.value.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{llm_config.model_name}"
|
||||
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": llm_config.api_key,
|
||||
}
|
||||
|
||||
|
||||
# Add optional parameters
|
||||
if llm_config.api_base:
|
||||
litellm_kwargs["api_base"] = llm_config.api_base
|
||||
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if llm_config.litellm_params:
|
||||
litellm_kwargs.update(llm_config.litellm_params)
|
||||
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error getting LLM instance for user {user_id}, role {role}: {e!s}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
|
||||
async def get_user_long_context_llm(
|
||||
session: AsyncSession, user_id: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""Get user's long context LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
|
||||
|
||||
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
|
||||
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None:
|
||||
"""Get user's fast LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
|
||||
|
||||
async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
|
||||
async def get_user_strategic_llm(
|
||||
session: AsyncSession, user_id: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""Get user's strategic LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import datetime
|
||||
from langchain.schema import HumanMessage, SystemMessage, AIMessage
|
||||
from app.config import config
|
||||
from app.services.llm_service import get_user_strategic_llm
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from app.services.llm_service import get_user_strategic_llm
|
||||
|
||||
|
||||
class QueryService:
|
||||
|
|
@ -13,13 +14,13 @@ class QueryService:
|
|||
|
||||
@staticmethod
|
||||
async def reformulate_query_with_chat_history(
|
||||
user_query: str,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
chat_history_str: Optional[str] = None
|
||||
user_query: str,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
chat_history_str: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Reformulate the user query using the user's strategic LLM to make it more
|
||||
Reformulate the user query using the user's strategic LLM to make it more
|
||||
effective for information retrieval and research purposes.
|
||||
|
||||
Args:
|
||||
|
|
@ -38,7 +39,9 @@ class QueryService:
|
|||
# Get the user's strategic LLM instance
|
||||
llm = await get_user_strategic_llm(session, user_id)
|
||||
if not llm:
|
||||
print(f"Warning: No strategic LLM configured for user {user_id}. Using original query.")
|
||||
print(
|
||||
f"Warning: No strategic LLM configured for user {user_id}. Using original query."
|
||||
)
|
||||
return user_query
|
||||
|
||||
# Create system message with instructions
|
||||
|
|
@ -92,14 +95,13 @@ class QueryService:
|
|||
print(f"Error reformulating query: {e}")
|
||||
return user_query
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def langchain_chat_history_to_str(chat_history: List[Any]) -> str:
|
||||
async def langchain_chat_history_to_str(chat_history: list[Any]) -> str:
|
||||
"""
|
||||
Convert a list of chat history messages to a string.
|
||||
"""
|
||||
chat_history_str = "<chat_history>\n"
|
||||
|
||||
|
||||
for chat_message in chat_history:
|
||||
if isinstance(chat_message, HumanMessage):
|
||||
chat_history_str += f"<user>{chat_message.content}</user>\n"
|
||||
|
|
@ -107,6 +109,6 @@ class QueryService:
|
|||
chat_history_str += f"<assistant>{chat_message.content}</assistant>\n"
|
||||
elif isinstance(chat_message, SystemMessage):
|
||||
chat_history_str += f"<system>{chat_message.content}</system>\n"
|
||||
|
||||
|
||||
chat_history_str += "</chat_history>"
|
||||
return chat_history_str
|
||||
|
|
|
|||
|
|
@ -1,35 +1,39 @@
|
|||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from rerankers import Document as RerankerDocument
|
||||
|
||||
|
||||
class RerankerService:
|
||||
"""
|
||||
Service for reranking documents using a configured reranker
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, reranker_instance=None):
|
||||
"""
|
||||
Initialize the reranker service
|
||||
|
||||
|
||||
Args:
|
||||
reranker_instance: The reranker instance to use for reranking
|
||||
"""
|
||||
self.reranker_instance = reranker_instance
|
||||
|
||||
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
|
||||
def rerank_documents(
|
||||
self, query_text: str, documents: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using the configured reranker
|
||||
|
||||
|
||||
Args:
|
||||
query_text: The query text to use for reranking
|
||||
documents: List of document dictionaries to rerank
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Reranked documents
|
||||
"""
|
||||
if not self.reranker_instance or not documents:
|
||||
return documents
|
||||
|
||||
|
||||
try:
|
||||
# Create Document objects for the rerankers library
|
||||
reranker_docs = []
|
||||
|
|
@ -38,58 +42,63 @@ class RerankerService:
|
|||
content = doc.get("content", "")
|
||||
score = doc.get("score", 0.0)
|
||||
document_info = doc.get("document", {})
|
||||
|
||||
|
||||
reranker_docs.append(
|
||||
RerankerDocument(
|
||||
text=content,
|
||||
doc_id=chunk_id,
|
||||
metadata={
|
||||
'document_id': document_info.get("id", ""),
|
||||
'document_title': document_info.get("title", ""),
|
||||
'document_type': document_info.get("document_type", ""),
|
||||
'rrf_score': score
|
||||
}
|
||||
"document_id": document_info.get("id", ""),
|
||||
"document_title": document_info.get("title", ""),
|
||||
"document_type": document_info.get("document_type", ""),
|
||||
"rrf_score": score,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Rerank using the configured reranker
|
||||
reranking_results = self.reranker_instance.rank(
|
||||
query=query_text,
|
||||
docs=reranker_docs
|
||||
query=query_text, docs=reranker_docs
|
||||
)
|
||||
|
||||
|
||||
# Process the results from the reranker
|
||||
# Convert to serializable dictionaries
|
||||
serialized_results = []
|
||||
for result in reranking_results.results:
|
||||
# Find the original document by id
|
||||
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None)
|
||||
original_doc = next(
|
||||
(
|
||||
doc
|
||||
for doc in documents
|
||||
if doc.get("chunk_id") == result.document.doc_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if original_doc:
|
||||
# Create a new document with the reranked score
|
||||
reranked_doc = original_doc.copy()
|
||||
reranked_doc["score"] = float(result.score)
|
||||
reranked_doc["rank"] = result.rank
|
||||
serialized_results.append(reranked_doc)
|
||||
|
||||
|
||||
return serialized_results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
logging.error(f"Error during reranking: {str(e)}")
|
||||
logging.error(f"Error during reranking: {e!s}")
|
||||
# Fall back to original documents without reranking
|
||||
return documents
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_reranker_instance() -> Optional['RerankerService']:
|
||||
def get_reranker_instance() -> Optional["RerankerService"]:
|
||||
"""
|
||||
Get a reranker service instance from the global configuration.
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[RerankerService]: A reranker service instance if configured, None otherwise
|
||||
"""
|
||||
from app.config import config
|
||||
|
||||
if hasattr(config, 'reranker_instance') and config.reranker_instance:
|
||||
|
||||
if hasattr(config, "reranker_instance") and config.reranker_instance:
|
||||
return RerankerService(config.reranker_instance)
|
||||
return None
|
||||
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StreamingService:
|
||||
|
|
@ -46,7 +46,7 @@ class StreamingService:
|
|||
annotation = {"type": "TERMINAL_INFO", "data": message}
|
||||
return f"8:[{json.dumps(annotation)}]\n"
|
||||
|
||||
def format_sources_delta(self, sources: List[Dict[str, Any]]) -> str:
|
||||
def format_sources_delta(self, sources: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format sources as a delta annotation
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ class StreamingService:
|
|||
annotation = {"type": "ANSWER", "content": [answer_chunk]}
|
||||
return f"8:[{json.dumps(annotation)}]\n"
|
||||
|
||||
def format_answer_annotation(self, answer_lines: List[str]) -> str:
|
||||
def format_answer_annotation(self, answer_lines: list[str]) -> str:
|
||||
"""
|
||||
Format the complete answer as a replacement annotation
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ class StreamingService:
|
|||
return f"8:[{json.dumps(annotation)}]\n"
|
||||
|
||||
def format_further_questions_delta(
|
||||
self, further_questions: List[Dict[str, Any]]
|
||||
self, further_questions: list[dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
Format further questions as a delta annotation
|
||||
|
|
|
|||
|
|
@ -1,111 +1,116 @@
|
|||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.db import Log, LogLevel, LogStatus
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Log, LogLevel, LogStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskLoggingService:
|
||||
"""Service for logging background tasks using the database Log model"""
|
||||
|
||||
|
||||
def __init__(self, session: AsyncSession, search_space_id: int):
|
||||
self.session = session
|
||||
self.search_space_id = search_space_id
|
||||
|
||||
|
||||
async def log_task_start(
|
||||
self,
|
||||
task_name: str,
|
||||
source: str,
|
||||
message: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Log the start of a task with IN_PROGRESS status
|
||||
|
||||
|
||||
Args:
|
||||
task_name: Name/identifier of the task
|
||||
source: Source service/component (e.g., 'document_processor', 'slack_indexer')
|
||||
message: Human-readable message about the task
|
||||
metadata: Additional context data
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The created log entry
|
||||
"""
|
||||
log_metadata = metadata or {}
|
||||
log_metadata.update({
|
||||
"task_name": task_name,
|
||||
"started_at": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
log_metadata.update(
|
||||
{"task_name": task_name, "started_at": datetime.utcnow().isoformat()}
|
||||
)
|
||||
|
||||
log_entry = Log(
|
||||
level=LogLevel.INFO,
|
||||
status=LogStatus.IN_PROGRESS,
|
||||
message=message,
|
||||
source=source,
|
||||
log_metadata=log_metadata,
|
||||
search_space_id=self.search_space_id
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
|
||||
|
||||
self.session.add(log_entry)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
|
||||
logger.info(f"Started task {task_name}: {message}")
|
||||
return log_entry
|
||||
|
||||
|
||||
async def log_task_success(
|
||||
self,
|
||||
log_entry: Log,
|
||||
message: str,
|
||||
additional_metadata: Optional[Dict[str, Any]] = None
|
||||
additional_metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Update a log entry to SUCCESS status
|
||||
|
||||
|
||||
Args:
|
||||
log_entry: The original log entry to update
|
||||
message: Success message
|
||||
additional_metadata: Additional metadata to merge
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The updated log entry
|
||||
"""
|
||||
# Update the existing log entry
|
||||
log_entry.status = LogStatus.SUCCESS
|
||||
log_entry.message = message
|
||||
|
||||
|
||||
# Merge additional metadata
|
||||
if additional_metadata:
|
||||
if log_entry.log_metadata is None:
|
||||
log_entry.log_metadata = {}
|
||||
log_entry.log_metadata.update(additional_metadata)
|
||||
log_entry.log_metadata["completed_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
|
||||
|
||||
task_name = (
|
||||
log_entry.log_metadata.get("task_name", "unknown")
|
||||
if log_entry.log_metadata
|
||||
else "unknown"
|
||||
)
|
||||
logger.info(f"Completed task {task_name}: {message}")
|
||||
return log_entry
|
||||
|
||||
|
||||
async def log_task_failure(
|
||||
self,
|
||||
log_entry: Log,
|
||||
error_message: str,
|
||||
error_details: Optional[str] = None,
|
||||
additional_metadata: Optional[Dict[str, Any]] = None
|
||||
error_details: str | None = None,
|
||||
additional_metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Update a log entry to FAILED status
|
||||
|
||||
|
||||
Args:
|
||||
log_entry: The original log entry to update
|
||||
error_message: Error message
|
||||
error_details: Detailed error information
|
||||
additional_metadata: Additional metadata to merge
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The updated log entry
|
||||
"""
|
||||
|
|
@ -113,77 +118,86 @@ class TaskLoggingService:
|
|||
log_entry.status = LogStatus.FAILED
|
||||
log_entry.level = LogLevel.ERROR
|
||||
log_entry.message = error_message
|
||||
|
||||
|
||||
# Merge additional metadata
|
||||
if log_entry.log_metadata is None:
|
||||
log_entry.log_metadata = {}
|
||||
|
||||
log_entry.log_metadata.update({
|
||||
"failed_at": datetime.utcnow().isoformat(),
|
||||
"error_details": error_details
|
||||
})
|
||||
|
||||
|
||||
log_entry.log_metadata.update(
|
||||
{"failed_at": datetime.utcnow().isoformat(), "error_details": error_details}
|
||||
)
|
||||
|
||||
if additional_metadata:
|
||||
log_entry.log_metadata.update(additional_metadata)
|
||||
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
|
||||
|
||||
task_name = (
|
||||
log_entry.log_metadata.get("task_name", "unknown")
|
||||
if log_entry.log_metadata
|
||||
else "unknown"
|
||||
)
|
||||
logger.error(f"Failed task {task_name}: {error_message}")
|
||||
if error_details:
|
||||
logger.error(f"Error details: {error_details}")
|
||||
|
||||
|
||||
return log_entry
|
||||
|
||||
|
||||
async def log_task_progress(
|
||||
self,
|
||||
log_entry: Log,
|
||||
progress_message: str,
|
||||
progress_metadata: Optional[Dict[str, Any]] = None
|
||||
progress_metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Update a log entry with progress information while keeping IN_PROGRESS status
|
||||
|
||||
|
||||
Args:
|
||||
log_entry: The log entry to update
|
||||
progress_message: Progress update message
|
||||
progress_metadata: Additional progress metadata
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The updated log entry
|
||||
"""
|
||||
log_entry.message = progress_message
|
||||
|
||||
|
||||
if progress_metadata:
|
||||
if log_entry.log_metadata is None:
|
||||
log_entry.log_metadata = {}
|
||||
log_entry.log_metadata.update(progress_metadata)
|
||||
log_entry.log_metadata["last_progress_update"] = datetime.utcnow().isoformat()
|
||||
|
||||
log_entry.log_metadata["last_progress_update"] = (
|
||||
datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
task_name = log_entry.log_metadata.get("task_name", "unknown") if log_entry.log_metadata else "unknown"
|
||||
|
||||
task_name = (
|
||||
log_entry.log_metadata.get("task_name", "unknown")
|
||||
if log_entry.log_metadata
|
||||
else "unknown"
|
||||
)
|
||||
logger.info(f"Progress update for task {task_name}: {progress_message}")
|
||||
return log_entry
|
||||
|
||||
|
||||
async def log_simple_event(
|
||||
self,
|
||||
level: LogLevel,
|
||||
source: str,
|
||||
message: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Log:
|
||||
"""
|
||||
Log a simple event (not a long-running task)
|
||||
|
||||
|
||||
Args:
|
||||
level: Log level
|
||||
source: Source service/component
|
||||
message: Log message
|
||||
metadata: Additional context data
|
||||
|
||||
|
||||
Returns:
|
||||
Log: The created log entry
|
||||
"""
|
||||
|
|
@ -193,12 +207,12 @@ class TaskLoggingService:
|
|||
message=message,
|
||||
source=source,
|
||||
log_metadata=metadata or {},
|
||||
search_space_id=self.search_space_id
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
|
||||
|
||||
self.session.add(log_entry)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log_entry)
|
||||
|
||||
|
||||
logger.info(f"Logged event from {source}: {message}")
|
||||
return log_entry
|
||||
return log_entry
|
||||
|
|
|
|||
|
|
@ -1,46 +1,49 @@
|
|||
from typing import Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import logging
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import aiohttp
|
||||
import validators
|
||||
from langchain_community.document_loaders import AsyncChromiumLoader, FireCrawlLoader
|
||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.db import Document, DocumentType, Chunk
|
||||
from app.schemas import ExtensionDocumentContent
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.utils.document_converters import convert_document_to_markdown, generate_content_hash
|
||||
from app.schemas import ExtensionDocumentContent
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
|
||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||
import validators
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
import aiohttp
|
||||
import logging
|
||||
from app.utils.document_converters import (
|
||||
convert_document_to_markdown,
|
||||
generate_content_hash,
|
||||
)
|
||||
|
||||
md = MarkdownifyTransformer()
|
||||
|
||||
|
||||
async def add_crawled_url_document(
|
||||
session: AsyncSession, url: str, search_space_id: int, user_id: str
|
||||
) -> Optional[Document]:
|
||||
) -> Document | None:
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="crawl_url_document",
|
||||
source="background_task",
|
||||
message=f"Starting URL crawling process for: {url}",
|
||||
metadata={"url": url, "user_id": str(user_id)}
|
||||
metadata={"url": url, "user_id": str(user_id)},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# URL validation step
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Validating URL: {url}",
|
||||
{"stage": "validation"}
|
||||
log_entry, f"Validating URL: {url}", {"stage": "validation"}
|
||||
)
|
||||
|
||||
|
||||
if not validators.url(url):
|
||||
raise ValueError(f"Url {url} is not a valid URL address")
|
||||
|
||||
|
|
@ -48,7 +51,10 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Setting up crawler for URL: {url}",
|
||||
{"stage": "crawler_setup", "firecrawl_available": bool(config.FIRECRAWL_API_KEY)}
|
||||
{
|
||||
"stage": "crawler_setup",
|
||||
"firecrawl_available": bool(config.FIRECRAWL_API_KEY),
|
||||
},
|
||||
)
|
||||
|
||||
if config.FIRECRAWL_API_KEY:
|
||||
|
|
@ -68,21 +74,21 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Crawling URL content: {url}",
|
||||
{"stage": "crawling", "crawler_type": type(crawl_loader).__name__}
|
||||
{"stage": "crawling", "crawler_type": type(crawl_loader).__name__},
|
||||
)
|
||||
|
||||
url_crawled = await crawl_loader.aload()
|
||||
|
||||
if type(crawl_loader) == FireCrawlLoader:
|
||||
if isinstance(crawl_loader, FireCrawlLoader):
|
||||
content_in_markdown = url_crawled[0].page_content
|
||||
elif type(crawl_loader) == AsyncChromiumLoader:
|
||||
elif isinstance(crawl_loader, AsyncChromiumLoader):
|
||||
content_in_markdown = md.transform_documents(url_crawled)[0].page_content
|
||||
|
||||
# Format document
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing crawled content from: {url}",
|
||||
{"stage": "content_processing", "content_length": len(content_in_markdown)}
|
||||
{"stage": "content_processing", "content_length": len(content_in_markdown)},
|
||||
)
|
||||
|
||||
# Format document metadata in a more maintainable way
|
||||
|
|
@ -117,7 +123,7 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Checking for duplicate content: {url}",
|
||||
{"stage": "duplicate_check", "content_hash": content_hash}
|
||||
{"stage": "duplicate_check", "content_hash": content_hash},
|
||||
)
|
||||
|
||||
# Check if document with this content hash already exists
|
||||
|
|
@ -125,21 +131,26 @@ async def add_crawled_url_document(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Document already exists for URL: {url}",
|
||||
{"duplicate_detected": True, "existing_document_id": existing_document.id}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
},
|
||||
)
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get LLM for summary generation
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Preparing for summary generation: {url}",
|
||||
{"stage": "llm_setup"}
|
||||
{"stage": "llm_setup"},
|
||||
)
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -151,7 +162,7 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Generating summary for URL content: {url}",
|
||||
{"stage": "summary_generation"}
|
||||
{"stage": "summary_generation"},
|
||||
)
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
|
|
@ -165,7 +176,7 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing content chunks for URL: {url}",
|
||||
{"stage": "chunk_processing"}
|
||||
{"stage": "chunk_processing"},
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -180,13 +191,13 @@ async def add_crawled_url_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Creating document in database for URL: {url}",
|
||||
{"stage": "document_creation", "chunks_count": len(chunks)}
|
||||
{"stage": "document_creation", "chunks_count": len(chunks)},
|
||||
)
|
||||
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=url_crawled[0].metadata["title"]
|
||||
if type(crawl_loader) == FireCrawlLoader
|
||||
if isinstance(crawl_loader, FireCrawlLoader)
|
||||
else url_crawled[0].metadata["source"],
|
||||
document_type=DocumentType.CRAWLED_URL,
|
||||
document_metadata=url_crawled[0].metadata,
|
||||
|
|
@ -209,8 +220,8 @@ async def add_crawled_url_document(
|
|||
"title": document.title,
|
||||
"content_hash": content_hash,
|
||||
"chunks_count": len(chunks),
|
||||
"summary_length": len(summary_content)
|
||||
}
|
||||
"summary_length": len(summary_content),
|
||||
},
|
||||
)
|
||||
|
||||
return document
|
||||
|
|
@ -221,7 +232,7 @@ async def add_crawled_url_document(
|
|||
log_entry,
|
||||
f"Database error while processing URL: {url}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"}
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -230,14 +241,17 @@ async def add_crawled_url_document(
|
|||
log_entry,
|
||||
f"Failed to crawl URL: {url}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
raise RuntimeError(f"Failed to crawl URL: {str(e)}")
|
||||
raise RuntimeError(f"Failed to crawl URL: {e!s}") from e
|
||||
|
||||
|
||||
async def add_extension_received_document(
|
||||
session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int, user_id: str
|
||||
) -> Optional[Document]:
|
||||
session: AsyncSession,
|
||||
content: ExtensionDocumentContent,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content received from the SurfSense Extension.
|
||||
|
||||
|
|
@ -250,7 +264,7 @@ async def add_extension_received_document(
|
|||
Document object if successful, None if failed
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="extension_document",
|
||||
|
|
@ -259,10 +273,10 @@ async def add_extension_received_document(
|
|||
metadata={
|
||||
"url": content.metadata.VisitedWebPageURL,
|
||||
"title": content.metadata.VisitedWebPageTitle,
|
||||
"user_id": str(user_id)
|
||||
}
|
||||
"user_id": str(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# Format document metadata in a more maintainable way
|
||||
metadata_sections = [
|
||||
|
|
@ -301,14 +315,19 @@ async def add_extension_received_document(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Extension document already exists: {content.metadata.VisitedWebPageTitle}",
|
||||
{"duplicate_detected": True, "existing_document_id": existing_document.id}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
},
|
||||
)
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -356,8 +375,8 @@ async def add_extension_received_document(
|
|||
{
|
||||
"document_id": document.id,
|
||||
"content_hash": content_hash,
|
||||
"url": content.metadata.VisitedWebPageURL
|
||||
}
|
||||
"url": content.metadata.VisitedWebPageURL,
|
||||
},
|
||||
)
|
||||
|
||||
return document
|
||||
|
|
@ -368,7 +387,7 @@ async def add_extension_received_document(
|
|||
log_entry,
|
||||
f"Database error processing extension document: {content.metadata.VisitedWebPageTitle}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"}
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -377,24 +396,32 @@ async def add_extension_received_document(
|
|||
log_entry,
|
||||
f"Failed to process extension document: {content.metadata.VisitedWebPageTitle}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
raise RuntimeError(f"Failed to process extension document: {str(e)}")
|
||||
raise RuntimeError(f"Failed to process extension document: {e!s}") from e
|
||||
|
||||
|
||||
async def add_received_markdown_file_document(
|
||||
session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int, user_id: str
|
||||
) -> Optional[Document]:
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
file_in_markdown: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Document | None:
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="markdown_file_document",
|
||||
source="background_task",
|
||||
message=f"Processing markdown file: {file_name}",
|
||||
metadata={"filename": file_name, "user_id": str(user_id), "content_length": len(file_in_markdown)}
|
||||
metadata={
|
||||
"filename": file_name,
|
||||
"user_id": str(user_id),
|
||||
"content_length": len(file_in_markdown),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
content_hash = generate_content_hash(file_in_markdown, search_space_id)
|
||||
|
||||
|
|
@ -403,14 +430,19 @@ async def add_received_markdown_file_document(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Markdown file document already exists: {file_name}",
|
||||
{"duplicate_detected": True, "existing_document_id": existing_document.id}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
},
|
||||
)
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -459,8 +491,8 @@ async def add_received_markdown_file_document(
|
|||
"document_id": document.id,
|
||||
"content_hash": content_hash,
|
||||
"chunks_count": len(chunks),
|
||||
"summary_length": len(summary_content)
|
||||
}
|
||||
"summary_length": len(summary_content),
|
||||
},
|
||||
)
|
||||
|
||||
return document
|
||||
|
|
@ -470,7 +502,7 @@ async def add_received_markdown_file_document(
|
|||
log_entry,
|
||||
f"Database error processing markdown file: {file_name}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"}
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -479,18 +511,18 @@ async def add_received_markdown_file_document(
|
|||
log_entry,
|
||||
f"Failed to process markdown file: {file_name}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
||||
raise RuntimeError(f"Failed to process file document: {e!s}") from e
|
||||
|
||||
|
||||
async def add_received_file_document_using_unstructured(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
unstructured_processed_elements: List[LangChainDocument],
|
||||
unstructured_processed_elements: list[LangChainDocument],
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Optional[Document]:
|
||||
) -> Document | None:
|
||||
try:
|
||||
file_in_markdown = await convert_document_to_markdown(
|
||||
unstructured_processed_elements
|
||||
|
|
@ -503,9 +535,11 @@ async def add_received_file_document_using_unstructured(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
return existing_document
|
||||
|
||||
# TODO: Check if file_markdown exceeds token limit of embedding model
|
||||
|
|
@ -555,7 +589,7 @@ async def add_received_file_document_using_unstructured(
|
|||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
||||
raise RuntimeError(f"Failed to process file document: {e!s}") from e
|
||||
|
||||
|
||||
async def add_received_file_document_using_llamacloud(
|
||||
|
|
@ -564,7 +598,7 @@ async def add_received_file_document_using_llamacloud(
|
|||
llamacloud_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Optional[Document]:
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content parsed by LlamaCloud.
|
||||
|
||||
|
|
@ -588,9 +622,11 @@ async def add_received_file_document_using_llamacloud(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -638,7 +674,9 @@ async def add_received_file_document_using_llamacloud(
|
|||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process file document using LlamaCloud: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Failed to process file document using LlamaCloud: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def add_received_file_document_using_docling(
|
||||
|
|
@ -647,7 +685,7 @@ async def add_received_file_document_using_docling(
|
|||
docling_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Optional[Document]:
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store document content parsed by Docling.
|
||||
|
||||
|
|
@ -671,9 +709,11 @@ async def add_received_file_document_using_docling(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -683,12 +723,11 @@ async def add_received_file_document_using_docling(
|
|||
|
||||
# Generate summary using chunked processing for large documents
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
|
||||
|
||||
summary_content = await docling_service.process_large_document_summary(
|
||||
content=file_in_markdown,
|
||||
llm=user_llm,
|
||||
document_title=file_name
|
||||
content=file_in_markdown, llm=user_llm, document_title=file_name
|
||||
)
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
|
|
@ -726,7 +765,9 @@ async def add_received_file_document_using_docling(
|
|||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process file document using Docling: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Failed to process file document using Docling: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def add_youtube_video_document(
|
||||
|
|
@ -749,23 +790,23 @@ async def add_youtube_video_document(
|
|||
RuntimeError: If the video processing fails
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="youtube_video_document",
|
||||
source="background_task",
|
||||
message=f"Starting YouTube video processing for: {url}",
|
||||
metadata={"url": url, "user_id": str(user_id)}
|
||||
metadata={"url": url, "user_id": str(user_id)},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# Extract video ID from URL
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Extracting video ID from URL: {url}",
|
||||
{"stage": "video_id_extraction"}
|
||||
{"stage": "video_id_extraction"},
|
||||
)
|
||||
|
||||
|
||||
def get_youtube_video_id(url: str):
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
|
|
@ -790,14 +831,14 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Video ID extracted: {video_id}",
|
||||
{"stage": "video_id_extracted", "video_id": video_id}
|
||||
{"stage": "video_id_extracted", "video_id": video_id},
|
||||
)
|
||||
|
||||
# Get video metadata
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching video metadata for: {video_id}",
|
||||
{"stage": "metadata_fetch"}
|
||||
{"stage": "metadata_fetch"},
|
||||
)
|
||||
|
||||
params = {
|
||||
|
|
@ -806,21 +847,27 @@ async def add_youtube_video_document(
|
|||
}
|
||||
oembed_url = "https://www.youtube.com/oembed"
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.get(oembed_url, params=params) as response:
|
||||
video_data = await response.json()
|
||||
async with (
|
||||
aiohttp.ClientSession() as http_session,
|
||||
http_session.get(oembed_url, params=params) as response,
|
||||
):
|
||||
video_data = await response.json()
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Video metadata fetched: {video_data.get('title', 'Unknown')}",
|
||||
{"stage": "metadata_fetched", "title": video_data.get('title'), "author": video_data.get('author_name')}
|
||||
{
|
||||
"stage": "metadata_fetched",
|
||||
"title": video_data.get("title"),
|
||||
"author": video_data.get("author_name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Get video transcript
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching transcript for video: {video_id}",
|
||||
{"stage": "transcript_fetch"}
|
||||
{"stage": "transcript_fetch"},
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -834,25 +881,29 @@ async def add_youtube_video_document(
|
|||
timestamp = f"[{start_time:.2f}s-{start_time + duration:.2f}s]"
|
||||
transcript_segments.append(f"{timestamp} {text}")
|
||||
transcript_text = "\n".join(transcript_segments)
|
||||
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Transcript fetched successfully: {len(captions)} segments",
|
||||
{"stage": "transcript_fetched", "segments_count": len(captions), "transcript_length": len(transcript_text)}
|
||||
{
|
||||
"stage": "transcript_fetched",
|
||||
"segments_count": len(captions),
|
||||
"transcript_length": len(transcript_text),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
transcript_text = f"No captions available for this video. Error: {str(e)}"
|
||||
transcript_text = f"No captions available for this video. Error: {e!s}"
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"No transcript available for video: {video_id}",
|
||||
{"stage": "transcript_unavailable", "error": str(e)}
|
||||
{"stage": "transcript_unavailable", "error": str(e)},
|
||||
)
|
||||
|
||||
# Format document
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing video content: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "content_processing"}
|
||||
{"stage": "content_processing"},
|
||||
)
|
||||
|
||||
# Format document metadata in a more maintainable way
|
||||
|
|
@ -890,7 +941,7 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Checking for duplicate video content: {video_id}",
|
||||
{"stage": "duplicate_check", "content_hash": content_hash}
|
||||
{"stage": "duplicate_check", "content_hash": content_hash},
|
||||
)
|
||||
|
||||
# Check if document with this content hash already exists
|
||||
|
|
@ -898,21 +949,27 @@ async def add_youtube_video_document(
|
|||
select(Document).where(Document.content_hash == content_hash)
|
||||
)
|
||||
existing_document = existing_doc_result.scalars().first()
|
||||
|
||||
|
||||
if existing_document:
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"YouTube video document already exists: {video_data.get('title', 'YouTube Video')}",
|
||||
{"duplicate_detected": True, "existing_document_id": existing_document.id, "video_id": video_id}
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"existing_document_id": existing_document.id,
|
||||
"video_id": video_id,
|
||||
},
|
||||
)
|
||||
logging.info(
|
||||
f"Document with content hash {content_hash} already exists. Skipping processing."
|
||||
)
|
||||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get LLM for summary generation
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Preparing for summary generation: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "llm_setup"}
|
||||
{"stage": "llm_setup"},
|
||||
)
|
||||
|
||||
# Get user's long context LLM
|
||||
|
|
@ -924,7 +981,7 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Generating summary for video: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "summary_generation"}
|
||||
{"stage": "summary_generation"},
|
||||
)
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
|
|
@ -938,7 +995,7 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing content chunks for video: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "chunk_processing"}
|
||||
{"stage": "chunk_processing"},
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -953,7 +1010,7 @@ async def add_youtube_video_document(
|
|||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Creating YouTube video document in database: {video_data.get('title', 'YouTube Video')}",
|
||||
{"stage": "document_creation", "chunks_count": len(chunks)}
|
||||
{"stage": "document_creation", "chunks_count": len(chunks)},
|
||||
)
|
||||
|
||||
document = Document(
|
||||
|
|
@ -988,8 +1045,8 @@ async def add_youtube_video_document(
|
|||
"content_hash": content_hash,
|
||||
"chunks_count": len(chunks),
|
||||
"summary_length": len(summary_content),
|
||||
"has_transcript": "No captions available" not in transcript_text
|
||||
}
|
||||
"has_transcript": "No captions available" not in transcript_text,
|
||||
},
|
||||
)
|
||||
|
||||
return document
|
||||
|
|
@ -999,7 +1056,10 @@ async def add_youtube_video_document(
|
|||
log_entry,
|
||||
f"Database error while processing YouTube video: {url}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError", "video_id": video_id if 'video_id' in locals() else None}
|
||||
{
|
||||
"error_type": "SQLAlchemyError",
|
||||
"video_id": video_id if "video_id" in locals() else None,
|
||||
},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -1008,7 +1068,10 @@ async def add_youtube_video_document(
|
|||
log_entry,
|
||||
f"Failed to process YouTube video: {url}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__, "video_id": video_id if 'video_id' in locals() else None}
|
||||
{
|
||||
"error_type": type(e).__name__,
|
||||
"video_id": video_id if "video_id" in locals() else None,
|
||||
},
|
||||
)
|
||||
logging.error(f"Failed to process YouTube video: {str(e)}")
|
||||
logging.error(f"Failed to process YouTube video: {e!s}")
|
||||
raise
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,33 +1,29 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State
|
||||
from app.db import Chat, Podcast
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
||||
async def generate_document_podcast(
|
||||
session: AsyncSession,
|
||||
document_id: int,
|
||||
search_space_id: int,
|
||||
user_id: int
|
||||
session: AsyncSession, document_id: int, search_space_id: int, user_id: int
|
||||
):
|
||||
# TODO: Need to fetch the document chunks, then concatenate them and pass them to the podcast generation model
|
||||
pass
|
||||
|
||||
|
||||
|
||||
async def generate_chat_podcast(
|
||||
session: AsyncSession,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
podcast_title: str,
|
||||
user_id: int
|
||||
user_id: int,
|
||||
):
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
|
||||
|
||||
# Log task start
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="generate_chat_podcast",
|
||||
|
|
@ -37,44 +33,43 @@ async def generate_chat_podcast(
|
|||
"chat_id": chat_id,
|
||||
"search_space_id": search_space_id,
|
||||
"podcast_title": podcast_title,
|
||||
"user_id": str(user_id)
|
||||
}
|
||||
"user_id": str(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# Fetch the chat with the specified ID
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Fetching chat {chat_id} from database",
|
||||
{"stage": "fetch_chat"}
|
||||
log_entry, f"Fetching chat {chat_id} from database", {"stage": "fetch_chat"}
|
||||
)
|
||||
|
||||
|
||||
query = select(Chat).filter(
|
||||
Chat.id == chat_id,
|
||||
Chat.search_space_id == search_space_id
|
||||
Chat.id == chat_id, Chat.search_space_id == search_space_id
|
||||
)
|
||||
|
||||
|
||||
result = await session.execute(query)
|
||||
chat = result.scalars().first()
|
||||
|
||||
|
||||
if not chat:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Chat with id {chat_id} not found in search space {search_space_id}",
|
||||
"Chat not found",
|
||||
{"error_type": "ChatNotFound"}
|
||||
{"error_type": "ChatNotFound"},
|
||||
)
|
||||
raise ValueError(f"Chat with id {chat_id} not found in search space {search_space_id}")
|
||||
|
||||
raise ValueError(
|
||||
f"Chat with id {chat_id} not found in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Create chat history structure
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing chat history for chat {chat_id}",
|
||||
{"stage": "process_chat_history", "message_count": len(chat.messages)}
|
||||
{"stage": "process_chat_history", "message_count": len(chat.messages)},
|
||||
)
|
||||
|
||||
|
||||
chat_history_str = "<chat_history>"
|
||||
|
||||
|
||||
processed_messages = 0
|
||||
for message in chat.messages:
|
||||
if message["role"] == "user":
|
||||
|
|
@ -89,18 +84,24 @@ async def generate_chat_podcast(
|
|||
# If content is a list, join it into a single string
|
||||
if isinstance(answer_text, list):
|
||||
answer_text = "\n".join(answer_text)
|
||||
chat_history_str += f"<assistant_message>{answer_text}</assistant_message>"
|
||||
chat_history_str += (
|
||||
f"<assistant_message>{answer_text}</assistant_message>"
|
||||
)
|
||||
processed_messages += 1
|
||||
|
||||
|
||||
chat_history_str += "</chat_history>"
|
||||
|
||||
|
||||
# Pass it to the SurfSense Podcaster
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing podcast generation for chat {chat_id}",
|
||||
{"stage": "initialize_podcast_generation", "processed_messages": processed_messages, "content_length": len(chat_history_str)}
|
||||
{
|
||||
"stage": "initialize_podcast_generation",
|
||||
"processed_messages": processed_messages,
|
||||
"content_length": len(chat_history_str),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"podcast_title": "SurfSense",
|
||||
|
|
@ -108,53 +109,55 @@ async def generate_chat_podcast(
|
|||
}
|
||||
}
|
||||
# Initialize state with database session and streaming service
|
||||
initial_state = State(
|
||||
source_content=chat_history_str,
|
||||
db_session=session
|
||||
)
|
||||
|
||||
initial_state = State(source_content=chat_history_str, db_session=session)
|
||||
|
||||
# Run the graph directly
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Running podcast generation graph for chat {chat_id}",
|
||||
{"stage": "run_podcast_graph"}
|
||||
{"stage": "run_podcast_graph"},
|
||||
)
|
||||
|
||||
|
||||
result = await podcaster_graph.ainvoke(initial_state, config=config)
|
||||
|
||||
|
||||
# Convert podcast transcript entries to serializable format
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing podcast transcript for chat {chat_id}",
|
||||
{"stage": "process_transcript", "transcript_entries": len(result["podcast_transcript"])}
|
||||
{
|
||||
"stage": "process_transcript",
|
||||
"transcript_entries": len(result["podcast_transcript"]),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
serializable_transcript = []
|
||||
for entry in result["podcast_transcript"]:
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.speaker_id,
|
||||
"dialog": entry.dialog
|
||||
})
|
||||
|
||||
serializable_transcript.append(
|
||||
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
|
||||
)
|
||||
|
||||
# Create a new podcast entry
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Creating podcast database entry for chat {chat_id}",
|
||||
{"stage": "create_podcast_entry", "file_location": result.get("final_podcast_file_path")}
|
||||
{
|
||||
"stage": "create_podcast_entry",
|
||||
"file_location": result.get("final_podcast_file_path"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
podcast = Podcast(
|
||||
title=f"{podcast_title}",
|
||||
podcast_transcript=serializable_transcript,
|
||||
file_location=result["final_podcast_file_path"],
|
||||
search_space_id=search_space_id
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
|
||||
# Add to session and commit
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
|
||||
|
||||
# Log success
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
|
|
@ -165,10 +168,10 @@ async def generate_chat_podcast(
|
|||
"transcript_entries": len(serializable_transcript),
|
||||
"file_location": result.get("final_podcast_file_path"),
|
||||
"processed_messages": processed_messages,
|
||||
"content_length": len(chat_history_str)
|
||||
}
|
||||
"content_length": len(chat_history_str),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return podcast
|
||||
|
||||
except ValueError as ve:
|
||||
|
|
@ -178,7 +181,7 @@ async def generate_chat_podcast(
|
|||
log_entry,
|
||||
f"Value error during podcast generation for chat {chat_id}",
|
||||
str(ve),
|
||||
{"error_type": "ValueError"}
|
||||
{"error_type": "ValueError"},
|
||||
)
|
||||
raise ve
|
||||
except SQLAlchemyError as db_error:
|
||||
|
|
@ -187,7 +190,7 @@ async def generate_chat_podcast(
|
|||
log_entry,
|
||||
f"Database error during podcast generation for chat {chat_id}",
|
||||
str(db_error),
|
||||
{"error_type": "SQLAlchemyError"}
|
||||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
|
|
@ -196,7 +199,8 @@ async def generate_chat_podcast(
|
|||
log_entry,
|
||||
f"Unexpected error during podcast generation for chat {chat_id}",
|
||||
str(e),
|
||||
{"error_type": type(e).__name__}
|
||||
{"error_type": type(e).__name__},
|
||||
)
|
||||
raise RuntimeError(f"Failed to generate podcast for chat {chat_id}: {str(e)}")
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to generate podcast for chat {chat_id}: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -1,28 +1,29 @@
|
|||
from typing import Any, AsyncGenerator, List, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.agents.researcher.graph import graph as researcher_graph
|
||||
from app.agents.researcher.state import State
|
||||
from app.services.streaming_service import StreamingService
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.researcher.configuration import SearchMode
|
||||
from app.agents.researcher.graph import graph as researcher_graph
|
||||
from app.agents.researcher.state import State
|
||||
from app.services.streaming_service import StreamingService
|
||||
|
||||
|
||||
async def stream_connector_search_results(
|
||||
user_query: str,
|
||||
user_id: Union[str, UUID],
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
selected_connectors: List[str],
|
||||
langchain_chat_history: List[Any],
|
||||
user_query: str,
|
||||
user_id: str | UUID,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
selected_connectors: list[str],
|
||||
langchain_chat_history: list[Any],
|
||||
search_mode_str: str,
|
||||
document_ids_to_add_in_context: List[int]
|
||||
document_ids_to_add_in_context: list[int],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream connector search results to the client
|
||||
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID (can be UUID object or string)
|
||||
|
|
@ -30,61 +31,60 @@ async def stream_connector_search_results(
|
|||
session: The database session
|
||||
research_mode: The research mode
|
||||
selected_connectors: List of selected connectors
|
||||
|
||||
|
||||
Yields:
|
||||
str: Formatted response strings
|
||||
"""
|
||||
streaming_service = StreamingService()
|
||||
|
||||
|
||||
if research_mode == "REPORT_GENERAL":
|
||||
NUM_SECTIONS = 1
|
||||
num_sections = 1
|
||||
elif research_mode == "REPORT_DEEP":
|
||||
NUM_SECTIONS = 3
|
||||
num_sections = 3
|
||||
elif research_mode == "REPORT_DEEPER":
|
||||
NUM_SECTIONS = 6
|
||||
num_sections = 6
|
||||
else:
|
||||
# Default fallback
|
||||
NUM_SECTIONS = 1
|
||||
|
||||
num_sections = 1
|
||||
|
||||
# Convert UUID to string if needed
|
||||
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
|
||||
|
||||
|
||||
if search_mode_str == "CHUNKS":
|
||||
search_mode = SearchMode.CHUNKS
|
||||
elif search_mode_str == "DOCUMENTS":
|
||||
search_mode = SearchMode.DOCUMENTS
|
||||
|
||||
|
||||
# Sample configuration
|
||||
config = {
|
||||
"configurable": {
|
||||
"user_query": user_query,
|
||||
"num_sections": NUM_SECTIONS,
|
||||
"num_sections": num_sections,
|
||||
"connectors_to_search": selected_connectors,
|
||||
"user_id": user_id_str,
|
||||
"search_space_id": search_space_id,
|
||||
"search_mode": search_mode,
|
||||
"research_mode": research_mode,
|
||||
"document_ids_to_add_in_context": document_ids_to_add_in_context
|
||||
"document_ids_to_add_in_context": document_ids_to_add_in_context,
|
||||
}
|
||||
}
|
||||
# Initialize state with database session and streaming service
|
||||
initial_state = State(
|
||||
db_session=session,
|
||||
streaming_service=streaming_service,
|
||||
chat_history=langchain_chat_history
|
||||
chat_history=langchain_chat_history,
|
||||
)
|
||||
|
||||
|
||||
# Run the graph directly
|
||||
print("\nRunning the complete researcher workflow...")
|
||||
|
||||
|
||||
# Use streaming with config parameter
|
||||
async for chunk in researcher_graph.astream(
|
||||
initial_state,
|
||||
config=config,
|
||||
stream_mode="custom",
|
||||
):
|
||||
if isinstance(chunk, dict):
|
||||
if "yield_value" in chunk:
|
||||
yield chunk["yield_value"]
|
||||
if isinstance(chunk, dict) and "yield_value" in chunk:
|
||||
yield chunk["yield_value"]
|
||||
|
||||
yield streaming_service.format_completion()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
from fastapi_users.authentication import (
|
||||
AuthenticationBackend,
|
||||
|
|
@ -10,21 +9,23 @@ from fastapi_users.authentication import (
|
|||
JWTStrategy,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.schemas import model_dump
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.config import config
|
||||
from app.db import User, get_user_db
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BearerResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
|
||||
google_oauth_client = GoogleOAuth2(
|
||||
config.GOOGLE_OAUTH_CLIENT_ID,
|
||||
config.GOOGLE_OAUTH_CLIENT_SECRET,
|
||||
|
|
@ -35,27 +36,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||
reset_password_token_secret = SECRET
|
||||
verification_token_secret = SECRET
|
||||
|
||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||
async def on_after_register(self, user: User, request: Request | None = None):
|
||||
print(f"User {user.id} has registered.")
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
self, user: User, token: str, request: Request | None = None
|
||||
):
|
||||
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
self, user: User, token: str, request: Request | None = None
|
||||
):
|
||||
print(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}")
|
||||
print(f"Verification requested for user {user.id}. Verification token: {token}")
|
||||
|
||||
|
||||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||
yield UserManager(user_db)
|
||||
|
||||
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600*24)
|
||||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24)
|
||||
|
||||
|
||||
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
|
||||
|
|
@ -77,6 +77,7 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
|||
# get_strategy=get_jwt_strategy,
|
||||
# )
|
||||
|
||||
|
||||
# BEARER AUTH CODE.
|
||||
class CustomBearerTransport(BearerTransport):
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
|
|
@ -87,6 +88,7 @@ class CustomBearerTransport(BearerTransport):
|
|||
else:
|
||||
return JSONResponse(model_dump(bearer_response))
|
||||
|
||||
|
||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
||||
|
||||
|
|
@ -98,4 +100,4 @@ auth_backend = AuthenticationBackend(
|
|||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,19 @@
|
|||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import User
|
||||
|
||||
|
||||
# Helper function to check user ownership
|
||||
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
|
||||
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id))
|
||||
item = await session.execute(
|
||||
select(model).filter(model.id == item_id, model.user_id == user.id)
|
||||
)
|
||||
item = item.scalars().first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it")
|
||||
return item
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Item not found or you don't have permission to access it",
|
||||
)
|
||||
return item
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ async def convert_element_to_markdown(element) -> str:
|
|||
"Footer": lambda x: f"*{x}*\n\n",
|
||||
"CodeSnippet": lambda x: f"```\n{x}\n```",
|
||||
"PageNumber": lambda x: f"*Page {x}*\n\n",
|
||||
"UncategorizedText": lambda x: f"{x}\n\n"
|
||||
"UncategorizedText": lambda x: f"{x}\n\n",
|
||||
}
|
||||
|
||||
converter = markdown_mapping.get(element_category, lambda x: x)
|
||||
|
|
@ -74,7 +74,7 @@ def convert_chunks_to_langchain_documents(chunks):
|
|||
except ImportError:
|
||||
raise ImportError(
|
||||
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
|
||||
)
|
||||
) from None
|
||||
|
||||
langchain_docs = []
|
||||
|
||||
|
|
@ -92,17 +92,20 @@ def convert_chunks_to_langchain_documents(chunks):
|
|||
# Add document information to metadata
|
||||
if "document" in chunk:
|
||||
doc = chunk["document"]
|
||||
metadata.update({
|
||||
"document_id": doc.get("id"),
|
||||
"document_title": doc.get("title"),
|
||||
"document_type": doc.get("document_type"),
|
||||
})
|
||||
metadata.update(
|
||||
{
|
||||
"document_id": doc.get("id"),
|
||||
"document_title": doc.get("title"),
|
||||
"document_type": doc.get("document_type"),
|
||||
}
|
||||
)
|
||||
|
||||
# Add document metadata if available
|
||||
if "metadata" in doc:
|
||||
# Prefix document metadata keys to avoid conflicts
|
||||
doc_metadata = {f"doc_meta_{k}": v for k,
|
||||
v in doc.get("metadata", {}).items()}
|
||||
doc_metadata = {
|
||||
f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()
|
||||
}
|
||||
metadata.update(doc_metadata)
|
||||
|
||||
# Add source URL if available in metadata
|
||||
|
|
@ -131,10 +134,7 @@ def convert_chunks_to_langchain_documents(chunks):
|
|||
"""
|
||||
|
||||
# Create LangChain Document
|
||||
langchain_doc = LangChainDocument(
|
||||
page_content=new_content,
|
||||
metadata=metadata
|
||||
)
|
||||
langchain_doc = LangChainDocument(page_content=new_content, metadata=metadata)
|
||||
|
||||
langchain_docs.append(langchain_doc)
|
||||
|
||||
|
|
@ -144,4 +144,4 @@ def convert_chunks_to_langchain_documents(chunks):
|
|||
def generate_content_hash(content: str, search_space_id: int) -> str:
|
||||
"""Generate SHA-256 hash for the given content combined with search space ID."""
|
||||
combined_data = f"{search_space_id}:{content}"
|
||||
return hashlib.sha256(combined_data.encode('utf-8')).hexdigest()
|
||||
return hashlib.sha256(combined_data.encode("utf-8")).hexdigest()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import uvicorn
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.config.uvicorn import load_uvicorn_config
|
||||
|
||||
logging.basicConfig(
|
||||
|
|
|
|||
|
|
@ -36,3 +36,97 @@ dependencies = [
|
|||
"validators>=0.34.0",
|
||||
"youtube-transcript-api>=1.0.3",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.12.5",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".ipynb_checkpoints",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pyenv",
|
||||
".pytest_cache",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
".vscode",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"site-packages",
|
||||
"venv",
|
||||
]
|
||||
|
||||
line-length = 88
|
||||
indent-width = 4
|
||||
|
||||
# Python 3.12
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E4", # pycodestyle errors
|
||||
"E7", # pycodestyle errors
|
||||
"E9", # pycodestyle errors
|
||||
"F", # Pyflakes
|
||||
"I", # isort
|
||||
"N", # pep8-naming
|
||||
"UP", # pyupgrade
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"T20", # flake8-print
|
||||
"SIM", # flake8-simplify
|
||||
"RUF", # Ruff-specific rules
|
||||
]
|
||||
|
||||
ignore = [
|
||||
"E501", # Line too long (handled by formatter)
|
||||
"B008", # Do not perform function calls in argument defaults
|
||||
"T201", # Print found (allow print statements)
|
||||
"RUF012", # Mutable class attributes should be annotated with `typing.ClassVar`
|
||||
]
|
||||
|
||||
extend-select = ["I"]
|
||||
|
||||
# Allow fix for all enabled rules (when `--fix`) is provided.
|
||||
fixable = ["ALL"]
|
||||
unfixable = []
|
||||
|
||||
# Allow unused variables when underscore-prefixed.
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
[tool.ruff.format]
|
||||
# Use double quotes for strings.
|
||||
quote-style = "double"
|
||||
|
||||
# Indent with spaces, rather than tabs.
|
||||
indent-style = "space"
|
||||
|
||||
# Respect magic trailing commas.
|
||||
skip-magic-trailing-comma = false
|
||||
|
||||
# Automatically detect the appropriate line ending.
|
||||
line-ending = "auto"
|
||||
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
# Group imports by type
|
||||
known-first-party = ["app"]
|
||||
force-single-line = false
|
||||
combine-as-imports = true
|
||||
|
|
|
|||
4507
surfsense_backend/uv.lock
generated
4507
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -9,12 +9,12 @@ import { ArrowLeft, Check, Loader2, Github } from "lucide-react";
|
|||
import { Form } from "@/components/ui/form";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardFooter,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardFooter,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
|
||||
// Import Utils, Types, Hook, and Components
|
||||
|
|
@ -27,201 +27,227 @@ import { EditSimpleTokenForm } from "@/components/editConnector/EditSimpleTokenF
|
|||
import { getConnectorIcon } from "@/components/chat";
|
||||
|
||||
export default function EditConnectorPage() {
|
||||
const router = useRouter();
|
||||
const params = useParams();
|
||||
const searchSpaceId = params.search_space_id as string;
|
||||
// Ensure connectorId is parsed safely
|
||||
const connectorIdParam = params.connector_id as string;
|
||||
const connectorId = connectorIdParam ? parseInt(connectorIdParam, 10) : NaN;
|
||||
const router = useRouter();
|
||||
const params = useParams();
|
||||
const searchSpaceId = params.search_space_id as string;
|
||||
// Ensure connectorId is parsed safely
|
||||
const connectorIdParam = params.connector_id as string;
|
||||
const connectorId = connectorIdParam ? parseInt(connectorIdParam, 10) : NaN;
|
||||
|
||||
// Use the custom hook to manage state and logic
|
||||
const {
|
||||
connectorsLoading,
|
||||
connector,
|
||||
isSaving,
|
||||
editForm,
|
||||
patForm, // Needed for GitHub child component
|
||||
handleSaveChanges,
|
||||
// GitHub specific props for the child component
|
||||
editMode,
|
||||
setEditMode, // Pass down if needed by GitHub component
|
||||
originalPat,
|
||||
currentSelectedRepos,
|
||||
fetchedRepos,
|
||||
setFetchedRepos,
|
||||
newSelectedRepos,
|
||||
setNewSelectedRepos,
|
||||
isFetchingRepos,
|
||||
handleFetchRepositories,
|
||||
handleRepoSelectionChange,
|
||||
} = useConnectorEditPage(connectorId, searchSpaceId);
|
||||
// Use the custom hook to manage state and logic
|
||||
const {
|
||||
connectorsLoading,
|
||||
connector,
|
||||
isSaving,
|
||||
editForm,
|
||||
patForm, // Needed for GitHub child component
|
||||
handleSaveChanges,
|
||||
// GitHub specific props for the child component
|
||||
editMode,
|
||||
setEditMode, // Pass down if needed by GitHub component
|
||||
originalPat,
|
||||
currentSelectedRepos,
|
||||
fetchedRepos,
|
||||
setFetchedRepos,
|
||||
newSelectedRepos,
|
||||
setNewSelectedRepos,
|
||||
isFetchingRepos,
|
||||
handleFetchRepositories,
|
||||
handleRepoSelectionChange,
|
||||
} = useConnectorEditPage(connectorId, searchSpaceId);
|
||||
|
||||
// Redirect if connectorId is not a valid number after parsing
|
||||
useEffect(() => {
|
||||
if (isNaN(connectorId)) {
|
||||
toast.error("Invalid Connector ID.");
|
||||
router.push(`/dashboard/${searchSpaceId}/connectors`);
|
||||
}
|
||||
}, [connectorId, router, searchSpaceId]);
|
||||
// Redirect if connectorId is not a valid number after parsing
|
||||
useEffect(() => {
|
||||
if (isNaN(connectorId)) {
|
||||
toast.error("Invalid Connector ID.");
|
||||
router.push(`/dashboard/${searchSpaceId}/connectors`);
|
||||
}
|
||||
}, [connectorId, router, searchSpaceId]);
|
||||
|
||||
// Loading State
|
||||
if (connectorsLoading || !connector) {
|
||||
// Handle NaN case before showing skeleton
|
||||
if (isNaN(connectorId)) return null;
|
||||
return <EditConnectorLoadingSkeleton />;
|
||||
}
|
||||
// Loading State
|
||||
if (connectorsLoading || !connector) {
|
||||
// Handle NaN case before showing skeleton
|
||||
if (isNaN(connectorId)) return null;
|
||||
return <EditConnectorLoadingSkeleton />;
|
||||
}
|
||||
|
||||
// Main Render using data/handlers from the hook
|
||||
return (
|
||||
<div className="container mx-auto py-8 max-w-3xl">
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="mb-6"
|
||||
onClick={() => router.push(`/dashboard/${searchSpaceId}/connectors`)}
|
||||
>
|
||||
<ArrowLeft className="mr-2 h-4 w-4" /> Back to Connectors
|
||||
</Button>
|
||||
// Main Render using data/handlers from the hook
|
||||
return (
|
||||
<div className="container mx-auto py-8 max-w-3xl">
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="mb-6"
|
||||
onClick={() => router.push(`/dashboard/${searchSpaceId}/connectors`)}
|
||||
>
|
||||
<ArrowLeft className="mr-2 h-4 w-4" /> Back to Connectors
|
||||
</Button>
|
||||
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ duration: 0.5 }}
|
||||
>
|
||||
<Card className="border-2 border-border">
|
||||
<CardHeader>
|
||||
<CardTitle className="text-2xl font-bold flex items-center gap-2">
|
||||
{getConnectorIcon(connector.connector_type)}
|
||||
Edit {getConnectorTypeDisplay(connector.connector_type)} Connector
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
Modify connector name and configuration.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ duration: 0.5 }}
|
||||
>
|
||||
<Card className="border-2 border-border">
|
||||
<CardHeader>
|
||||
<CardTitle className="text-2xl font-bold flex items-center gap-2">
|
||||
{getConnectorIcon(connector.connector_type)}
|
||||
Edit {getConnectorTypeDisplay(connector.connector_type)} Connector
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
Modify connector name and configuration.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
|
||||
<Form {...editForm}>
|
||||
{/* Pass hook's handleSaveChanges */}
|
||||
<form
|
||||
onSubmit={editForm.handleSubmit(handleSaveChanges)}
|
||||
className="space-y-6"
|
||||
>
|
||||
<CardContent className="space-y-6">
|
||||
{/* Pass form control from hook */}
|
||||
<EditConnectorNameForm control={editForm.control} />
|
||||
<Form {...editForm}>
|
||||
{/* Pass hook's handleSaveChanges */}
|
||||
<form
|
||||
onSubmit={editForm.handleSubmit(handleSaveChanges)}
|
||||
className="space-y-6"
|
||||
>
|
||||
<CardContent className="space-y-6">
|
||||
{/* Pass form control from hook */}
|
||||
<EditConnectorNameForm control={editForm.control} />
|
||||
|
||||
<hr />
|
||||
<hr />
|
||||
|
||||
<h3 className="text-lg font-semibold">Configuration</h3>
|
||||
<h3 className="text-lg font-semibold">Configuration</h3>
|
||||
|
||||
{/* == GitHub == */}
|
||||
{connector.connector_type === "GITHUB_CONNECTOR" && (
|
||||
<EditGitHubConnectorConfig
|
||||
// Pass relevant state and handlers from hook
|
||||
editMode={editMode}
|
||||
setEditMode={setEditMode} // Pass setter if child manages mode
|
||||
originalPat={originalPat}
|
||||
currentSelectedRepos={currentSelectedRepos}
|
||||
fetchedRepos={fetchedRepos}
|
||||
newSelectedRepos={newSelectedRepos}
|
||||
isFetchingRepos={isFetchingRepos}
|
||||
patForm={patForm}
|
||||
handleFetchRepositories={handleFetchRepositories}
|
||||
handleRepoSelectionChange={handleRepoSelectionChange}
|
||||
setNewSelectedRepos={setNewSelectedRepos}
|
||||
setFetchedRepos={setFetchedRepos}
|
||||
/>
|
||||
)}
|
||||
{/* == GitHub == */}
|
||||
{connector.connector_type === "GITHUB_CONNECTOR" && (
|
||||
<EditGitHubConnectorConfig
|
||||
// Pass relevant state and handlers from hook
|
||||
editMode={editMode}
|
||||
setEditMode={setEditMode} // Pass setter if child manages mode
|
||||
originalPat={originalPat}
|
||||
currentSelectedRepos={currentSelectedRepos}
|
||||
fetchedRepos={fetchedRepos}
|
||||
newSelectedRepos={newSelectedRepos}
|
||||
isFetchingRepos={isFetchingRepos}
|
||||
patForm={patForm}
|
||||
handleFetchRepositories={handleFetchRepositories}
|
||||
handleRepoSelectionChange={handleRepoSelectionChange}
|
||||
setNewSelectedRepos={setNewSelectedRepos}
|
||||
setFetchedRepos={setFetchedRepos}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* == Slack == */}
|
||||
{connector.connector_type === "SLACK_CONNECTOR" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="SLACK_BOT_TOKEN"
|
||||
fieldLabel="Slack Bot Token"
|
||||
fieldDescription="Update the Slack Bot Token if needed."
|
||||
placeholder="Begins with xoxb-..."
|
||||
/>
|
||||
)}
|
||||
{/* == Notion == */}
|
||||
{connector.connector_type === "NOTION_CONNECTOR" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="NOTION_INTEGRATION_TOKEN"
|
||||
fieldLabel="Notion Integration Token"
|
||||
fieldDescription="Update the Notion Integration Token if needed."
|
||||
placeholder="Begins with secret_..."
|
||||
/>
|
||||
)}
|
||||
{/* == Serper == */}
|
||||
{connector.connector_type === "SERPER_API" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="SERPER_API_KEY"
|
||||
fieldLabel="Serper API Key"
|
||||
fieldDescription="Update the Serper API Key if needed."
|
||||
/>
|
||||
)}
|
||||
{/* == Tavily == */}
|
||||
{connector.connector_type === "TAVILY_API" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="TAVILY_API_KEY"
|
||||
fieldLabel="Tavily API Key"
|
||||
fieldDescription="Update the Tavily API Key if needed."
|
||||
/>
|
||||
)}
|
||||
{/* == Slack == */}
|
||||
{connector.connector_type === "SLACK_CONNECTOR" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="SLACK_BOT_TOKEN"
|
||||
fieldLabel="Slack Bot Token"
|
||||
fieldDescription="Update the Slack Bot Token if needed."
|
||||
placeholder="Begins with xoxb-..."
|
||||
/>
|
||||
)}
|
||||
{/* == Notion == */}
|
||||
{connector.connector_type === "NOTION_CONNECTOR" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="NOTION_INTEGRATION_TOKEN"
|
||||
fieldLabel="Notion Integration Token"
|
||||
fieldDescription="Update the Notion Integration Token if needed."
|
||||
placeholder="Begins with secret_..."
|
||||
/>
|
||||
)}
|
||||
{/* == Serper == */}
|
||||
{connector.connector_type === "SERPER_API" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="SERPER_API_KEY"
|
||||
fieldLabel="Serper API Key"
|
||||
fieldDescription="Update the Serper API Key if needed."
|
||||
/>
|
||||
)}
|
||||
{/* == Tavily == */}
|
||||
{connector.connector_type === "TAVILY_API" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="TAVILY_API_KEY"
|
||||
fieldLabel="Tavily API Key"
|
||||
fieldDescription="Update the Tavily API Key if needed."
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* == Linear == */}
|
||||
{connector.connector_type === "LINEAR_CONNECTOR" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="LINEAR_API_KEY"
|
||||
fieldLabel="Linear API Key"
|
||||
fieldDescription="Update your Linear API Key if needed."
|
||||
placeholder="Begins with lin_api_..."
|
||||
/>
|
||||
)}
|
||||
{/* == Linear == */}
|
||||
{connector.connector_type === "LINEAR_CONNECTOR" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="LINEAR_API_KEY"
|
||||
fieldLabel="Linear API Key"
|
||||
fieldDescription="Update your Linear API Key if needed."
|
||||
placeholder="Begins with lin_api_..."
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* == Linkup == */}
|
||||
{connector.connector_type === "LINKUP_API" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="LINKUP_API_KEY"
|
||||
fieldLabel="Linkup API Key"
|
||||
fieldDescription="Update your Linkup API Key if needed."
|
||||
placeholder="Begins with linkup_..."
|
||||
/>
|
||||
)}
|
||||
{/* == Jira == */}
|
||||
{connector.connector_type === "JIRA_CONNECTOR" && (
|
||||
<div className="space-y-4">
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="JIRA_BASE_URL"
|
||||
fieldLabel="Jira Base URL"
|
||||
fieldDescription="Update your Jira instance URL if needed."
|
||||
placeholder="https://yourcompany.atlassian.net"
|
||||
/>
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="JIRA_EMAIL"
|
||||
fieldLabel="Jira Email"
|
||||
fieldDescription="Update your Atlassian account email if needed."
|
||||
placeholder="your.email@company.com"
|
||||
/>
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="JIRA_API_TOKEN"
|
||||
fieldLabel="Jira API Token"
|
||||
fieldDescription="Update your Jira API Token if needed."
|
||||
placeholder="Your Jira API Token"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* == Discord == */}
|
||||
{connector.connector_type === "DISCORD_CONNECTOR" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="DISCORD_BOT_TOKEN"
|
||||
fieldLabel="Discord Bot Token"
|
||||
fieldDescription="Update the Discord Bot Token if needed."
|
||||
placeholder="Bot token..."
|
||||
/>
|
||||
)}
|
||||
{/* == Linkup == */}
|
||||
{connector.connector_type === "LINKUP_API" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="LINKUP_API_KEY"
|
||||
fieldLabel="Linkup API Key"
|
||||
fieldDescription="Update your Linkup API Key if needed."
|
||||
placeholder="Begins with linkup_..."
|
||||
/>
|
||||
)}
|
||||
|
||||
</CardContent>
|
||||
<CardFooter className="border-t pt-6">
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSaving}
|
||||
className="w-full sm:w-auto"
|
||||
>
|
||||
{isSaving ? (
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<Check className="mr-2 h-4 w-4" />
|
||||
)}
|
||||
Save Changes
|
||||
</Button>
|
||||
</CardFooter>
|
||||
</form>
|
||||
</Form>
|
||||
</Card>
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
{/* == Discord == */}
|
||||
{connector.connector_type === "DISCORD_CONNECTOR" && (
|
||||
<EditSimpleTokenForm
|
||||
control={editForm.control}
|
||||
fieldName="DISCORD_BOT_TOKEN"
|
||||
fieldLabel="Discord Bot Token"
|
||||
fieldDescription="Update the Discord Bot Token if needed."
|
||||
placeholder="Bot token..."
|
||||
/>
|
||||
)}
|
||||
</CardContent>
|
||||
<CardFooter className="border-t pt-6">
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSaving}
|
||||
className="w-full sm:w-auto"
|
||||
>
|
||||
{isSaving ? (
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<Check className="mr-2 h-4 w-4" />
|
||||
)}
|
||||
Save Changes
|
||||
</Button>
|
||||
</CardFooter>
|
||||
</form>
|
||||
</Form>
|
||||
</Card>
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ import * as z from "zod";
|
|||
import { toast } from "sonner";
|
||||
import { ArrowLeft, Check, Info, Loader2 } from "lucide-react";
|
||||
|
||||
import { useSearchSourceConnectors, SearchSourceConnector } from "@/hooks/useSearchSourceConnectors";
|
||||
import {
|
||||
useSearchSourceConnectors,
|
||||
SearchSourceConnector,
|
||||
} from "@/hooks/useSearchSourceConnectors";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
|
|
@ -28,11 +31,7 @@ import {
|
|||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import {
|
||||
Alert,
|
||||
AlertDescription,
|
||||
AlertTitle,
|
||||
} from "@/components/ui/alert";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
|
||||
// Define the form schema with Zod
|
||||
const apiConnectorFormSchema = z.object({
|
||||
|
|
@ -47,13 +46,15 @@ const apiConnectorFormSchema = z.object({
|
|||
// Helper function to get connector type display name
|
||||
const getConnectorTypeDisplay = (type: string): string => {
|
||||
const typeMap: Record<string, string> = {
|
||||
"SERPER_API": "Serper API",
|
||||
"TAVILY_API": "Tavily API",
|
||||
"SLACK_CONNECTOR": "Slack Connector",
|
||||
"NOTION_CONNECTOR": "Notion Connector",
|
||||
"GITHUB_CONNECTOR": "GitHub Connector",
|
||||
"DISCORD_CONNECTOR": "Discord Connector",
|
||||
"LINKUP_API": "Linkup",
|
||||
SERPER_API: "Serper API",
|
||||
TAVILY_API: "Tavily API",
|
||||
SLACK_CONNECTOR: "Slack Connector",
|
||||
NOTION_CONNECTOR: "Notion Connector",
|
||||
GITHUB_CONNECTOR: "GitHub Connector",
|
||||
LINEAR_CONNECTOR: "Linear Connector",
|
||||
JIRA_CONNECTOR: "Jira Connector",
|
||||
DISCORD_CONNECTOR: "Discord Connector",
|
||||
LINKUP_API: "Linkup",
|
||||
// Add other connector types here as needed
|
||||
};
|
||||
return typeMap[type] || type;
|
||||
|
|
@ -67,9 +68,11 @@ export default function EditConnectorPage() {
|
|||
const params = useParams();
|
||||
const searchSpaceId = params.search_space_id as string;
|
||||
const connectorId = parseInt(params.connector_id as string, 10);
|
||||
|
||||
|
||||
const { connectors, updateConnector } = useSearchSourceConnectors();
|
||||
const [connector, setConnector] = useState<SearchSourceConnector | null>(null);
|
||||
const [connector, setConnector] = useState<SearchSourceConnector | null>(
|
||||
null,
|
||||
);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
// console.log("connector", connector);
|
||||
|
|
@ -85,24 +88,24 @@ export default function EditConnectorPage() {
|
|||
// Get API key field name based on connector type
|
||||
const getApiKeyFieldName = (connectorType: string): string => {
|
||||
const fieldMap: Record<string, string> = {
|
||||
"SERPER_API": "SERPER_API_KEY",
|
||||
"TAVILY_API": "TAVILY_API_KEY",
|
||||
"SLACK_CONNECTOR": "SLACK_BOT_TOKEN",
|
||||
"NOTION_CONNECTOR": "NOTION_INTEGRATION_TOKEN",
|
||||
"GITHUB_CONNECTOR": "GITHUB_PAT",
|
||||
"DISCORD_CONNECTOR": "DISCORD_BOT_TOKEN",
|
||||
"LINKUP_API": "LINKUP_API_KEY"
|
||||
SERPER_API: "SERPER_API_KEY",
|
||||
TAVILY_API: "TAVILY_API_KEY",
|
||||
SLACK_CONNECTOR: "SLACK_BOT_TOKEN",
|
||||
NOTION_CONNECTOR: "NOTION_INTEGRATION_TOKEN",
|
||||
GITHUB_CONNECTOR: "GITHUB_PAT",
|
||||
DISCORD_CONNECTOR: "DISCORD_BOT_TOKEN",
|
||||
LINKUP_API: "LINKUP_API_KEY",
|
||||
};
|
||||
return fieldMap[connectorType] || "";
|
||||
};
|
||||
|
||||
// Find connector in the list
|
||||
useEffect(() => {
|
||||
const currentConnector = connectors.find(c => c.id === connectorId);
|
||||
|
||||
const currentConnector = connectors.find((c) => c.id === connectorId);
|
||||
|
||||
if (currentConnector) {
|
||||
setConnector(currentConnector);
|
||||
|
||||
|
||||
// Check if connector type is supported
|
||||
const apiKeyField = getApiKeyFieldName(currentConnector.connector_type);
|
||||
if (apiKeyField) {
|
||||
|
|
@ -115,7 +118,7 @@ export default function EditConnectorPage() {
|
|||
toast.error("This connector type is not supported for editing");
|
||||
router.push(`/dashboard/${searchSpaceId}/connectors`);
|
||||
}
|
||||
|
||||
|
||||
setIsLoading(false);
|
||||
} else if (!isLoading && connectors.length > 0) {
|
||||
// If connectors are loaded but this one isn't found
|
||||
|
|
@ -127,11 +130,11 @@ export default function EditConnectorPage() {
|
|||
// Handle form submission
|
||||
const onSubmit = async (values: ApiConnectorFormValues) => {
|
||||
if (!connector) return;
|
||||
|
||||
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
const apiKeyField = getApiKeyFieldName(connector.connector_type);
|
||||
|
||||
|
||||
// Only update the API key if a new one was provided
|
||||
const updatedConfig = { ...connector.config };
|
||||
if (values.api_key) {
|
||||
|
|
@ -150,7 +153,9 @@ export default function EditConnectorPage() {
|
|||
router.push(`/dashboard/${searchSpaceId}/connectors`);
|
||||
} catch (error) {
|
||||
console.error("Error updating connector:", error);
|
||||
toast.error(error instanceof Error ? error.message : "Failed to update connector");
|
||||
toast.error(
|
||||
error instanceof Error ? error.message : "Failed to update connector",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
|
|
@ -186,24 +191,30 @@ export default function EditConnectorPage() {
|
|||
<Card className="border-2 border-border">
|
||||
<CardHeader>
|
||||
<CardTitle className="text-2xl font-bold">
|
||||
Edit {connector ? getConnectorTypeDisplay(connector.connector_type) : ""} Connector
|
||||
Edit{" "}
|
||||
{connector
|
||||
? getConnectorTypeDisplay(connector.connector_type)
|
||||
: ""}{" "}
|
||||
Connector
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
Update your connector settings.
|
||||
</CardDescription>
|
||||
<CardDescription>Update your connector settings.</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<Alert className="mb-6 bg-muted">
|
||||
<Info className="h-4 w-4" />
|
||||
<AlertTitle>API Key Security</AlertTitle>
|
||||
<AlertDescription>
|
||||
Your API key is stored securely. For security reasons, we don't display your existing API key.
|
||||
If you don't update the API key field, your existing key will be preserved.
|
||||
Your API key is stored securely. For security reasons, we don't
|
||||
display your existing API key. If you don't update the API key
|
||||
field, your existing key will be preserved.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-6">
|
||||
<form
|
||||
onSubmit={form.handleSubmit(onSubmit)}
|
||||
className="space-y-6"
|
||||
>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="name"
|
||||
|
|
@ -227,10 +238,10 @@ export default function EditConnectorPage() {
|
|||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>
|
||||
{connector?.connector_type === "SLACK_CONNECTOR"
|
||||
? "Slack Bot Token"
|
||||
: connector?.connector_type === "NOTION_CONNECTOR"
|
||||
? "Notion Integration Token"
|
||||
{connector?.connector_type === "SLACK_CONNECTOR"
|
||||
? "Slack Bot Token"
|
||||
: connector?.connector_type === "NOTION_CONNECTOR"
|
||||
? "Notion Integration Token"
|
||||
: connector?.connector_type === "GITHUB_CONNECTOR"
|
||||
? "GitHub Personal Access Token (PAT)"
|
||||
: connector?.connector_type === "LINKUP_API"
|
||||
|
|
@ -238,27 +249,28 @@ export default function EditConnectorPage() {
|
|||
: "API Key"}
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
<Input
|
||||
type="password"
|
||||
placeholder={
|
||||
connector?.connector_type === "SLACK_CONNECTOR"
|
||||
? "Enter new Slack Bot Token (optional)"
|
||||
: connector?.connector_type === "NOTION_CONNECTOR"
|
||||
connector?.connector_type === "SLACK_CONNECTOR"
|
||||
? "Enter new Slack Bot Token (optional)"
|
||||
: connector?.connector_type === "NOTION_CONNECTOR"
|
||||
? "Enter new Notion Token (optional)"
|
||||
: connector?.connector_type === "GITHUB_CONNECTOR"
|
||||
: connector?.connector_type ===
|
||||
"GITHUB_CONNECTOR"
|
||||
? "Enter new GitHub PAT (optional)"
|
||||
: connector?.connector_type === "LINKUP_API"
|
||||
? "Enter new Linkup API Key (optional)"
|
||||
: "Enter new API key (optional)"
|
||||
}
|
||||
{...field}
|
||||
}
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
{connector?.connector_type === "SLACK_CONNECTOR"
|
||||
? "Enter a new Slack Bot Token or leave blank to keep your existing token."
|
||||
: connector?.connector_type === "NOTION_CONNECTOR"
|
||||
? "Enter a new Notion Integration Token or leave blank to keep your existing token."
|
||||
{connector?.connector_type === "SLACK_CONNECTOR"
|
||||
? "Enter a new Slack Bot Token or leave blank to keep your existing token."
|
||||
: connector?.connector_type === "NOTION_CONNECTOR"
|
||||
? "Enter a new Notion Integration Token or leave blank to keep your existing token."
|
||||
: connector?.connector_type === "GITHUB_CONNECTOR"
|
||||
? "Enter a new GitHub PAT or leave blank to keep your existing token."
|
||||
: connector?.connector_type === "LINKUP_API"
|
||||
|
|
@ -271,8 +283,8 @@ export default function EditConnectorPage() {
|
|||
/>
|
||||
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
type="submit"
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className="w-full sm:w-auto"
|
||||
>
|
||||
|
|
@ -296,4 +308,4 @@ export default function EditConnectorPage() {
|
|||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,472 @@
|
|||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useRouter, useParams } from "next/navigation";
|
||||
import { motion } from "framer-motion";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useForm } from "react-hook-form";
|
||||
import * as z from "zod";
|
||||
import { toast } from "sonner";
|
||||
import { ArrowLeft, Check, Info, Loader2 } from "lucide-react";
|
||||
|
||||
import { useSearchSourceConnectors } from "@/hooks/useSearchSourceConnectors";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardFooter,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import {
|
||||
Accordion,
|
||||
AccordionContent,
|
||||
AccordionItem,
|
||||
AccordionTrigger,
|
||||
} from "@/components/ui/accordion";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
|
||||
// Define the form schema with Zod
|
||||
const jiraConnectorFormSchema = z.object({
|
||||
name: z.string().min(3, {
|
||||
message: "Connector name must be at least 3 characters.",
|
||||
}),
|
||||
base_url: z
|
||||
.string()
|
||||
.url({
|
||||
message:
|
||||
"Please enter a valid Jira URL (e.g., https://yourcompany.atlassian.net)",
|
||||
})
|
||||
.refine(
|
||||
(url) => {
|
||||
return url.includes("atlassian.net") || url.includes("jira");
|
||||
},
|
||||
{
|
||||
message: "Please enter a valid Jira instance URL",
|
||||
},
|
||||
),
|
||||
email: z.string().email({
|
||||
message: "Please enter a valid email address.",
|
||||
}),
|
||||
api_token: z.string().min(10, {
|
||||
message: "Jira API Token is required and must be valid.",
|
||||
}),
|
||||
});
|
||||
|
||||
// Define the type for the form values
|
||||
type JiraConnectorFormValues = z.infer<typeof jiraConnectorFormSchema>;
|
||||
|
||||
export default function JiraConnectorPage() {
|
||||
const router = useRouter();
|
||||
const params = useParams();
|
||||
const searchSpaceId = params.search_space_id as string;
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const { createConnector } = useSearchSourceConnectors();
|
||||
|
||||
// Initialize the form
|
||||
const form = useForm<JiraConnectorFormValues>({
|
||||
resolver: zodResolver(jiraConnectorFormSchema),
|
||||
defaultValues: {
|
||||
name: "Jira Connector",
|
||||
base_url: "",
|
||||
email: "",
|
||||
api_token: "",
|
||||
},
|
||||
});
|
||||
|
||||
// Handle form submission
|
||||
const onSubmit = async (values: JiraConnectorFormValues) => {
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
await createConnector({
|
||||
name: values.name,
|
||||
connector_type: "JIRA_CONNECTOR",
|
||||
config: {
|
||||
JIRA_BASE_URL: values.base_url,
|
||||
JIRA_EMAIL: values.email,
|
||||
JIRA_API_TOKEN: values.api_token,
|
||||
},
|
||||
is_indexable: true,
|
||||
last_indexed_at: null,
|
||||
});
|
||||
|
||||
toast.success("Jira connector created successfully!");
|
||||
|
||||
// Navigate back to connectors page
|
||||
router.push(`/dashboard/${searchSpaceId}/connectors`);
|
||||
} catch (error) {
|
||||
console.error("Error creating connector:", error);
|
||||
toast.error(
|
||||
error instanceof Error ? error.message : "Failed to create connector",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="container mx-auto py-8 max-w-3xl">
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="mb-6"
|
||||
onClick={() =>
|
||||
router.push(`/dashboard/${searchSpaceId}/connectors/add`)
|
||||
}
|
||||
>
|
||||
<ArrowLeft className="mr-2 h-4 w-4" />
|
||||
Back to Connectors
|
||||
</Button>
|
||||
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ duration: 0.5 }}
|
||||
>
|
||||
<Tabs defaultValue="connect" className="w-full">
|
||||
<TabsList className="grid w-full grid-cols-2 mb-6">
|
||||
<TabsTrigger value="connect">Connect</TabsTrigger>
|
||||
<TabsTrigger value="documentation">Documentation</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="connect">
|
||||
<Card className="border-2 border-border">
|
||||
<CardHeader>
|
||||
<CardTitle className="text-2xl font-bold">
|
||||
Connect Jira Instance
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
Integrate with Jira to search and retrieve information from
|
||||
your issues, tickets, and comments. This connector can index
|
||||
your Jira content for search.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<Alert className="mb-6 bg-muted">
|
||||
<Info className="h-4 w-4" />
|
||||
<AlertTitle>Jira Personal Access Token Required</AlertTitle>
|
||||
<AlertDescription>
|
||||
You'll need a Jira Personal Access Token to use this
|
||||
connector. You can create one from{" "}
|
||||
<a
|
||||
href="https://id.atlassian.com/manage-profile/security/api-tokens"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="font-medium underline underline-offset-4"
|
||||
>
|
||||
Atlassian Account Settings
|
||||
</a>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<Form {...form}>
|
||||
<form
|
||||
onSubmit={form.handleSubmit(onSubmit)}
|
||||
className="space-y-6"
|
||||
>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="name"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Connector Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input placeholder="My Jira Connector" {...field} />
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
A friendly name to identify this connector.
|
||||
</FormDescription>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="base_url"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Jira Instance URL</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
placeholder="https://yourcompany.atlassian.net"
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
Your Jira instance URL. For Atlassian Cloud, this is
|
||||
typically https://yourcompany.atlassian.net
|
||||
</FormDescription>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="email"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Email Address</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="email"
|
||||
placeholder="your.email@company.com"
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
Your Atlassian account email address.
|
||||
</FormDescription>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="api_token"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>API Token</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Your Jira API Token"
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
Your Jira API Token will be encrypted and stored securely.
|
||||
</FormDescription>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className="w-full sm:w-auto"
|
||||
>
|
||||
{isSubmitting ? (
|
||||
<>
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
Connecting...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Check className="mr-2 h-4 w-4" />
|
||||
Connect Jira
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
</Form>
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-col items-start border-t bg-muted/50 px-6 py-4">
|
||||
<h4 className="text-sm font-medium">
|
||||
What you get with Jira integration:
|
||||
</h4>
|
||||
<ul className="mt-2 list-disc pl-5 text-sm text-muted-foreground">
|
||||
<li>Search through all your Jira issues and tickets</li>
|
||||
<li>
|
||||
Access issue descriptions, comments, and full discussion
|
||||
threads
|
||||
</li>
|
||||
<li>
|
||||
Connect your team's project management directly to your
|
||||
search space
|
||||
</li>
|
||||
<li>
|
||||
Keep your search results up-to-date with latest Jira content
|
||||
</li>
|
||||
<li>
|
||||
Index your Jira issues for enhanced search capabilities
|
||||
</li>
|
||||
<li>
|
||||
Search by issue keys, status, priority, and assignee
|
||||
information
|
||||
</li>
|
||||
</ul>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="documentation">
|
||||
<Card className="border-2 border-border">
|
||||
<CardHeader>
|
||||
<CardTitle className="text-2xl font-bold">
|
||||
Jira Connector Documentation
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
Learn how to set up and use the Jira connector to index your
|
||||
project management data.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-6">
|
||||
<div>
|
||||
<h3 className="text-xl font-semibold mb-2">How it works</h3>
|
||||
<p className="text-muted-foreground">
|
||||
The Jira connector uses the Jira REST API with Basic Authentication
|
||||
to fetch all issues and comments that your account has
|
||||
access to within your Jira instance.
|
||||
</p>
|
||||
<ul className="mt-2 list-disc pl-5 text-muted-foreground">
|
||||
<li>
|
||||
For follow up indexing runs, the connector retrieves
|
||||
issues and comments that have been updated since the last
|
||||
indexing attempt.
|
||||
</li>
|
||||
<li>
|
||||
Indexing is configured to run periodically, so updates
|
||||
should appear in your search results within minutes.
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<Accordion type="single" collapsible className="w-full">
|
||||
<AccordionItem value="authorization">
|
||||
<AccordionTrigger className="text-lg font-medium">
|
||||
Authorization
|
||||
</AccordionTrigger>
|
||||
<AccordionContent className="space-y-4">
|
||||
<Alert className="bg-muted">
|
||||
<Info className="h-4 w-4" />
|
||||
<AlertTitle>Read-Only Access is Sufficient</AlertTitle>
|
||||
<AlertDescription>
|
||||
You only need read access for this connector to work.
|
||||
The API Token will only be used to read your Jira data.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className="space-y-6">
|
||||
<div>
|
||||
<h4 className="font-medium mb-2">
|
||||
Step 1: Create an API Token
|
||||
</h4>
|
||||
<ol className="list-decimal pl-5 space-y-3">
|
||||
<li>Log in to your Atlassian account</li>
|
||||
<li>
|
||||
Navigate to{" "}
|
||||
<a
|
||||
href="https://id.atlassian.com/manage-profile/security/api-tokens"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="font-medium underline underline-offset-4"
|
||||
>
|
||||
https://id.atlassian.com/manage-profile/security/api-tokens
|
||||
</a>
|
||||
</li>
|
||||
<li>
|
||||
Click <strong>Create API token</strong>
|
||||
</li>
|
||||
<li>
|
||||
Enter a label for your token (like "SurfSense
|
||||
Connector")
|
||||
</li>
|
||||
<li>
|
||||
Click <strong>Create</strong>
|
||||
</li>
|
||||
<li>
|
||||
Copy the generated token as it will only be shown
|
||||
once
|
||||
</li>
|
||||
</ol>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h4 className="font-medium mb-2">
|
||||
Step 2: Grant necessary access
|
||||
</h4>
|
||||
<p className="text-muted-foreground mb-3">
|
||||
The API Token will have access to all projects and
|
||||
issues that your user account can see. Make sure your
|
||||
account has appropriate permissions for the projects
|
||||
you want to index.
|
||||
</p>
|
||||
<Alert className="bg-muted">
|
||||
<Info className="h-4 w-4" />
|
||||
<AlertTitle>Data Privacy</AlertTitle>
|
||||
<AlertDescription>
|
||||
Only issues, comments, and basic metadata will be
|
||||
indexed. Jira attachments and linked files are not
|
||||
indexed by this connector.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
</div>
|
||||
</div>
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
|
||||
<AccordionItem value="indexing">
|
||||
<AccordionTrigger className="text-lg font-medium">
|
||||
Indexing
|
||||
</AccordionTrigger>
|
||||
<AccordionContent className="space-y-4">
|
||||
<ol className="list-decimal pl-5 space-y-3">
|
||||
<li>
|
||||
Navigate to the Connector Dashboard and select the{" "}
|
||||
<strong>Jira</strong> Connector.
|
||||
</li>
|
||||
<li>
|
||||
Enter your <strong>Jira Instance URL</strong> (e.g.,
|
||||
https://yourcompany.atlassian.net)
|
||||
</li>
|
||||
<li>
|
||||
Place your <strong>Personal Access Token</strong> in
|
||||
the form field.
|
||||
</li>
|
||||
<li>
|
||||
Click <strong>Connect</strong> to establish the
|
||||
connection.
|
||||
</li>
|
||||
<li>
|
||||
Once connected, your Jira issues will be indexed
|
||||
automatically.
|
||||
</li>
|
||||
</ol>
|
||||
|
||||
<Alert className="bg-muted">
|
||||
<Info className="h-4 w-4" />
|
||||
<AlertTitle>What Gets Indexed</AlertTitle>
|
||||
<AlertDescription>
|
||||
<p className="mb-2">
|
||||
The Jira connector indexes the following data:
|
||||
</p>
|
||||
<ul className="list-disc pl-5">
|
||||
<li>Issue keys and summaries (e.g., PROJ-123)</li>
|
||||
<li>Issue descriptions</li>
|
||||
<li>Issue comments and discussion threads</li>
|
||||
<li>
|
||||
Issue status, priority, and type information
|
||||
</li>
|
||||
<li>Assignee and reporter information</li>
|
||||
<li>Project information</li>
|
||||
</ul>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
</Accordion>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue