mirror of
https://github.com/rushil-thareja/dp-fusion-lib.git
synced 2026-04-24 12:06:23 +02:00
Initial release v0.1.0
- Token-level differential privacy for LLMs - Integration with Document Privacy API - Comprehensive test suite and documentation - Examples and Jupyter notebook included
This commit is contained in:
commit
d012046d85
31 changed files with 4480 additions and 0 deletions
47
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
47
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
---
|
||||||
|
name: Bug Report
|
||||||
|
about: Report a bug or unexpected behavior
|
||||||
|
title: "[BUG] "
|
||||||
|
labels: bug
|
||||||
|
assignees: ''
|
||||||
|
---
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
A clear description of the bug.
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
|
||||||
|
- Python version:
|
||||||
|
- PyTorch version:
|
||||||
|
- Transformers version:
|
||||||
|
- dp-fusion-lib version:
|
||||||
|
- OS:
|
||||||
|
- GPU (if applicable):
|
||||||
|
|
||||||
|
## Steps to Reproduce
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Minimal code to reproduce the issue
|
||||||
|
from dp_fusion_lib import DPFusion
|
||||||
|
|
||||||
|
# ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Expected Behavior
|
||||||
|
|
||||||
|
What you expected to happen.
|
||||||
|
|
||||||
|
## Actual Behavior
|
||||||
|
|
||||||
|
What actually happened.
|
||||||
|
|
||||||
|
## Error Traceback
|
||||||
|
|
||||||
|
```
|
||||||
|
Paste full traceback here
|
||||||
|
```
|
||||||
|
|
||||||
|
## Additional Context
|
||||||
|
|
||||||
|
Any other context about the problem.
|
||||||
27
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
---
|
||||||
|
name: Feature Request
|
||||||
|
about: Suggest a new feature or enhancement
|
||||||
|
title: "[FEATURE] "
|
||||||
|
labels: enhancement
|
||||||
|
assignees: ''
|
||||||
|
---
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
A clear description of the feature you'd like.
|
||||||
|
|
||||||
|
## Motivation
|
||||||
|
|
||||||
|
Why is this feature needed? What problem does it solve?
|
||||||
|
|
||||||
|
## Proposed Solution
|
||||||
|
|
||||||
|
If you have ideas on how to implement this, describe them here.
|
||||||
|
|
||||||
|
## Alternatives Considered
|
||||||
|
|
||||||
|
Any alternative solutions or features you've considered.
|
||||||
|
|
||||||
|
## Additional Context
|
||||||
|
|
||||||
|
Any other context, examples, or references.
|
||||||
28
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
28
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
## Description
|
||||||
|
|
||||||
|
Brief description of what this PR does.
|
||||||
|
|
||||||
|
## Type of Change
|
||||||
|
|
||||||
|
- [ ] Bug fix (non-breaking change that fixes an issue)
|
||||||
|
- [ ] New feature (non-breaking change that adds functionality)
|
||||||
|
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
|
||||||
|
- [ ] Documentation update
|
||||||
|
- [ ] Refactoring (no functional changes)
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] I have read the [CONTRIBUTING](CONTRIBUTING.md) guidelines
|
||||||
|
- [ ] My code follows the project's code style
|
||||||
|
- [ ] I have added tests that prove my fix/feature works
|
||||||
|
- [ ] All new and existing tests pass
|
||||||
|
- [ ] I have updated documentation as needed
|
||||||
|
- [ ] My changes don't introduce new warnings
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Describe how you tested your changes.
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
Fixes #(issue number)
|
||||||
56
.github/workflows/publish.yml
vendored
Normal file
56
.github/workflows/publish.yml
vendored
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
name: Publish to PyPI
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install build tools
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install build twine
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: |
|
||||||
|
python -m build
|
||||||
|
|
||||||
|
- name: Check package
|
||||||
|
run: |
|
||||||
|
twine check dist/*
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: dist
|
||||||
|
path: dist/
|
||||||
|
|
||||||
|
publish:
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
environment: pypi
|
||||||
|
permissions:
|
||||||
|
id-token: write # For trusted publishing
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Download artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: dist
|
||||||
|
path: dist/
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
# Uses trusted publishing (no API token needed if configured in PyPI)
|
||||||
|
# Alternatively, use:
|
||||||
|
# with:
|
||||||
|
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
63
.github/workflows/tests.yml
vendored
Normal file
63
.github/workflows/tests.yml
vendored
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
name: Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main, master]
|
||||||
|
pull_request:
|
||||||
|
branches: [main, master]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
|
- name: Run linter
|
||||||
|
run: |
|
||||||
|
ruff check src/
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
pytest tests/ -v --cov=dp_fusion_lib --cov-report=term-missing
|
||||||
|
|
||||||
|
- name: Check import works
|
||||||
|
run: |
|
||||||
|
python -c "from dp_fusion_lib import DPFusion, Tagger, compute_epsilon_single_group; print('Import successful')"
|
||||||
|
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install linting tools
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install ruff black
|
||||||
|
|
||||||
|
- name: Check formatting with black
|
||||||
|
run: |
|
||||||
|
black --check src/ tests/ examples/
|
||||||
|
|
||||||
|
- name: Lint with ruff
|
||||||
|
run: |
|
||||||
|
ruff check src/ tests/ examples/
|
||||||
82
.gitignore
vendored
Normal file
82
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
.venv/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.project
|
||||||
|
.pydevproject
|
||||||
|
.settings/
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
site/
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
*.bak
|
||||||
|
|
||||||
|
# Jupyter
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
*.ipynb_checkpoints
|
||||||
|
|
||||||
|
# Project specific
|
||||||
|
*.log
|
||||||
|
models/
|
||||||
|
*.pkl
|
||||||
|
*.pt
|
||||||
|
*.pth
|
||||||
|
*.bin
|
||||||
|
*.safetensors
|
||||||
|
*.ckpt
|
||||||
|
|
||||||
|
# Environment files (may contain secrets)
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
*.env
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
.cache/
|
||||||
|
|
||||||
|
# Build artifacts
|
||||||
|
*.whl
|
||||||
|
MANIFEST
|
||||||
37
CHANGELOG.md
Normal file
37
CHANGELOG.md
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [0.1.0] - 2025-01-01
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Initial release of DP-Fusion-Lib
|
||||||
|
- `DPFusion` class for differentially private text generation
|
||||||
|
- Message-based context building with `add_message()`
|
||||||
|
- Direct context generation with `generate()`
|
||||||
|
- Token-level generation with `generate_from_tokens()`
|
||||||
|
- `Tagger` class for automatic private phrase extraction
|
||||||
|
- Integration with Document Privacy API
|
||||||
|
- Support for multiple document types (HEALTH, FINANCE, LEGAL)
|
||||||
|
- Privacy accounting functions
|
||||||
|
- `compute_epsilon_single_group()` for single-group privacy guarantees
|
||||||
|
- `compute_dp_epsilon()` for multi-group scenarios
|
||||||
|
- Utility functions for advanced usage
|
||||||
|
- `compute_renyi_divergence_clipped_symmetric()` for divergence computation
|
||||||
|
- `find_lambda()` for binary search of mixing parameter
|
||||||
|
- `replace_sequences_with_placeholder_fast()` for token-level redaction
|
||||||
|
- Support for HuggingFace transformers models
|
||||||
|
- Incremental decoding with KV-cache optimization
|
||||||
|
- Comprehensive documentation and examples
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
- PyTorch >= 2.0.0
|
||||||
|
- Transformers >= 4.25.0
|
||||||
|
- Requests >= 2.25.0
|
||||||
42
CITATION.cff
Normal file
42
CITATION.cff
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
cff-version: 1.2.0
|
||||||
|
title: "DP-Fusion-Lib: Token-Level Differentially Private Inference for Large Language Models"
|
||||||
|
message: "If you use this software, please cite it as below."
|
||||||
|
type: software
|
||||||
|
authors:
|
||||||
|
- family-names: "Thareja"
|
||||||
|
given-names: "Rushil"
|
||||||
|
email: "rushil.thareja@mbzuai.ac.ae"
|
||||||
|
affiliation: "MBZUAI"
|
||||||
|
repository-code: "https://github.com/rushil-thareja/dp-fusion-lib"
|
||||||
|
url: "https://github.com/rushil-thareja/dp-fusion-lib"
|
||||||
|
license: "LicenseRef-NonCommercial"
|
||||||
|
version: "0.1.0"
|
||||||
|
date-released: "2025-01-01"
|
||||||
|
keywords:
|
||||||
|
- differential-privacy
|
||||||
|
- text-generation
|
||||||
|
- large-language-models
|
||||||
|
- privacy
|
||||||
|
- machine-learning
|
||||||
|
- renyi-divergence
|
||||||
|
- nlp
|
||||||
|
preferred-citation:
|
||||||
|
type: article
|
||||||
|
authors:
|
||||||
|
- family-names: "Thareja"
|
||||||
|
given-names: "Rushil"
|
||||||
|
- family-names: "Lukas"
|
||||||
|
given-names: "Nils"
|
||||||
|
- family-names: "Baba"
|
||||||
|
given-names: "Sarim"
|
||||||
|
- family-names: "Abbasi"
|
||||||
|
given-names: "Ahmed"
|
||||||
|
- family-names: "Asokan"
|
||||||
|
given-names: "N."
|
||||||
|
title: "DP-Fusion: Token-Level Differentially Private Inference for Large Language Models"
|
||||||
|
year: 2025
|
||||||
|
url: "https://arxiv.org/abs/2507.04531"
|
||||||
|
identifiers:
|
||||||
|
- type: other
|
||||||
|
value: "arXiv:2507.04531"
|
||||||
|
description: "arXiv preprint"
|
||||||
0
CLAUDE.md
Normal file
0
CLAUDE.md
Normal file
112
CONTRIBUTING.md
Normal file
112
CONTRIBUTING.md
Normal file
|
|
@ -0,0 +1,112 @@
|
||||||
|
# Contributing to DP-Fusion-Lib
|
||||||
|
|
||||||
|
Thank you for your interest in contributing to DP-Fusion-Lib! This document provides guidelines for contributing.
|
||||||
|
|
||||||
|
## Code of Conduct
|
||||||
|
|
||||||
|
Please be respectful and constructive in all interactions.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Development Setup
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/rushil-thareja/dp-fusion-lib.git
|
||||||
|
cd dp-fusion-lib
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create a virtual environment:
|
||||||
|
```bash
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Install in development mode:
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Install pre-commit hooks (optional):
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest tests/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
With coverage:
|
||||||
|
```bash
|
||||||
|
pytest tests/ -v --cov=dp_fusion_lib --cov-report=term-missing
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
We use:
|
||||||
|
- **black** for code formatting
|
||||||
|
- **ruff** for linting
|
||||||
|
|
||||||
|
Format code:
|
||||||
|
```bash
|
||||||
|
black src/ tests/ examples/
|
||||||
|
```
|
||||||
|
|
||||||
|
Check linting:
|
||||||
|
```bash
|
||||||
|
ruff check src/ tests/ examples/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Making Changes
|
||||||
|
|
||||||
|
### Branch Naming
|
||||||
|
|
||||||
|
- `feature/description` - New features
|
||||||
|
- `fix/description` - Bug fixes
|
||||||
|
- `docs/description` - Documentation changes
|
||||||
|
- `refactor/description` - Code refactoring
|
||||||
|
|
||||||
|
### Commit Messages
|
||||||
|
|
||||||
|
Use clear, descriptive commit messages:
|
||||||
|
- `feat: Add support for batch generation`
|
||||||
|
- `fix: Handle edge case in lambda search`
|
||||||
|
- `docs: Update installation instructions`
|
||||||
|
- `test: Add tests for epsilon computation`
|
||||||
|
|
||||||
|
### Pull Requests
|
||||||
|
|
||||||
|
1. Create a feature branch from `main`
|
||||||
|
2. Make your changes
|
||||||
|
3. Add tests for new functionality
|
||||||
|
4. Ensure all tests pass
|
||||||
|
5. Update documentation if needed
|
||||||
|
6. Submit a pull request
|
||||||
|
|
||||||
|
## Reporting Issues
|
||||||
|
|
||||||
|
### Bug Reports
|
||||||
|
|
||||||
|
Please include:
|
||||||
|
- Python version
|
||||||
|
- Package versions (torch, transformers, dp-fusion-lib)
|
||||||
|
- Minimal code to reproduce the issue
|
||||||
|
- Expected vs actual behavior
|
||||||
|
- Full error traceback
|
||||||
|
|
||||||
|
### Feature Requests
|
||||||
|
|
||||||
|
Please include:
|
||||||
|
- Clear description of the feature
|
||||||
|
- Use case / motivation
|
||||||
|
- Possible implementation approach (optional)
|
||||||
|
|
||||||
|
## Questions
|
||||||
|
|
||||||
|
For questions about using the library, please open a GitHub issue with the "question" label.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
By contributing, you agree that your contributions will be licensed under the same license as the project (see LICENSE file).
|
||||||
49
LICENSE
Normal file
49
LICENSE
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
DP-Fusion-Lib Non-Commercial License
|
||||||
|
|
||||||
|
Copyright (c) 2025 Rushil Thareja
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to use,
|
||||||
|
copy, modify, and distribute the Software for non-commercial purposes only,
|
||||||
|
subject to the following conditions:
|
||||||
|
|
||||||
|
1. NON-COMMERCIAL USE ONLY
|
||||||
|
|
||||||
|
The Software may only be used for:
|
||||||
|
- Academic research and publications
|
||||||
|
- Educational purposes and coursework
|
||||||
|
- Personal projects and experimentation
|
||||||
|
- Non-profit organizations
|
||||||
|
|
||||||
|
2. COMMERCIAL USE REQUIRES LICENSE
|
||||||
|
|
||||||
|
Any commercial use, including but not limited to:
|
||||||
|
- Use in commercial products or services
|
||||||
|
- Use by for-profit companies or entities
|
||||||
|
- Integration into proprietary software
|
||||||
|
- Offering the Software as a service (SaaS)
|
||||||
|
|
||||||
|
requires a separate commercial license. Contact rushil.thareja@mbzuai.ac.ae
|
||||||
|
for commercial licensing inquiries.
|
||||||
|
|
||||||
|
3. ATTRIBUTION
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included
|
||||||
|
in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
4. CITATION
|
||||||
|
|
||||||
|
Academic use must cite the associated paper:
|
||||||
|
|
||||||
|
Thareja et al. "DP-Fusion: Token-Level Differentially Private Inference
|
||||||
|
for Large Language Models" (arXiv:2507.04531)
|
||||||
|
|
||||||
|
5. NO WARRANTY
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
DEALINGS IN THE SOFTWARE.
|
||||||
46
LICENSE-COMMERCIAL.md
Normal file
46
LICENSE-COMMERCIAL.md
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Commercial Licensing
|
||||||
|
|
||||||
|
DP-Fusion-Lib is available under a dual license model.
|
||||||
|
|
||||||
|
## Non-Commercial Use (Free)
|
||||||
|
|
||||||
|
The Software is free for:
|
||||||
|
- Academic research and publications
|
||||||
|
- Educational purposes
|
||||||
|
- Personal projects
|
||||||
|
- Non-profit organizations
|
||||||
|
|
||||||
|
See [LICENSE](LICENSE) for full terms.
|
||||||
|
|
||||||
|
## Commercial Use (Paid License)
|
||||||
|
|
||||||
|
For commercial use in products, services, or by for-profit entities,
|
||||||
|
a separate commercial license is required.
|
||||||
|
|
||||||
|
### What Requires a Commercial License?
|
||||||
|
|
||||||
|
- Using DP-Fusion-Lib in commercial products or services
|
||||||
|
- Integrating DP-Fusion-Lib into proprietary software
|
||||||
|
- Offering DP-Fusion-Lib functionality as a service (SaaS)
|
||||||
|
- Use by for-profit companies or entities
|
||||||
|
- Any use that generates revenue directly or indirectly
|
||||||
|
|
||||||
|
### Commercial License Benefits
|
||||||
|
|
||||||
|
- Full commercial usage rights
|
||||||
|
- Priority support and bug fixes
|
||||||
|
- Custom integration assistance
|
||||||
|
- License terms tailored to your needs
|
||||||
|
|
||||||
|
### Contact
|
||||||
|
|
||||||
|
For commercial licensing inquiries, please contact:
|
||||||
|
|
||||||
|
**Email**: rushil.thareja@mbzuai.ac.ae
|
||||||
|
|
||||||
|
Please include:
|
||||||
|
- Your company name
|
||||||
|
- Intended use case
|
||||||
|
- Expected scale of deployment
|
||||||
|
|
||||||
|
We typically respond within 2-3 business days.
|
||||||
221
README.md
Normal file
221
README.md
Normal file
|
|
@ -0,0 +1,221 @@
|
||||||
|
# DP-Fusion-Lib
|
||||||
|
|
||||||
|
[](https://pypi.org/project/dp-fusion-lib/)
|
||||||
|
[](https://www.python.org/downloads/)
|
||||||
|
[](LICENSE)
|
||||||
|
[](https://arxiv.org/abs/2507.04531)
|
||||||
|
[](https://www.documentprivacy.com/)
|
||||||
|
[](https://console.documentprivacy.com/)
|
||||||
|
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
**DP-Fusion-Lib** enables Large Language Model inference with mathematically provable differential privacy guarantees. Based on our research paper [*"DP-Fusion: Token-Level Differentially Private Inference for Large Language Models"*](https://arxiv.org/abs/2507.04531), this library provides formal (ε, δ)-DP protection for sensitive text generation workflows.
|
||||||
|
|
||||||
|
Differential privacy is the core foundation, but the library addresses the **full spectrum of text and document privacy**. Its **PII detection and rewriting tools** can be used **with or without DP**, offering practical privacy protection by default, and **formal guarantees** when DP is enabled.
|
||||||
|
|
||||||
|
**[Try the Live Demo](https://www.documentprivacy.com)**
|
||||||
|
|
||||||
|
**[Run the example collab notebook](https://colab.research.google.com/drive/1hzoUAXF_jsFU9E3D6U5ceZdYZ3wfXPPd?usp=sharing)**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Traditional privacy approaches for LLMs rely on heuristic redaction or post-hoc filtering. **DP-Fusion-Lib** goes further by providing a complete privacy framework with three levels of protection:
|
||||||
|
|
||||||
|
| Level | Approach | Protection |
|
||||||
|
|-------|----------|------------|
|
||||||
|
| 1 | **Redaction** | Automatic PII detection and replacement via Constitutional Tagger API |
|
||||||
|
| 2 | **Paraphrasing** | Context rewriting to obscure stylistic and contextual signatures |
|
||||||
|
| 3 | **Differential Privacy** | Formal (ε, δ)-DP guarantees via controlled distribution fusion |
|
||||||
|
|
||||||
|
The library achieves Level 3 protection by fusing token probability distributions from private and redacted contexts, bounding the Rényi divergence at each generation step to provide provable privacy guarantees.
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Technical Approach
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
DP-Fusion operates by maintaining two parallel contexts during generation:
|
||||||
|
|
||||||
|
- **Private Context**: The original document containing sensitive information
|
||||||
|
- **Public Context**: A redacted version with sensitive phrases replaced by placeholders
|
||||||
|
|
||||||
|
At each token generation step, the algorithm:
|
||||||
|
|
||||||
|
1. Computes next-token probability distributions for both contexts
|
||||||
|
2. Performs binary search to find the optimal mixing parameter λ
|
||||||
|
3. Ensures the fused distribution satisfies the Rényi divergence bound
|
||||||
|
4. Samples from the privacy-preserving mixed distribution
|
||||||
|
|
||||||
|
This approach guarantees that the output distribution is statistically similar regardless of the specific private information present, providing formal differential privacy.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install dp-fusion-lib
|
||||||
|
```
|
||||||
|
|
||||||
|
**Hardware Requirements**: This library requires PyTorch. For production deployments, NVIDIA GPU acceleration is recommended. The `Qwen/Qwen2.5-7B-Instruct` model provides an effective balance between generation quality and privacy utility.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# For CUDA 12.1 environments
|
||||||
|
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
pip install dp-fusion-lib
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
For a complete working example, see the [basic usage script](examples/basic_usage.py) or run the interactive [Jupyter notebook](examples/basic_usage.ipynb).
|
||||||
|
|
||||||
|
### Step 1: Initialize Components
|
||||||
|
|
||||||
|
The Tagger API provides automated sensitive phrase detection using Constitutional AI. API keys are available at [console.documentprivacy.com](https://console.documentprivacy.com).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dp_fusion_lib import DPFusion, Tagger, compute_epsilon_single_group
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-7B-Instruct",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
||||||
|
|
||||||
|
# Initialize Tagger
|
||||||
|
tagger = Tagger(api_key="your_api_key")
|
||||||
|
tagger.set_constitution("LEGAL") # Options: LEGAL, HEALTH, FINANCE
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Build Context
|
||||||
|
|
||||||
|
The library applies differential privacy only to segments marked as private, allowing precise control over which content receives protection.
|
||||||
|
|
||||||
|
```python
|
||||||
|
dpf = DPFusion(model=model, tokenizer=tokenizer, tagger=tagger)
|
||||||
|
|
||||||
|
# Sample document with sensitive information
|
||||||
|
document = """The applicant was born in 1973 and currently resides in
|
||||||
|
Les Salles-sur-Verdon, France. In the early 1990s, a new criminal
|
||||||
|
phenomenon emerged in Denmark known as 'tax asset stripping cases'."""
|
||||||
|
|
||||||
|
# Build context with privacy annotations
|
||||||
|
dpf.add_message("system", "You're job is to re-write documents for privacy. You will be provided a document out a paraphrase that preserves privacy and doesn't leak personally identifiable information. Just output the paraphrase only, nothing else.", is_private=False)
|
||||||
|
dpf.add_message("user", document, is_private=True)
|
||||||
|
dpf.add_message("user", "I just passed the document to you, you can paraphrase it for privacy.", is_private=False)
|
||||||
|
dpf.add_message("assistant", "Here is the paraphrased document:", is_private=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Execute Private Generation
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Run tagger to identify and redact sensitive phrases
|
||||||
|
dpf.run_tagger()
|
||||||
|
|
||||||
|
# Generate with differential privacy
|
||||||
|
output = dpf.generate(
|
||||||
|
alpha=2.0, # Rényi order
|
||||||
|
beta=0.01, # Per-token privacy budget
|
||||||
|
max_new_tokens=100
|
||||||
|
)
|
||||||
|
|
||||||
|
print(output['text'])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Compute Privacy Guarantee
|
||||||
|
|
||||||
|
The library provides two epsilon values for comprehensive privacy accounting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
alpha = 2.0
|
||||||
|
beta = 0.01
|
||||||
|
delta = 1e-5
|
||||||
|
|
||||||
|
eps_result = compute_epsilon_single_group(
|
||||||
|
divergences=output['divergences']['PRIVATE'],
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
beta=beta
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"(ε, δ)-DP Guarantee (α={alpha}, δ={delta}, T={eps_result['T']} tokens):")
|
||||||
|
print(f" Empirical ε = {eps_result['empirical']:.4f} (from actual divergences)")
|
||||||
|
print(f" Theoretical ε = {eps_result['theoretical']:.4f} (worst-case, β={beta} per step)")
|
||||||
|
```
|
||||||
|
|
||||||
|
| Epsilon Type | Description | Use Case |
|
||||||
|
|--------------|-------------|----------|
|
||||||
|
| **Empirical ε** | Computed from actual per-step divergences observed during generation | Tighter bound reflecting real privacy cost |
|
||||||
|
| **Theoretical ε** | Worst-case bound assuming maximum divergence (α·β) at every step | Conservative upper bound for compliance reporting |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Privacy Parameters
|
||||||
|
|
||||||
|
| Parameter | Symbol | Description | Trade-off |
|
||||||
|
|-----------|--------|-------------|-----------|
|
||||||
|
| Beta | β | Maximum Rényi divergence per token | Lower β → Stronger privacy, reduced utility |
|
||||||
|
| Alpha | α | Rényi divergence order (must be > 1) | Higher α → Tighter bounds, different privacy regime |
|
||||||
|
| Delta | δ | Probability of privacy failure | Lower δ → Stronger guarantee, higher ε |
|
||||||
|
| Epsilon | ε | Total privacy budget (computed) | Lower ε → Stronger privacy guarantee |
|
||||||
|
|
||||||
|
**Recommendation**: For most applications, start with `alpha=2.0` and `beta=0.01`. Adjust based on your privacy-utility requirements.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Privacy
|
||||||
|
|
||||||
|
While `dp-fusion-lib` executes entirely on your infrastructure, the Tagger API requires an external call for sensitive phrase detection. For anyone with strict data residency or compliance requirements please contact me, I will help-out.
|
||||||
|
|
||||||
|
Contact [rushil.thareja@mbzuai.ac.ae](mailto:rushil.thareja@mbzuai.ac.ae).
|
||||||
|
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use this library in academic work, please cite:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{thareja2025dpfusion,
|
||||||
|
title={DP-Fusion: Token-Level Differentially Private Inference for Large Language Models},
|
||||||
|
author={Rushil Thareja and Preslav Nakov and Praneeth Vepakomma and Nils Lukas},
|
||||||
|
year={2025},
|
||||||
|
eprint={2507.04531},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CL},
|
||||||
|
url={https://arxiv.org/abs/2507.04531}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
DP-Fusion-Lib is available under a dual license:
|
||||||
|
|
||||||
|
| Use Case | License | Cost |
|
||||||
|
|----------|---------|------|
|
||||||
|
| Academic research | Non-Commercial License | Free |
|
||||||
|
| Educational use | Non-Commercial License | Free |
|
||||||
|
| Commercial products | Commercial License | Contact for pricing |
|
||||||
|
|
||||||
|
For commercial inquiries, contact [rushil.thareja@mbzuai.ac.ae](mailto:rushil.thareja@mbzuai.ac.ae).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Support
|
||||||
|
|
||||||
|
- **Documentation**: [GitHub Repository](https://github.com/rushil-thareja/dp-fusion-lib)
|
||||||
|
- **Issues**: [GitHub Issues](https://github.com/rushil-thareja/dp-fusion-lib/issues)
|
||||||
|
- **Any querries? just email me**: [rushil.thareja@mbzuai.ac.ae](mailto:rushil.thareja@mbzuai.ac.ae)
|
||||||
394
environment.yml
Normal file
394
environment.yml
Normal file
|
|
@ -0,0 +1,394 @@
|
||||||
|
name: myTorch
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=main
|
||||||
|
- _openmp_mutex=5.1=1_gnu
|
||||||
|
- bzip2=1.0.8=h5eee18b_6
|
||||||
|
- ca-certificates=2025.2.25=h06a4308_0
|
||||||
|
- expat=2.7.1=h6a678d5_0
|
||||||
|
- ld_impl_linux-64=2.40=h12ee557_0
|
||||||
|
- libffi=3.4.4=h6a678d5_1
|
||||||
|
- libgcc-ng=11.2.0=h1234567_1
|
||||||
|
- libgomp=11.2.0=h1234567_1
|
||||||
|
- libstdcxx-ng=11.2.0=h1234567_1
|
||||||
|
- libuuid=1.41.5=h5eee18b_0
|
||||||
|
- libxcb=1.17.0=h9b100fa_0
|
||||||
|
- ncurses=6.5=h7934f7d_0
|
||||||
|
- openssl=3.0.17=h5eee18b_0
|
||||||
|
- pip=25.1=pyhc872135_2
|
||||||
|
- pthread-stubs=0.3=h0ce48e5_1
|
||||||
|
- python=3.11.13=h1a3bd86_0
|
||||||
|
- readline=8.2=h5eee18b_0
|
||||||
|
- setuptools=78.1.1=py311h06a4308_0
|
||||||
|
- sqlite=3.50.2=hb25bd0a_1
|
||||||
|
- tk=8.6.14=h993c535_1
|
||||||
|
- wheel=0.45.1=py311h06a4308_0
|
||||||
|
- xorg-libx11=1.8.12=h9b100fa_1
|
||||||
|
- xorg-libxau=1.0.12=h9b100fa_0
|
||||||
|
- xorg-libxdmcp=1.1.5=h9b100fa_0
|
||||||
|
- xorg-xorgproto=2024.1=h5eee18b_1
|
||||||
|
- xz=5.6.4=h5eee18b_1
|
||||||
|
- zlib=1.2.13=h5eee18b_1
|
||||||
|
- pip:
|
||||||
|
- absl-py==2.3.1
|
||||||
|
- accelerate==1.9.0
|
||||||
|
- agentlightning==0.2.1
|
||||||
|
- agentops==0.4.21
|
||||||
|
- aiofiles==24.1.0
|
||||||
|
- aiohappyeyeballs==2.6.1
|
||||||
|
- aiohttp==3.12.15
|
||||||
|
- aiohttp-cors==0.8.1
|
||||||
|
- aiosignal==1.4.0
|
||||||
|
- alembic==1.16.4
|
||||||
|
- altair==5.5.0
|
||||||
|
- annotated-doc==0.0.3
|
||||||
|
- annotated-types==0.7.0
|
||||||
|
- antlr4-python3-runtime==4.9.3
|
||||||
|
- anyio==4.10.0
|
||||||
|
- apscheduler==3.11.1
|
||||||
|
- asgiref==3.10.0
|
||||||
|
- astor==0.8.1
|
||||||
|
- attrs==25.3.0
|
||||||
|
- authlib==1.6.5
|
||||||
|
- azure-core==1.36.0
|
||||||
|
- azure-identity==1.25.1
|
||||||
|
- azure-storage-blob==12.27.1
|
||||||
|
- backoff==2.2.1
|
||||||
|
- bidict==0.23.1
|
||||||
|
- bitsandbytes==0.47.0
|
||||||
|
- blake3==1.0.6
|
||||||
|
- blinker==1.9.0
|
||||||
|
- boto3==1.36.0
|
||||||
|
- botocore==1.36.26
|
||||||
|
- brotli==1.1.0
|
||||||
|
- cachetools==5.5.2
|
||||||
|
- cbor2==5.7.0
|
||||||
|
- certifi==2025.8.3
|
||||||
|
- cffi==2.0.0
|
||||||
|
- charset-normalizer==3.4.3
|
||||||
|
- click==8.2.1
|
||||||
|
- cloudpickle==3.1.1
|
||||||
|
- codetiming==1.4.0
|
||||||
|
- colorful==0.5.8
|
||||||
|
- compressed-tensors==0.11.0
|
||||||
|
- contourpy==1.3.3
|
||||||
|
- croniter==6.0.0
|
||||||
|
- cryptography==46.0.3
|
||||||
|
- cupy-cuda12x==13.6.0
|
||||||
|
- cut-cross-entropy==25.1.1
|
||||||
|
- cycler==0.12.1
|
||||||
|
- cyclopts==4.2.1
|
||||||
|
- dataclasses-json==0.6.7
|
||||||
|
- datasets==3.6.0
|
||||||
|
- dateparser==1.2.2
|
||||||
|
- datetime==5.5
|
||||||
|
- depyf==0.19.0
|
||||||
|
- diffusers==0.35.1
|
||||||
|
- dill==0.3.8
|
||||||
|
- diskcache==5.6.3
|
||||||
|
- distlib==0.4.0
|
||||||
|
- distro==1.9.0
|
||||||
|
- django==5.2.7
|
||||||
|
- dnspython==2.8.0
|
||||||
|
- docstring-parser==0.17.0
|
||||||
|
- docutils==0.22.2
|
||||||
|
- einops==0.8.1
|
||||||
|
- email-validator==2.3.0
|
||||||
|
- exceptiongroup==1.3.0
|
||||||
|
- fastapi==0.121.1
|
||||||
|
- fastapi-cli==0.0.13
|
||||||
|
- fastapi-cloud-cli==0.2.1
|
||||||
|
- fastapi-sso==0.16.0
|
||||||
|
- fastmcp==2.11.1
|
||||||
|
- fastrlock==0.8.3
|
||||||
|
- fastuuid==0.14.0
|
||||||
|
- ffmpy==0.6.1
|
||||||
|
- filelock==3.19.1
|
||||||
|
- flask==3.1.2
|
||||||
|
- fonttools==4.59.2
|
||||||
|
- frozendict==2.4.6
|
||||||
|
- frozenlist==1.7.0
|
||||||
|
- fsspec==2025.3.0
|
||||||
|
- geojson==2.5.0
|
||||||
|
- gguf==0.17.1
|
||||||
|
- gitdb==4.0.12
|
||||||
|
- gitpython==3.1.45
|
||||||
|
- google-api-core==2.25.1
|
||||||
|
- google-api-python-client==2.179.0
|
||||||
|
- google-auth==2.40.3
|
||||||
|
- google-auth-httplib2==0.2.0
|
||||||
|
- google-auth-oauthlib==1.2.2
|
||||||
|
- googleapis-common-protos==1.70.0
|
||||||
|
- gradio==5.40.0
|
||||||
|
- gradio-client==1.11.0
|
||||||
|
- granian==2.5.0
|
||||||
|
- graphviz==0.21
|
||||||
|
- greenlet==3.2.3
|
||||||
|
- groovy==0.1.2
|
||||||
|
- groq==0.31.1
|
||||||
|
- grpcio==1.76.0
|
||||||
|
- gunicorn==23.0.0
|
||||||
|
- h11==0.16.0
|
||||||
|
- hf-transfer==0.1.9
|
||||||
|
- hf-xet==1.1.10
|
||||||
|
- httpcore==1.0.9
|
||||||
|
- httpdbg==2.1.3
|
||||||
|
- httplib2==0.22.0
|
||||||
|
- httptools==0.6.4
|
||||||
|
- httpx==0.28.1
|
||||||
|
- httpx-sse==0.4.1
|
||||||
|
- huggingface-hub==0.35.1
|
||||||
|
- hydra-core==1.3.2
|
||||||
|
- idna==3.10
|
||||||
|
- importlib-metadata==6.11.0
|
||||||
|
- interegular==0.3.3
|
||||||
|
- isodate==0.7.2
|
||||||
|
- itsdangerous==2.2.0
|
||||||
|
- jinja2==3.1.6
|
||||||
|
- jiter==0.10.0
|
||||||
|
- jmespath==1.0.1
|
||||||
|
- joblib==1.5.2
|
||||||
|
- jsonpatch==1.33
|
||||||
|
- jsonpointer==3.0.0
|
||||||
|
- jsonschema==4.25.0
|
||||||
|
- jsonschema-path==0.3.4
|
||||||
|
- jsonschema-specifications==2025.4.1
|
||||||
|
- kiwisolver==1.4.9
|
||||||
|
- langchain==0.3.27
|
||||||
|
- langchain-community==0.3.31
|
||||||
|
- langchain-core==0.3.80
|
||||||
|
- langchain-groq==0.3.8
|
||||||
|
- langchain-ollama==0.3.6
|
||||||
|
- langchain-openai==0.3.28
|
||||||
|
- langchain-text-splitters==0.3.9
|
||||||
|
- langgraph==0.6.3
|
||||||
|
- langgraph-checkpoint==2.1.1
|
||||||
|
- langgraph-prebuilt==0.6.3
|
||||||
|
- langgraph-sdk==0.2.0
|
||||||
|
- langsmith==0.4.11
|
||||||
|
- lark==1.2.2
|
||||||
|
- lazy-object-proxy==1.12.0
|
||||||
|
- litellm==1.79.3
|
||||||
|
- litellm-enterprise==0.1.20
|
||||||
|
- litellm-proxy-extras==0.4.3
|
||||||
|
- llguidance==0.7.30
|
||||||
|
- llvmlite==0.44.0
|
||||||
|
- lm-format-enforcer==0.11.3
|
||||||
|
- mako==1.3.10
|
||||||
|
- markdown==3.10
|
||||||
|
- markdown-it-py==3.0.0
|
||||||
|
- markupsafe==3.0.2
|
||||||
|
- marshmallow==3.26.1
|
||||||
|
- matplotlib==3.10.6
|
||||||
|
- mcp==1.12.3
|
||||||
|
- mdurl==0.1.2
|
||||||
|
- mistral-common==1.8.5
|
||||||
|
- more-itertools==10.8.0
|
||||||
|
- mpmath==1.3.0
|
||||||
|
- msal==1.34.0
|
||||||
|
- msal-extensions==1.3.1
|
||||||
|
- msgpack==1.1.1
|
||||||
|
- msgspec==0.19.0
|
||||||
|
- multidict==6.6.4
|
||||||
|
- multiprocess==0.70.16
|
||||||
|
- mypy-extensions==1.1.0
|
||||||
|
- nano==1.0.0
|
||||||
|
- narwhals==2.0.1
|
||||||
|
- networkx==3.5
|
||||||
|
- ninja==1.13.0
|
||||||
|
- nltk==3.9.2
|
||||||
|
- numba==0.61.2
|
||||||
|
- numpy==1.26.4
|
||||||
|
- nvidia-cublas-cu12==12.8.4.1
|
||||||
|
- nvidia-cuda-cupti-cu12==12.8.90
|
||||||
|
- nvidia-cuda-nvrtc-cu12==12.8.93
|
||||||
|
- nvidia-cuda-runtime-cu12==12.8.90
|
||||||
|
- nvidia-cudnn-cu12==9.10.2.21
|
||||||
|
- nvidia-cufft-cu12==11.3.3.83
|
||||||
|
- nvidia-cufile-cu12==1.13.1.3
|
||||||
|
- nvidia-curand-cu12==10.3.9.90
|
||||||
|
- nvidia-cusolver-cu12==11.7.3.90
|
||||||
|
- nvidia-cusparse-cu12==12.5.8.93
|
||||||
|
- nvidia-cusparselt-cu12==0.7.1
|
||||||
|
- nvidia-nccl-cu12==2.27.3
|
||||||
|
- nvidia-nvjitlink-cu12==12.8.93
|
||||||
|
- nvidia-nvtx-cu12==12.8.90
|
||||||
|
- oauthlib==3.3.1
|
||||||
|
- ollama==0.5.1
|
||||||
|
- omegaconf==2.3.0
|
||||||
|
- openai==1.109.1
|
||||||
|
- openai-harmony==0.0.4
|
||||||
|
- openapi-core==0.19.5
|
||||||
|
- openapi-pydantic==0.5.1
|
||||||
|
- openapi-schema-validator==0.6.3
|
||||||
|
- openapi-spec-validator==0.7.2
|
||||||
|
- opencensus==0.11.4
|
||||||
|
- opencensus-context==0.1.3
|
||||||
|
- opencv-python-headless==4.12.0.88
|
||||||
|
- opentelemetry-api==1.38.0
|
||||||
|
- opentelemetry-exporter-otlp==1.38.0
|
||||||
|
- opentelemetry-exporter-otlp-proto-common==1.38.0
|
||||||
|
- opentelemetry-exporter-otlp-proto-grpc==1.38.0
|
||||||
|
- opentelemetry-exporter-otlp-proto-http==1.38.0
|
||||||
|
- opentelemetry-exporter-prometheus==0.59b0
|
||||||
|
- opentelemetry-instrumentation==0.59b0
|
||||||
|
- opentelemetry-proto==1.38.0
|
||||||
|
- opentelemetry-sdk==1.38.0
|
||||||
|
- opentelemetry-semantic-conventions==0.59b0
|
||||||
|
- ordered-set==4.1.0
|
||||||
|
- orjson==3.11.1
|
||||||
|
- ormsgpack==1.10.0
|
||||||
|
- outlines-core==0.2.11
|
||||||
|
- packaging==25.0
|
||||||
|
- pandas==2.3.2
|
||||||
|
- parse==1.20.2
|
||||||
|
- partial-json-parser==0.2.1.1.post6
|
||||||
|
- pathable==0.4.4
|
||||||
|
- peft==0.17.1
|
||||||
|
- pillow==10.4.0
|
||||||
|
- platformdirs==4.3.8
|
||||||
|
- plotly==5.17.0
|
||||||
|
- polars==1.35.1
|
||||||
|
- polars-runtime-32==1.35.1
|
||||||
|
- prometheus-client==0.23.1
|
||||||
|
- prometheus-fastapi-instrumentator==7.1.0
|
||||||
|
- propcache==0.3.2
|
||||||
|
- proto-plus==1.26.1
|
||||||
|
- protobuf==6.33.0
|
||||||
|
- psutil==7.0.0
|
||||||
|
- psycopg2-binary==2.9.10
|
||||||
|
- py-cpuinfo==9.0.0
|
||||||
|
- py-spy==0.4.1
|
||||||
|
- pyarrow==21.0.0
|
||||||
|
- pyasn1==0.6.1
|
||||||
|
- pyasn1-modules==0.4.2
|
||||||
|
- pybase64==1.4.2
|
||||||
|
- pybind11==3.0.1
|
||||||
|
- pycountry==24.6.1
|
||||||
|
- pycparser==2.23
|
||||||
|
- pydantic==2.11.7
|
||||||
|
- pydantic-core==2.33.2
|
||||||
|
- pydantic-extra-types==2.10.5
|
||||||
|
- pydantic-settings==2.10.1
|
||||||
|
- pydeck==0.9.1
|
||||||
|
- pydub==0.25.1
|
||||||
|
- pygments==2.19.2
|
||||||
|
- pyjwt==2.10.1
|
||||||
|
- pylatexenc==2.10
|
||||||
|
- pymongo==4.14.0
|
||||||
|
- pynacl==1.6.0
|
||||||
|
- pynano==1.0.1
|
||||||
|
- pyowm==3.3.0
|
||||||
|
- pyparsing==3.2.3
|
||||||
|
- pyperclip==1.11.0
|
||||||
|
- pysocks==1.7.1
|
||||||
|
- python-dateutil==2.9.0.post0
|
||||||
|
- python-dotenv==1.1.1
|
||||||
|
- python-engineio==4.12.2
|
||||||
|
- python-json-logger==3.3.0
|
||||||
|
- python-multipart==0.0.20
|
||||||
|
- python-socketio==5.13.0
|
||||||
|
- pytz==2025.2
|
||||||
|
- pyvers==0.1.0
|
||||||
|
- pyyaml==6.0.3
|
||||||
|
- pyzmq==27.1.0
|
||||||
|
- ray==2.49.2
|
||||||
|
- redis==6.4.0
|
||||||
|
- referencing==0.36.2
|
||||||
|
- reflex==0.8.7
|
||||||
|
- reflex-hosting-cli==0.1.55
|
||||||
|
- regex==2025.9.18
|
||||||
|
- requests==2.32.5
|
||||||
|
- requests-oauthlib==2.0.0
|
||||||
|
- requests-toolbelt==1.0.0
|
||||||
|
- rfc3339-validator==0.1.4
|
||||||
|
- rich==13.9.4
|
||||||
|
- rich-rst==1.3.2
|
||||||
|
- rich-toolkit==0.15.1
|
||||||
|
- rignore==0.6.4
|
||||||
|
- rouge-score==0.1.2
|
||||||
|
- rpds-py==0.26.0
|
||||||
|
- rq==2.6.0
|
||||||
|
- rsa==4.9.1
|
||||||
|
- ruff==0.12.7
|
||||||
|
- s3transfer==0.11.3
|
||||||
|
- safehttpx==0.1.6
|
||||||
|
- safetensors==0.6.2
|
||||||
|
- scikit-learn==1.7.2
|
||||||
|
- scipy==1.16.2
|
||||||
|
- seaborn==0.13.2
|
||||||
|
- semantic-version==2.10.0
|
||||||
|
- sentence-transformers==5.1.0
|
||||||
|
- sentencepiece==0.2.1
|
||||||
|
- sentry-sdk==2.39.0
|
||||||
|
- setproctitle==1.3.7
|
||||||
|
- shellingham==1.5.4
|
||||||
|
- shtab==1.7.2
|
||||||
|
- simple-websocket==1.1.0
|
||||||
|
- six==1.17.0
|
||||||
|
- smart-open==7.5.0
|
||||||
|
- smmap==5.0.2
|
||||||
|
- sniffio==1.3.1
|
||||||
|
- soundfile==0.12.1
|
||||||
|
- soxr==1.0.0
|
||||||
|
- sqlalchemy==2.0.42
|
||||||
|
- sqlmodel==0.0.24
|
||||||
|
- sqlparse==0.5.3
|
||||||
|
- sse-starlette==3.0.2
|
||||||
|
- starlette==0.47.2
|
||||||
|
- streamlit==1.48.1
|
||||||
|
- sympy==1.14.0
|
||||||
|
- tenacity==8.5.0
|
||||||
|
- tensorboard==2.20.0
|
||||||
|
- tensorboard-data-server==0.7.2
|
||||||
|
- tensordict==0.10.0
|
||||||
|
- termcolor==2.4.0
|
||||||
|
- threadpoolctl==3.6.0
|
||||||
|
- tiktoken==0.9.0
|
||||||
|
- tokenizers==0.22.1
|
||||||
|
- toml==0.10.2
|
||||||
|
- tomlkit==0.13.3
|
||||||
|
- torch==2.8.0
|
||||||
|
- torchao==0.13.0
|
||||||
|
- torchaudio==2.8.0
|
||||||
|
- torchdata==0.11.0
|
||||||
|
- torchvision==0.23.0
|
||||||
|
- tornado==6.5.2
|
||||||
|
- tqdm==4.67.1
|
||||||
|
- transformers==4.56.2
|
||||||
|
- triton==3.4.0
|
||||||
|
- trl==0.25.1
|
||||||
|
- typeguard==4.4.4
|
||||||
|
- typer==0.16.0
|
||||||
|
- typing-extensions==4.15.0
|
||||||
|
- typing-inspect==0.9.0
|
||||||
|
- typing-inspection==0.4.1
|
||||||
|
- tyro==0.9.32
|
||||||
|
- tzdata==2025.2
|
||||||
|
- tzlocal==5.3.1
|
||||||
|
- unsloth==2025.9.9
|
||||||
|
- unsloth-zoo==2025.9.12
|
||||||
|
- uritemplate==4.2.0
|
||||||
|
- urllib3==2.5.0
|
||||||
|
- uvicorn==0.29.0
|
||||||
|
- uvloop==0.21.0
|
||||||
|
- validators==0.35.0
|
||||||
|
- verl==0.6.0
|
||||||
|
- virtualenv==20.35.4
|
||||||
|
- vllm==0.10.2
|
||||||
|
- wandb==0.22.3
|
||||||
|
- watchdog==6.0.0
|
||||||
|
- watchfiles==1.1.0
|
||||||
|
- websockets==13.1
|
||||||
|
- werkzeug==3.1.1
|
||||||
|
- wrapt==1.17.3
|
||||||
|
- wsproto==1.2.0
|
||||||
|
- xformers==0.0.32.post1
|
||||||
|
- xgrammar==0.1.23
|
||||||
|
- xmltodict==1.0.2
|
||||||
|
- xxhash==3.5.0
|
||||||
|
- yarl==1.20.1
|
||||||
|
- zipp==3.23.0
|
||||||
|
- zope-interface==7.2
|
||||||
|
- zstandard==0.23.0
|
||||||
1300
examples/basic_usage.ipynb
Normal file
1300
examples/basic_usage.ipynb
Normal file
File diff suppressed because it is too large
Load diff
149
examples/basic_usage.py
Normal file
149
examples/basic_usage.py
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
"""
|
||||||
|
Basic usage example for DP-Fusion-Lib with Tagger API
|
||||||
|
|
||||||
|
Demonstrates the Tagger integration for fine-grained privacy redaction.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
pip install dp-fusion-lib transformers torch
|
||||||
|
|
||||||
|
Note: This example requires a GPU for reasonable performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from dp_fusion_lib import DPFusion, Tagger, compute_epsilon_single_group
|
||||||
|
|
||||||
|
# Model config
|
||||||
|
# this model works well u can use it
|
||||||
|
MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
|
||||||
|
|
||||||
|
# API config - Get your free key at console.documentprivacy.com
|
||||||
|
API_KEY = "put ure key here!"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("DP-Fusion Library Example (with Tagger API)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
print(f"\nLoading tokenizer: {MODEL_ID}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
MODEL_ID,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model on GPU
|
||||||
|
print(f"Loading model: {MODEL_ID}")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
MODEL_ID,
|
||||||
|
dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
print("Model loaded successfully")
|
||||||
|
|
||||||
|
# Initialize Tagger with API key (verbose=True to see input/output)
|
||||||
|
print("\nInitializing Tagger...")
|
||||||
|
tagger = Tagger(api_key=API_KEY, verbose=True)
|
||||||
|
|
||||||
|
# List available models
|
||||||
|
print("\nAvailable models:")
|
||||||
|
available_models = tagger.get_available_models()
|
||||||
|
for m in available_models:
|
||||||
|
print(f" - {m}")
|
||||||
|
|
||||||
|
# Configure tagger
|
||||||
|
# gpt-oss-120b is a good nice strong model
|
||||||
|
tagger.set_model("gpt-oss-120b")
|
||||||
|
tagger.set_constitution("LEGAL")
|
||||||
|
|
||||||
|
# Initialize DPFusion with tagger
|
||||||
|
print("Initializing DPFusion with Tagger...")
|
||||||
|
dpf = DPFusion(model=model, tokenizer=tokenizer, max_tokens=100, tagger=tagger)
|
||||||
|
|
||||||
|
# Example private text (ECHR style legal document)
|
||||||
|
private_text = """The applicant was born in 1973 and currently resides in Les Salles-sur-Verdon, France.
|
||||||
|
In the early 1990s, a new criminal phenomenon emerged in Denmark known as 'tax asset stripping cases' (selskabstømmersager)."""
|
||||||
|
|
||||||
|
print(f"\nPrivate text ({len(private_text)} characters):")
|
||||||
|
print(private_text)
|
||||||
|
|
||||||
|
# Build context using message API
|
||||||
|
dpf.add_message("system", "You are a helpful assistant that paraphrases text.", is_private=False)
|
||||||
|
dpf.add_message("user", private_text, is_private=True)
|
||||||
|
dpf.add_message("system", "Now paraphrase this text for privacy", is_private=False)
|
||||||
|
dpf.add_message("assistant", "Sure, here is the paraphrase of the above text that ensures privacy:", is_private=False)
|
||||||
|
|
||||||
|
# Run tagger to extract and redact private phrases
|
||||||
|
print("\n" + "-" * 60)
|
||||||
|
print("Running Tagger API to extract private phrases...")
|
||||||
|
print("-" * 60)
|
||||||
|
dpf.run_tagger()
|
||||||
|
|
||||||
|
# Show both contexts
|
||||||
|
print("\n" + "-" * 60)
|
||||||
|
print("Private Context (full text):")
|
||||||
|
print("-" * 60)
|
||||||
|
print(dpf.private_context)
|
||||||
|
|
||||||
|
print("\n" + "-" * 60)
|
||||||
|
print("Public Context (redacted):")
|
||||||
|
print("-" * 60)
|
||||||
|
print(dpf.public_context)
|
||||||
|
|
||||||
|
# Run DP-Fusion generation
|
||||||
|
print("\n" + "-" * 60)
|
||||||
|
print("Running DP-Fusion generation...")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
output = dpf.generate(
|
||||||
|
alpha=2.0,
|
||||||
|
beta=0.01,
|
||||||
|
temperature=1.0,
|
||||||
|
max_new_tokens=100,
|
||||||
|
debug=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Results:")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"\nGenerated text:\n{output['text']}\n")
|
||||||
|
|
||||||
|
# Print some stats
|
||||||
|
if output['lambdas'].get('PRIVATE'):
|
||||||
|
lambdas = output['lambdas']['PRIVATE']
|
||||||
|
print(f"Lambda stats: Mean={sum(lambdas)/len(lambdas):.4f}, Min={min(lambdas):.4f}, Max={max(lambdas):.4f}")
|
||||||
|
|
||||||
|
if output['divergences'].get('PRIVATE'):
|
||||||
|
divs = output['divergences']['PRIVATE']
|
||||||
|
print(f"Divergence stats: Mean={sum(divs)/len(divs):.4f}, Min={min(divs):.4f}, Max={max(divs):.4f}")
|
||||||
|
|
||||||
|
# Compute (ε, δ)-DP guarantee
|
||||||
|
print("\n" + "-" * 60)
|
||||||
|
print("Computing (ε, δ)-DP guarantees:")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
alpha = 2.0 # Rényi order (same as used in generation)
|
||||||
|
beta = 0.01 # Divergence bound (same as used in generation)
|
||||||
|
delta = 1e-5 # Target δ for (ε, δ)-DP
|
||||||
|
|
||||||
|
if output['divergences'].get('PRIVATE'):
|
||||||
|
eps_result = compute_epsilon_single_group(
|
||||||
|
divergences=output['divergences']['PRIVATE'],
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
beta=beta
|
||||||
|
)
|
||||||
|
print(f"\n(ε, δ)-DP guarantees (α={alpha}, δ={delta}, T={eps_result['T']} tokens):")
|
||||||
|
print(f" Empirical ε = {eps_result['empirical']:.4f} (from actual divergences)")
|
||||||
|
print(f" Theoretical ε = {eps_result['theoretical']:.4f} (worst-case, β={beta} per step)")
|
||||||
|
|
||||||
|
print("\nExample completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
images/demo_docscan_page.jpg
Normal file
BIN
images/demo_docscan_page.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 468 KiB |
BIN
images/dp-fusion-main-new_page.jpg
Normal file
BIN
images/dp-fusion-main-new_page.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
BIN
images/eyecatcher_v2_page.jpg
Normal file
BIN
images/eyecatcher_v2_page.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 484 KiB |
72
pyproject.toml
Normal file
72
pyproject.toml
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "dp-fusion-lib"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Token-Level Differentially Private Inference for Large Language Models"
|
||||||
|
readme = "README.md"
|
||||||
|
license = {file = "LICENSE"}
|
||||||
|
authors = [
|
||||||
|
{name = "Rushil Thareja", email = "rushil.thareja@mbzuai.ac.ae"}
|
||||||
|
]
|
||||||
|
keywords = ["differential-privacy", "llm", "text-generation", "privacy", "nlp", "renyi-divergence"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"License :: Other/Proprietary License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Security",
|
||||||
|
]
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"transformers>=4.25.0",
|
||||||
|
"accelerate>=0.20.0",
|
||||||
|
"requests>=2.25.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"black>=23.0.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"pre-commit>=3.0.0",
|
||||||
|
]
|
||||||
|
docs = [
|
||||||
|
"mkdocs>=1.5.0",
|
||||||
|
"mkdocs-material>=9.0.0",
|
||||||
|
"mkdocstrings[python]>=0.24.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/rushil-thareja/dp-fusion-lib"
|
||||||
|
Documentation = "https://rushil-thareja.github.io/dp-fusion-lib"
|
||||||
|
Repository = "https://github.com/rushil-thareja/dp-fusion-lib"
|
||||||
|
Issues = "https://github.com/rushil-thareja/dp-fusion-lib/issues"
|
||||||
|
Paper = "https://arxiv.org/abs/2507.04531"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 100
|
||||||
|
target-version = ["py38", "py39", "py310", "py311"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
select = ["E", "F", "I", "N", "W"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
addopts = "-v --cov=dp_fusion_lib --cov-report=term-missing"
|
||||||
85
src/dp_fusion_lib/__init__.py
Normal file
85
src/dp_fusion_lib/__init__.py
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
"""
|
||||||
|
DP-Fusion-Lib: Token-Level Differentially Private Inference for LLMs
|
||||||
|
|
||||||
|
Generate text with formal (epsilon, delta)-differential privacy guarantees
|
||||||
|
using distribution fusion techniques.
|
||||||
|
|
||||||
|
This library implements the DP-Fusion algorithm from:
|
||||||
|
|
||||||
|
Thareja et al. "DP-Fusion: Token-Level Differentially Private
|
||||||
|
Inference for Large Language Models" (arXiv:2507.04531)
|
||||||
|
|
||||||
|
Quick Start:
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
>>> from dp_fusion_lib import DPFusion, compute_epsilon_single_group
|
||||||
|
>>>
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||||
|
>>>
|
||||||
|
>>> dpf = DPFusion(model=model, tokenizer=tokenizer)
|
||||||
|
>>> dpf.add_message("system", "You are a helpful assistant.", is_private=False)
|
||||||
|
>>> dpf.add_message("user", "My SSN is 123-45-6789. Summarize my info.", is_private=True)
|
||||||
|
>>>
|
||||||
|
>>> output = dpf.generate(alpha=2.0, beta=0.1, max_new_tokens=100)
|
||||||
|
>>> print(output["text"])
|
||||||
|
>>>
|
||||||
|
>>> # Compute privacy guarantee
|
||||||
|
>>> eps = compute_epsilon_single_group(
|
||||||
|
... divergences=output["divergences"]["PRIVATE"],
|
||||||
|
... alpha=2.0,
|
||||||
|
... delta=1e-5,
|
||||||
|
... beta=0.1
|
||||||
|
... )
|
||||||
|
>>> print(f"Privacy: epsilon={eps['empirical']:.2f} at delta=1e-5")
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"PyTorch is required but not installed. Install it first:\n"
|
||||||
|
" pip install torch\n"
|
||||||
|
" or with CUDA: pip install torch --index-url https://download.pytorch.org/whl/cu121\n"
|
||||||
|
" or visit https://pytorch.org/get-started/locally/"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Core classes and functions
|
||||||
|
from dp_fusion_lib.core import DPFusion, generate_dp_text
|
||||||
|
from dp_fusion_lib.tagger import Tagger, find_phrase_offsets
|
||||||
|
from dp_fusion_lib.epsilon import compute_epsilon_single_group, compute_dp_epsilon
|
||||||
|
from dp_fusion_lib._version import __version__
|
||||||
|
|
||||||
|
# Utility functions (advanced usage)
|
||||||
|
from dp_fusion_lib.utils import (
|
||||||
|
compute_renyi_divergence_clipped_symmetric,
|
||||||
|
find_lambda,
|
||||||
|
replace_sequences_with_placeholder_fast,
|
||||||
|
dp_fusion_groups_incremental,
|
||||||
|
format_prompt_new_template,
|
||||||
|
DEFAULT_BETA_DICT,
|
||||||
|
ENTITY_TYPES,
|
||||||
|
PLACEHOLDER_TOKEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Main API
|
||||||
|
"DPFusion",
|
||||||
|
"Tagger",
|
||||||
|
"generate_dp_text",
|
||||||
|
# Epsilon computation
|
||||||
|
"compute_epsilon_single_group",
|
||||||
|
"compute_dp_epsilon",
|
||||||
|
# Utility functions (advanced)
|
||||||
|
"find_phrase_offsets",
|
||||||
|
"compute_renyi_divergence_clipped_symmetric",
|
||||||
|
"find_lambda",
|
||||||
|
"replace_sequences_with_placeholder_fast",
|
||||||
|
"dp_fusion_groups_incremental",
|
||||||
|
"format_prompt_new_template",
|
||||||
|
# Constants
|
||||||
|
"DEFAULT_BETA_DICT",
|
||||||
|
"ENTITY_TYPES",
|
||||||
|
"PLACEHOLDER_TOKEN",
|
||||||
|
# Version
|
||||||
|
"__version__",
|
||||||
|
]
|
||||||
3
src/dp_fusion_lib/_version.py
Normal file
3
src/dp_fusion_lib/_version.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""Version information for dp-fusion-lib."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
482
src/dp_fusion_lib/core.py
Normal file
482
src/dp_fusion_lib/core.py
Normal file
|
|
@ -0,0 +1,482 @@
|
||||||
|
"""
|
||||||
|
Core DP-Fusion generation module.
|
||||||
|
|
||||||
|
This module provides the main DPFusion class and convenience functions
|
||||||
|
for differentially private text generation using distribution fusion.
|
||||||
|
|
||||||
|
Theory:
|
||||||
|
DP-Fusion mixes token probability distributions from:
|
||||||
|
1. Private context: Full sensitive document
|
||||||
|
2. Public context: Redacted version with placeholders
|
||||||
|
|
||||||
|
The mixing is controlled via λ to bound the Rényi divergence,
|
||||||
|
providing formal (ε, δ)-differential privacy guarantees.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
Thareja et al. "DP-Fusion: Token-Level Differentially Private
|
||||||
|
Inference for Large Language Models" (arXiv:2507.04531)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from dp_fusion_lib.tagger import Tagger, find_phrase_offsets
|
||||||
|
from dp_fusion_lib.utils import (
|
||||||
|
dp_fusion_groups_incremental,
|
||||||
|
format_prompt_new_template,
|
||||||
|
replace_sequences_with_placeholder_fast,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DPFusion:
|
||||||
|
"""
|
||||||
|
DP-Fusion wrapper for differentially private text generation.
|
||||||
|
|
||||||
|
This class provides a clean API for mixing private and public distributions
|
||||||
|
to generate text with differential privacy guarantees.
|
||||||
|
|
||||||
|
The workflow supports two modes:
|
||||||
|
1. **Message-based**: Build context with `add_message()`, run `run_tagger()`
|
||||||
|
for automatic phrase extraction, then `generate()`.
|
||||||
|
2. **Direct context**: Pass `private_context` and `public_context` directly
|
||||||
|
to `generate()`.
|
||||||
|
|
||||||
|
Example (Message-based with Tagger):
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
>>> from dp_fusion_lib import DPFusion, Tagger
|
||||||
|
>>>
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||||
|
>>> tagger = Tagger(api_key="sk_...")
|
||||||
|
>>>
|
||||||
|
>>> dpf = DPFusion(model=model, tokenizer=tokenizer, tagger=tagger)
|
||||||
|
>>> dpf.add_message("system", "You are a helpful assistant.", is_private=False)
|
||||||
|
>>> dpf.add_message("user", "My SSN is 123-45-6789.", is_private=True)
|
||||||
|
>>> dpf.run_tagger()
|
||||||
|
>>> output = dpf.generate(alpha=2.0, beta=0.1)
|
||||||
|
>>> print(output["text"])
|
||||||
|
|
||||||
|
Example (Direct context):
|
||||||
|
>>> dpf = DPFusion(model=model, tokenizer=tokenizer)
|
||||||
|
>>> output = dpf.generate(
|
||||||
|
... private_context="John Doe's SSN is 123-45-6789.",
|
||||||
|
... public_context="_'s SSN is _.",
|
||||||
|
... alpha=2.0,
|
||||||
|
... beta=0.1
|
||||||
|
... )
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A HuggingFace CausalLM model (on any device)
|
||||||
|
tokenizer: Corresponding HuggingFace tokenizer
|
||||||
|
max_tokens: Maximum number of tokens to generate (default: 100)
|
||||||
|
placeholder: Placeholder character for redacted content (default: "_")
|
||||||
|
tagger: Optional Tagger instance for automatic phrase extraction
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
max_tokens: int = 100,
|
||||||
|
placeholder: str = "_",
|
||||||
|
tagger: Optional[Tagger] = None
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.placeholder = placeholder
|
||||||
|
self.tagger = tagger
|
||||||
|
|
||||||
|
# Auto-detect device from model parameters
|
||||||
|
self.device = next(model.parameters()).device
|
||||||
|
|
||||||
|
# Ensure tokenizer has pad token
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
# Message storage for building context
|
||||||
|
self._messages: List[Dict] = []
|
||||||
|
|
||||||
|
# Cached contexts (populated by run_tagger)
|
||||||
|
self._private_context: Optional[str] = None
|
||||||
|
self._public_context: Optional[str] = None
|
||||||
|
self._private_tokens: Optional[torch.Tensor] = None
|
||||||
|
self._public_tokens: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def add_message(self, role: str, content: str, is_private: bool = False):
|
||||||
|
"""
|
||||||
|
Add a message to the conversation context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role: Message role - "system", "user", or "assistant"
|
||||||
|
content: The message text
|
||||||
|
is_private: If True, content is sensitive and will be redacted
|
||||||
|
in the public context
|
||||||
|
"""
|
||||||
|
self._messages.append({
|
||||||
|
"role": role,
|
||||||
|
"content": content,
|
||||||
|
"is_private": is_private
|
||||||
|
})
|
||||||
|
|
||||||
|
def clear_messages(self):
|
||||||
|
"""Clear all stored messages and cached contexts."""
|
||||||
|
self._messages = []
|
||||||
|
self._private_context = None
|
||||||
|
self._public_context = None
|
||||||
|
self._private_tokens = None
|
||||||
|
self._public_tokens = None
|
||||||
|
|
||||||
|
def run_tagger(self):
|
||||||
|
"""
|
||||||
|
Run the tagger on all private messages to extract and redact private phrases.
|
||||||
|
|
||||||
|
This method calls the privacy API to identify sensitive phrases in messages
|
||||||
|
marked as private, then builds both private and public contexts with
|
||||||
|
fine-grained redaction at the token level to ensure alignment.
|
||||||
|
|
||||||
|
Must be called before generate() if using fine-grained redaction.
|
||||||
|
Populates self._private_context, self._public_context, and token versions.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no tagger is configured or no messages added
|
||||||
|
requests.RequestException: If API call fails
|
||||||
|
"""
|
||||||
|
if self.tagger is None:
|
||||||
|
raise ValueError("No tagger configured. Pass tagger to DPFusion.__init__")
|
||||||
|
|
||||||
|
if not self._messages:
|
||||||
|
raise ValueError("No messages added. Use add_message() first.")
|
||||||
|
|
||||||
|
# Collect all private phrases from private messages
|
||||||
|
all_phrases = []
|
||||||
|
for msg in self._messages:
|
||||||
|
if msg["is_private"]:
|
||||||
|
phrases = self.tagger.extract_private_phrases(msg["content"])
|
||||||
|
all_phrases.extend(phrases)
|
||||||
|
|
||||||
|
# Build the full private prompt text
|
||||||
|
private_msgs = [{"role": msg["role"], "content": msg["content"]} for msg in self._messages]
|
||||||
|
self._private_context = self.tokenizer.apply_chat_template(
|
||||||
|
private_msgs, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize the full private context
|
||||||
|
self._private_tokens = self.tokenizer.encode(self._private_context, return_tensors="pt")[0]
|
||||||
|
|
||||||
|
if all_phrases:
|
||||||
|
# Find phrase offsets in the FULL prompt text
|
||||||
|
offsets = find_phrase_offsets(self._private_context, all_phrases)
|
||||||
|
|
||||||
|
# Get public tokens directly - SAME LENGTH as private tokens!
|
||||||
|
public_token_ids = replace_sequences_with_placeholder_fast(
|
||||||
|
self._private_context, offsets, self.placeholder, self.tokenizer
|
||||||
|
)
|
||||||
|
self._public_tokens = torch.tensor(public_token_ids)
|
||||||
|
|
||||||
|
# Decode for display purposes only
|
||||||
|
self._public_context = self.tokenizer.decode(self._public_tokens, skip_special_tokens=False)
|
||||||
|
else:
|
||||||
|
# No private phrases found, public = private
|
||||||
|
self._public_tokens = self._private_tokens.clone()
|
||||||
|
self._public_context = self._private_context
|
||||||
|
|
||||||
|
@property
|
||||||
|
def private_context(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the private context (full text with no redaction).
|
||||||
|
|
||||||
|
Call run_tagger() first to populate this property.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted prompt string with full private content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If run_tagger() hasn't been called
|
||||||
|
"""
|
||||||
|
if self._private_context is None:
|
||||||
|
raise ValueError("No context available. Call run_tagger() first.")
|
||||||
|
return self._private_context
|
||||||
|
|
||||||
|
@property
|
||||||
|
def public_context(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the public context (text with private phrases redacted).
|
||||||
|
|
||||||
|
Call run_tagger() first to populate this property.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted prompt string with redacted content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If run_tagger() hasn't been called
|
||||||
|
"""
|
||||||
|
if self._public_context is None:
|
||||||
|
raise ValueError("No context available. Call run_tagger() first.")
|
||||||
|
return self._public_context
|
||||||
|
|
||||||
|
def _build_contexts(self):
|
||||||
|
"""
|
||||||
|
Build private and public contexts from stored messages.
|
||||||
|
|
||||||
|
This is used when run_tagger() hasn't been called, providing
|
||||||
|
a simple full-message redaction fallback.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (private_messages, public_messages) for apply_chat_template.
|
||||||
|
"""
|
||||||
|
private_msgs = []
|
||||||
|
public_msgs = []
|
||||||
|
|
||||||
|
for msg in self._messages:
|
||||||
|
private_msgs.append({"role": msg["role"], "content": msg["content"]})
|
||||||
|
if msg["is_private"]:
|
||||||
|
# Redact entire content with placeholder
|
||||||
|
public_msgs.append({"role": msg["role"], "content": self.placeholder})
|
||||||
|
else:
|
||||||
|
public_msgs.append({"role": msg["role"], "content": msg["content"]})
|
||||||
|
|
||||||
|
return private_msgs, public_msgs
|
||||||
|
|
||||||
|
def get_context_text(self) -> str:
|
||||||
|
"""
|
||||||
|
Get formatted context text using tokenizer's chat template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted prompt string with special tokens
|
||||||
|
"""
|
||||||
|
msgs = [{"role": msg["role"], "content": msg["content"]} for msg in self._messages]
|
||||||
|
|
||||||
|
return self.tokenizer.apply_chat_template(
|
||||||
|
msgs,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
private_context: Optional[str] = None,
|
||||||
|
public_context: Optional[str] = None,
|
||||||
|
alpha: float = 2.0,
|
||||||
|
beta: float = 0.5,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
debug: bool = False
|
||||||
|
) -> Dict[str, Union[str, dict]]:
|
||||||
|
"""
|
||||||
|
Generate text using DP-Fusion mixing of private and public distributions.
|
||||||
|
|
||||||
|
Can be called in two ways:
|
||||||
|
1. **With stored messages** (via add_message): `generate(alpha=2.0, beta=0.5)`
|
||||||
|
2. **With explicit contexts**: `generate(private_context="...", public_context="...")`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
private_context: The full sensitive document text (optional if using messages)
|
||||||
|
public_context: The redacted document text (optional if using messages)
|
||||||
|
alpha: Renyi divergence order, must be > 1 (default: 2.0)
|
||||||
|
beta: Divergence threshold - lower = more privacy (default: 0.5)
|
||||||
|
Internal bound is alpha * beta per the paper notation.
|
||||||
|
temperature: Softmax temperature for sampling (default: 1.0)
|
||||||
|
max_new_tokens: Override max tokens for this call (optional)
|
||||||
|
debug: Enable debug printing (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- "text": Generated text (str)
|
||||||
|
- "lambdas": Per-step lambda values per group (dict)
|
||||||
|
- "divergences": Per-step divergence values per group (dict)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no context is available (neither messages nor explicit contexts)
|
||||||
|
"""
|
||||||
|
if private_context is None and public_context is None:
|
||||||
|
# Check if run_tagger() was called - use pre-computed tokens directly
|
||||||
|
if self._private_tokens is not None:
|
||||||
|
private_tokens = self._private_tokens
|
||||||
|
public_tokens = self._public_tokens
|
||||||
|
else:
|
||||||
|
# Use stored messages with default _build_contexts behavior
|
||||||
|
if not self._messages:
|
||||||
|
raise ValueError(
|
||||||
|
"No messages added. Use add_message() or provide "
|
||||||
|
"private_context/public_context."
|
||||||
|
)
|
||||||
|
|
||||||
|
private_msgs, public_msgs = self._build_contexts()
|
||||||
|
|
||||||
|
private_prompt = self.tokenizer.apply_chat_template(
|
||||||
|
private_msgs,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
public_prompt = self.tokenizer.apply_chat_template(
|
||||||
|
public_msgs,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
private_tokens = self.tokenizer.encode(private_prompt, return_tensors="pt")[0]
|
||||||
|
public_tokens = self.tokenizer.encode(public_prompt, return_tensors="pt")[0]
|
||||||
|
else:
|
||||||
|
# Use provided contexts
|
||||||
|
private_prompt = format_prompt_new_template(
|
||||||
|
self.tokenizer,
|
||||||
|
private_context,
|
||||||
|
self.placeholder
|
||||||
|
)
|
||||||
|
public_prompt = format_prompt_new_template(
|
||||||
|
self.tokenizer,
|
||||||
|
public_context,
|
||||||
|
self.placeholder
|
||||||
|
)
|
||||||
|
private_tokens = self.tokenizer.encode(private_prompt, return_tensors="pt")[0]
|
||||||
|
public_tokens = self.tokenizer.encode(public_prompt, return_tensors="pt")[0]
|
||||||
|
|
||||||
|
# Create token groups dict
|
||||||
|
# "PUBLIC" is the redacted version, "PRIVATE" is the full sensitive version
|
||||||
|
token_ids_groups = {
|
||||||
|
"PUBLIC": public_tokens,
|
||||||
|
"PRIVATE": private_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
# Beta dict for the private group
|
||||||
|
# Paper notation: D_alpha <= alpha * beta, so internal bound = alpha * beta
|
||||||
|
internal_beta = alpha * beta
|
||||||
|
beta_dict = {"PRIVATE": internal_beta}
|
||||||
|
|
||||||
|
# Determine max tokens
|
||||||
|
tokens_to_generate = max_new_tokens if max_new_tokens else self.max_tokens
|
||||||
|
|
||||||
|
# Run DP-Fusion generation
|
||||||
|
generated_text, lambdas, divergences = dp_fusion_groups_incremental(
|
||||||
|
token_ids_groups=token_ids_groups,
|
||||||
|
beta_dict=beta_dict,
|
||||||
|
alpha=alpha,
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=tokens_to_generate,
|
||||||
|
debug_mode=debug
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": generated_text,
|
||||||
|
"lambdas": lambdas,
|
||||||
|
"divergences": divergences
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_from_tokens(
|
||||||
|
self,
|
||||||
|
private_tokens: torch.Tensor,
|
||||||
|
public_tokens: torch.Tensor,
|
||||||
|
alpha: float = 2.0,
|
||||||
|
beta: float = 0.5,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
debug: bool = False
|
||||||
|
) -> Dict[str, Union[str, dict]]:
|
||||||
|
"""
|
||||||
|
Generate text from pre-tokenized inputs.
|
||||||
|
|
||||||
|
This is useful when you want more control over tokenization
|
||||||
|
or are processing batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
private_tokens: Token IDs for private context (1D tensor)
|
||||||
|
public_tokens: Token IDs for public/redacted context (1D tensor)
|
||||||
|
alpha: Renyi divergence order (default: 2.0)
|
||||||
|
beta: Divergence threshold (default: 0.5)
|
||||||
|
temperature: Softmax temperature (default: 1.0)
|
||||||
|
max_new_tokens: Override max tokens (optional)
|
||||||
|
debug: Enable debug printing (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Same as generate()
|
||||||
|
"""
|
||||||
|
token_ids_groups = {
|
||||||
|
"PUBLIC": public_tokens,
|
||||||
|
"PRIVATE": private_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
# Paper notation: D_alpha <= alpha * beta, so internal bound = alpha * beta
|
||||||
|
internal_beta = alpha * beta
|
||||||
|
beta_dict = {"PRIVATE": internal_beta}
|
||||||
|
|
||||||
|
tokens_to_generate = max_new_tokens if max_new_tokens else self.max_tokens
|
||||||
|
|
||||||
|
generated_text, lambdas, divergences = dp_fusion_groups_incremental(
|
||||||
|
token_ids_groups=token_ids_groups,
|
||||||
|
beta_dict=beta_dict,
|
||||||
|
alpha=alpha,
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=tokens_to_generate,
|
||||||
|
debug_mode=debug
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": generated_text,
|
||||||
|
"lambdas": lambdas,
|
||||||
|
"divergences": divergences
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dp_text(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
private_context: str,
|
||||||
|
public_context: str,
|
||||||
|
alpha: float = 2.0,
|
||||||
|
beta: float = 0.5,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_new_tokens: int = 100,
|
||||||
|
debug: bool = False
|
||||||
|
) -> Dict[str, Union[str, dict]]:
|
||||||
|
"""
|
||||||
|
Convenience function for one-off DP-Fusion generation.
|
||||||
|
|
||||||
|
This is a shortcut that creates a temporary DPFusion instance
|
||||||
|
and generates text in one call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: HuggingFace CausalLM model
|
||||||
|
tokenizer: HuggingFace tokenizer
|
||||||
|
private_context: Full sensitive document text
|
||||||
|
public_context: Redacted document text with placeholders
|
||||||
|
alpha: Renyi divergence order (default: 2.0)
|
||||||
|
beta: Divergence threshold - paper notation where bound = alpha * beta (default: 0.5)
|
||||||
|
temperature: Softmax temperature (default: 1.0)
|
||||||
|
max_new_tokens: Max tokens to generate (default: 100)
|
||||||
|
debug: Enable debug printing (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {"text": str, "lambdas": dict, "divergences": dict}
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
>>> from dp_fusion_lib import generate_dp_text
|
||||||
|
>>>
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||||
|
>>>
|
||||||
|
>>> output = generate_dp_text(
|
||||||
|
... model=model,
|
||||||
|
... tokenizer=tokenizer,
|
||||||
|
... private_context="John Doe's SSN is 123-45-6789.",
|
||||||
|
... public_context="_'s SSN is _.",
|
||||||
|
... alpha=2.0,
|
||||||
|
... beta=0.1
|
||||||
|
... )
|
||||||
|
>>> print(output["text"])
|
||||||
|
"""
|
||||||
|
dpf = DPFusion(model=model, tokenizer=tokenizer)
|
||||||
|
return dpf.generate(
|
||||||
|
private_context=private_context,
|
||||||
|
public_context=public_context,
|
||||||
|
alpha=alpha,
|
||||||
|
beta=beta,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
debug=debug
|
||||||
|
)
|
||||||
143
src/dp_fusion_lib/epsilon.py
Normal file
143
src/dp_fusion_lib/epsilon.py
Normal file
|
|
@ -0,0 +1,143 @@
|
||||||
|
"""
|
||||||
|
Epsilon computation for differential privacy guarantees.
|
||||||
|
|
||||||
|
This module provides functions to compute (ε, δ)-DP guarantees from
|
||||||
|
per-step Rényi divergences, following the theory in:
|
||||||
|
|
||||||
|
Thareja et al. "DP-Fusion: Token-Level Differentially Private
|
||||||
|
Inference for Large Language Models" (arXiv:2507.04531)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
|
||||||
|
def compute_epsilon_single_group(
|
||||||
|
divergences: List[float],
|
||||||
|
alpha: float,
|
||||||
|
delta: float,
|
||||||
|
beta: float = None
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute (ε, δ)-DP guarantee for a single private group.
|
||||||
|
|
||||||
|
For a single group (N=1), the per-step RDP formula simplifies to:
|
||||||
|
eps_step = 4 * β_t
|
||||||
|
|
||||||
|
where β_t = divergence_t / α (paper notation).
|
||||||
|
|
||||||
|
Total epsilon:
|
||||||
|
ε = (4/α) * Σ(divergences) + log(1/δ)/(α-1)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
divergences: List of per-step D_α values (bounded by α·β internally).
|
||||||
|
alpha: Rényi order (>1).
|
||||||
|
delta: Target δ in (ε, δ)-DP.
|
||||||
|
beta: Paper's β (where internal bound = α·β). If provided,
|
||||||
|
also computes theoretical epsilon.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with:
|
||||||
|
- "empirical": ε computed from actual divergences
|
||||||
|
- "theoretical": ε assuming divergence = α·β at each step (if beta provided)
|
||||||
|
- "T": number of tokens generated
|
||||||
|
"""
|
||||||
|
if alpha <= 1.0:
|
||||||
|
raise ValueError("alpha must be > 1")
|
||||||
|
if delta <= 0.0 or delta >= 1.0:
|
||||||
|
raise ValueError("delta must be in (0,1)")
|
||||||
|
|
||||||
|
T = len(divergences)
|
||||||
|
log_delta_term = math.log(1.0 / delta) / (alpha - 1.0)
|
||||||
|
|
||||||
|
# Empirical: divergences are bounded by α·β, so β_t = d/α
|
||||||
|
# eps_t = 4 * β_t = 4 * (d / α)
|
||||||
|
empirical_rdp = sum(4.0 * (d / alpha) for d in divergences)
|
||||||
|
epsilon_empirical = empirical_rdp + log_delta_term
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"empirical": epsilon_empirical,
|
||||||
|
"T": T
|
||||||
|
}
|
||||||
|
|
||||||
|
# Theoretical: worst-case is divergence = α·β each step
|
||||||
|
# β_t = β, so eps_t = 4 * β
|
||||||
|
if beta is not None:
|
||||||
|
theoretical_rdp = T * 4.0 * beta
|
||||||
|
epsilon_theoretical = theoretical_rdp + log_delta_term
|
||||||
|
result["theoretical"] = epsilon_theoretical
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def compute_dp_epsilon(
|
||||||
|
divergences: Dict[str, List[float]],
|
||||||
|
alpha: float,
|
||||||
|
delta: float,
|
||||||
|
mode: str = "global"
|
||||||
|
) -> Union[float, Dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Compute (ε, δ)-DP guarantee from per-step Rényi divergences.
|
||||||
|
|
||||||
|
Supports multi-group privacy with either global or per-group guarantees.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
divergences: Mapping group_name -> list of β_t values (length=T).
|
||||||
|
The key "PUBLIC" (if present) will be ignored.
|
||||||
|
alpha: Rényi order (>1).
|
||||||
|
delta: Target δ in (ε, δ)-DP.
|
||||||
|
mode: "global" for one ε protecting all groups (worst-case per step),
|
||||||
|
"per_group" for a dict of ε_i per group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If mode == "global": float ε.
|
||||||
|
If mode == "per_group": dict of {group: ε_i}.
|
||||||
|
"""
|
||||||
|
if alpha <= 1.0:
|
||||||
|
raise ValueError("alpha must be > 1")
|
||||||
|
if delta <= 0.0 or delta >= 1.0:
|
||||||
|
raise ValueError("delta must be in (0,1)")
|
||||||
|
|
||||||
|
# Filter out PUBLIC and ensure at least one private group
|
||||||
|
priv_div = {g: lst for g, lst in divergences.items() if g != "PUBLIC"}
|
||||||
|
if not priv_div:
|
||||||
|
raise ValueError("No private groups provided")
|
||||||
|
|
||||||
|
# Ensure all groups have same number of steps
|
||||||
|
step_counts = {len(lst) for lst in priv_div.values()}
|
||||||
|
if len(step_counts) != 1:
|
||||||
|
raise ValueError(f"Divergence lists have unequal lengths: {step_counts}")
|
||||||
|
|
||||||
|
T = step_counts.pop()
|
||||||
|
N = len(priv_div)
|
||||||
|
|
||||||
|
def eps_step(beta: float) -> float:
|
||||||
|
"""Compute per-step RDP cost."""
|
||||||
|
if beta is None:
|
||||||
|
raise ValueError("Found None in divergence list")
|
||||||
|
arg = (N - 1.0) / N + (1.0 / N) * math.exp((alpha - 1.0) * 4.0 * beta)
|
||||||
|
if arg <= 0.0:
|
||||||
|
raise ValueError(f"Non-positive argument for log: {arg}")
|
||||||
|
return (1.0 / (alpha - 1.0)) * math.log(arg)
|
||||||
|
|
||||||
|
if mode == "global":
|
||||||
|
total_rdp = 0.0
|
||||||
|
for t in range(T):
|
||||||
|
betas = [div_list[t] for div_list in priv_div.values()]
|
||||||
|
beta_max = max(betas)
|
||||||
|
total_rdp += eps_step(beta_max)
|
||||||
|
epsilon = total_rdp + math.log(1.0 / delta) / (alpha - 1.0)
|
||||||
|
return epsilon
|
||||||
|
|
||||||
|
elif mode == "per_group":
|
||||||
|
epsilons = {}
|
||||||
|
for group, div_list in priv_div.items():
|
||||||
|
total_rdp_g = 0.0
|
||||||
|
for beta_t in div_list:
|
||||||
|
total_rdp_g += eps_step(beta_t)
|
||||||
|
eps_group = total_rdp_g + math.log(1.0 / delta) / (alpha - 1.0)
|
||||||
|
epsilons[group] = eps_group
|
||||||
|
return epsilons
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("mode must be 'global' or 'per_group'")
|
||||||
172
src/dp_fusion_lib/tagger.py
Normal file
172
src/dp_fusion_lib/tagger.py
Normal file
|
|
@ -0,0 +1,172 @@
|
||||||
|
"""
|
||||||
|
Private phrase extraction using the Document Privacy API.
|
||||||
|
|
||||||
|
This module provides the Tagger class for automatically identifying
|
||||||
|
sensitive/private phrases in documents using an external API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def find_phrase_offsets(text: str, phrases: List[str]) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Find all occurrences of phrases in text and return [start, end] offsets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The full text to search in
|
||||||
|
phrases: List of phrases to find
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of [start_char, end_char] offsets for all phrase occurrences
|
||||||
|
"""
|
||||||
|
offsets = []
|
||||||
|
for phrase in phrases:
|
||||||
|
start = 0
|
||||||
|
while True:
|
||||||
|
idx = text.find(phrase, start)
|
||||||
|
if idx == -1:
|
||||||
|
break
|
||||||
|
offsets.append([idx, idx + len(phrase)])
|
||||||
|
start = idx + 1
|
||||||
|
return offsets
|
||||||
|
|
||||||
|
|
||||||
|
class Tagger:
|
||||||
|
"""
|
||||||
|
Private phrase extraction using the Document Privacy API.
|
||||||
|
|
||||||
|
The Tagger uses an external API to identify sensitive information
|
||||||
|
in documents. It supports different extraction models and document
|
||||||
|
types (constitutions).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> tagger = Tagger(api_key="sk_...")
|
||||||
|
>>> tagger.set_model("llama3.1-8b")
|
||||||
|
>>> tagger.set_constitution("HEALTH")
|
||||||
|
>>> phrases = tagger.extract_private_phrases("John Doe visited on 01/01/1990.")
|
||||||
|
>>> print(phrases)
|
||||||
|
['John Doe', '01/01/1990']
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for the Document Privacy API
|
||||||
|
verbose: If True, log input/output of API calls (default: False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, verbose: bool = False):
|
||||||
|
"""
|
||||||
|
Initialize the Tagger with an API key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for the Document Privacy API
|
||||||
|
verbose: If True, log input/output of API calls (default: False)
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_base = "https://api.documentprivacy.com"
|
||||||
|
self._model = "llama3.1-8b"
|
||||||
|
self._constitution = "HEALTH"
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
def set_model(self, model: str):
|
||||||
|
"""
|
||||||
|
Set the extraction model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model identifier (e.g., 'llama3.1-8b')
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def set_constitution(self, constitution: str):
|
||||||
|
"""
|
||||||
|
Set the document type/constitution.
|
||||||
|
|
||||||
|
Available constitutions depend on the API. Common options:
|
||||||
|
- 'HEALTH': Medical/healthcare documents
|
||||||
|
- 'FINANCE': Financial documents
|
||||||
|
- 'LEGAL': Legal documents
|
||||||
|
|
||||||
|
Args:
|
||||||
|
constitution: Document type identifier
|
||||||
|
"""
|
||||||
|
self._constitution = constitution
|
||||||
|
|
||||||
|
def get_available_models(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of available models from the API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available model identifiers
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
requests.RequestException: If API call fails
|
||||||
|
"""
|
||||||
|
url = f"{self.api_base}/models"
|
||||||
|
headers = {
|
||||||
|
"X-API-KEY": self.api_key
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[Tagger] GET {url}")
|
||||||
|
|
||||||
|
response = requests.get(url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[Tagger] Response: {result}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def extract_private_phrases(self, document: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extract private phrases from a document using the API.
|
||||||
|
|
||||||
|
This method sends the document to the Document Privacy API,
|
||||||
|
which uses the configured model and constitution to identify
|
||||||
|
sensitive information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: The text document to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of detected private/sensitive phrases
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
requests.RequestException: If API call fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> tagger = Tagger(api_key="sk_...")
|
||||||
|
>>> phrases = tagger.extract_private_phrases(
|
||||||
|
... "Patient John Smith, DOB 05/15/1980, was diagnosed with diabetes."
|
||||||
|
... )
|
||||||
|
>>> print(phrases)
|
||||||
|
['John Smith', '05/15/1980', 'diabetes']
|
||||||
|
"""
|
||||||
|
url = f"{self.api_base}/extract"
|
||||||
|
headers = {
|
||||||
|
"X-API-KEY": self.api_key,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"document": document,
|
||||||
|
"model": self._model,
|
||||||
|
"type": self._constitution
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[Tagger] POST {url}")
|
||||||
|
print(f"[Tagger] Input document: {document[:200]}{'...' if len(document) > 200 else ''}")
|
||||||
|
print(f"[Tagger] Model: {self._model}, Constitution: {self._constitution}")
|
||||||
|
|
||||||
|
response = requests.post(url, json=payload, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
private_phrases = data.get("private_phrases", [])
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[Tagger] Extracted phrases: {private_phrases}")
|
||||||
|
|
||||||
|
return private_phrases
|
||||||
394
src/dp_fusion_lib/utils.py
Normal file
394
src/dp_fusion_lib/utils.py
Normal file
|
|
@ -0,0 +1,394 @@
|
||||||
|
"""
|
||||||
|
Utility functions for DP-Fusion.
|
||||||
|
|
||||||
|
This module contains the core algorithmic components:
|
||||||
|
- Rényi divergence computation
|
||||||
|
- Lambda search for privacy-utility tradeoff
|
||||||
|
- Token replacement for redaction
|
||||||
|
- Incremental DP-Fusion generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from bisect import bisect_right
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# Default beta values for different entity types
|
||||||
|
DEFAULT_BETA_DICT = {
|
||||||
|
"PERSON": 0.5,
|
||||||
|
"CODE": 0.5,
|
||||||
|
"LOC": 0.5,
|
||||||
|
"ORG": 0.5,
|
||||||
|
"DEM": 0.5,
|
||||||
|
"DATETIME": 0.5,
|
||||||
|
"QUANTITY": 0.5,
|
||||||
|
"MISC": 0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Entity types available
|
||||||
|
ENTITY_TYPES = [
|
||||||
|
"PERSON", "CODE", "LOC", "ORG", "DEM",
|
||||||
|
"DATETIME", "QUANTITY", "MISC"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default placeholder token for redaction
|
||||||
|
PLACEHOLDER_TOKEN = "_"
|
||||||
|
|
||||||
|
|
||||||
|
def replace_sequences_with_placeholder_fast(
|
||||||
|
text: str,
|
||||||
|
word_offsets: List[List[int]],
|
||||||
|
placeholder: str,
|
||||||
|
tokenizer
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Replace tokens falling within provided word offset ranges with placeholder tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Original text string
|
||||||
|
word_offsets: List of [start_char, end_char] offsets for words to replace
|
||||||
|
placeholder: Placeholder token to use (e.g., "_")
|
||||||
|
tokenizer: Tokenizer that returns 'input_ids' and 'offset_mapping'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token IDs with specified words replaced by placeholder token ID
|
||||||
|
"""
|
||||||
|
placeholder_id = tokenizer.convert_tokens_to_ids(placeholder)
|
||||||
|
|
||||||
|
encoded = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||||
|
input_ids = encoded['input_ids']
|
||||||
|
offsets = encoded['offset_mapping']
|
||||||
|
|
||||||
|
word_offsets = sorted(word_offsets, key=lambda x: x[0])
|
||||||
|
starts = [wo[0] for wo in word_offsets]
|
||||||
|
ends = [wo[1] for wo in word_offsets]
|
||||||
|
|
||||||
|
for i, (t_start, t_end) in enumerate(offsets):
|
||||||
|
if t_start == t_end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
idx = bisect_right(starts, t_end)
|
||||||
|
|
||||||
|
while idx > 0:
|
||||||
|
idx -= 1
|
||||||
|
w_start, w_end = starts[idx], ends[idx]
|
||||||
|
|
||||||
|
if w_end > t_start and w_start < t_end:
|
||||||
|
input_ids[i] = placeholder_id
|
||||||
|
break
|
||||||
|
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
|
def compute_renyi_divergence_clipped_symmetric(
|
||||||
|
p: torch.Tensor,
|
||||||
|
q: torch.Tensor,
|
||||||
|
alpha: float,
|
||||||
|
eps: float = 1e-10
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute symmetric Rényi divergence D↔_α(p‖q) = max{D_α(p‖q), D_α(q‖p)}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p: Probability vector (last dimension is the support)
|
||||||
|
q: Probability vector (last dimension is the support)
|
||||||
|
alpha: Rényi order (must be > 1)
|
||||||
|
eps: Small constant for numerical stability
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
D↔_α(p, q) with shape p.shape[:-1]
|
||||||
|
"""
|
||||||
|
if alpha <= 1.0:
|
||||||
|
raise ValueError("alpha must be > 1")
|
||||||
|
|
||||||
|
p = p.float().clamp_min(eps)
|
||||||
|
q = q.float().clamp_min(eps)
|
||||||
|
|
||||||
|
# Forward direction D_α(p‖q)
|
||||||
|
term_pq = torch.sum(p.pow(alpha) * q.pow(1.0 - alpha), dim=-1).clamp_min(eps)
|
||||||
|
div_pq = (1.0 / (alpha - 1.0)) * torch.log(term_pq)
|
||||||
|
|
||||||
|
# Reverse direction D_α(q‖p)
|
||||||
|
term_qp = torch.sum(q.pow(alpha) * p.pow(1.0 - alpha), dim=-1).clamp_min(eps)
|
||||||
|
div_qp = (1.0 / (alpha - 1.0)) * torch.log(term_qp)
|
||||||
|
|
||||||
|
return torch.maximum(div_pq, div_qp)
|
||||||
|
|
||||||
|
|
||||||
|
def find_lambda(
|
||||||
|
p_priv: torch.Tensor,
|
||||||
|
p_pub: torch.Tensor,
|
||||||
|
alpha: float,
|
||||||
|
beta: float,
|
||||||
|
debug_mode: bool = False,
|
||||||
|
max_iter: int = 20,
|
||||||
|
tol: float = 1e-6
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Binary search for λ in [0,1] that satisfies the divergence bound.
|
||||||
|
|
||||||
|
Finds λ such that:
|
||||||
|
D_α((1-λ)*p_pub + λ*p_priv || p_pub) <= beta
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p_priv: Private distribution (already softmaxed & temperature-scaled)
|
||||||
|
p_pub: Public distribution (already softmaxed & temperature-scaled)
|
||||||
|
alpha: Rényi order (> 1)
|
||||||
|
beta: Divergence threshold (>= 0)
|
||||||
|
debug_mode: Whether to print debug information
|
||||||
|
max_iter: Maximum binary search iterations
|
||||||
|
tol: Tolerance for convergence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (lambda_value, divergence)
|
||||||
|
"""
|
||||||
|
if beta <= 0:
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
div_at_1 = compute_renyi_divergence_clipped_symmetric(p_priv, p_pub, alpha)
|
||||||
|
|
||||||
|
if div_at_1 <= beta:
|
||||||
|
return 1.0, div_at_1.item() if hasattr(div_at_1, 'item') else div_at_1
|
||||||
|
|
||||||
|
left, right = 0.0, 1.0
|
||||||
|
for _ in range(max_iter):
|
||||||
|
mid = 0.5 * (left + right)
|
||||||
|
mixture = mid * p_priv + (1 - mid) * p_pub
|
||||||
|
div = compute_renyi_divergence_clipped_symmetric(mixture, p_pub, alpha)
|
||||||
|
|
||||||
|
if div > beta:
|
||||||
|
right = mid
|
||||||
|
else:
|
||||||
|
left = mid
|
||||||
|
|
||||||
|
if (right - left) < tol:
|
||||||
|
break
|
||||||
|
|
||||||
|
final_lambda = left
|
||||||
|
mixture = final_lambda * p_priv + (1 - final_lambda) * p_pub
|
||||||
|
final_div = compute_renyi_divergence_clipped_symmetric(mixture, p_pub, alpha)
|
||||||
|
|
||||||
|
return final_lambda, final_div.item() if hasattr(final_div, 'item') else final_div
|
||||||
|
|
||||||
|
|
||||||
|
def dp_fusion_groups_incremental(
|
||||||
|
token_ids_groups: Dict[str, torch.Tensor],
|
||||||
|
beta_dict: Dict[str, float],
|
||||||
|
alpha: float,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_new_tokens: int = 50,
|
||||||
|
debug_mode: bool = False,
|
||||||
|
device_map=None,
|
||||||
|
batch_override=None
|
||||||
|
) -> Tuple[str, Dict[str, List[float]], Dict[str, List[float]]]:
|
||||||
|
"""
|
||||||
|
DP-Fusion generation with incremental decoding using KV-cache.
|
||||||
|
|
||||||
|
Supports multi-group privacy where each group can have different β thresholds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_groups: Dict mapping group names to token ID tensors.
|
||||||
|
Must include "PUBLIC" key for the redacted version.
|
||||||
|
beta_dict: Mapping from group name to β threshold.
|
||||||
|
alpha: Rényi divergence order (>1).
|
||||||
|
model: HuggingFace CausalLM model.
|
||||||
|
tokenizer: Corresponding tokenizer.
|
||||||
|
temperature: Temperature for scaling logits.
|
||||||
|
max_new_tokens: Maximum tokens to generate.
|
||||||
|
debug_mode: Whether to print debug information.
|
||||||
|
device_map: Optional device map.
|
||||||
|
batch_override: Optional batch settings override.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (generated_text, lambdas_dict, divergences_dict)
|
||||||
|
"""
|
||||||
|
eos_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
going_lambdas: Dict[str, List[float]] = {}
|
||||||
|
going_divergence: Dict[str, List[float]] = {}
|
||||||
|
|
||||||
|
if "PUBLIC" not in token_ids_groups:
|
||||||
|
raise ValueError("Must have a 'PUBLIC' key in token_ids_groups.")
|
||||||
|
|
||||||
|
private_groups = [g for g in token_ids_groups if g != "PUBLIC"]
|
||||||
|
if not private_groups:
|
||||||
|
raise ValueError("No private groups besides 'PUBLIC' – need at least one for DP-Fusion.")
|
||||||
|
|
||||||
|
if device_map:
|
||||||
|
first_device = next(iter(device_map.values()))
|
||||||
|
device = torch.device(f"cuda:{first_device}" if isinstance(first_device, int) else first_device)
|
||||||
|
else:
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
for group, tokens in token_ids_groups.items():
|
||||||
|
if not isinstance(tokens, torch.Tensor):
|
||||||
|
tokens = torch.tensor(tokens, dtype=torch.long)
|
||||||
|
token_ids_groups[group] = tokens.to(device)
|
||||||
|
|
||||||
|
if debug_mode:
|
||||||
|
print(f"[DP-Fusion] Starting generation. Private groups: {private_groups}")
|
||||||
|
for g in token_ids_groups:
|
||||||
|
print(f"[Initial] Prefix shape for group {g}: {token_ids_groups[g].shape}")
|
||||||
|
|
||||||
|
group_order = list(token_ids_groups.keys())
|
||||||
|
num_groups = len(group_order)
|
||||||
|
|
||||||
|
# Initial pass: process each group's full prefix to build cache
|
||||||
|
prefix_batches = [token_ids_groups[g] for g in group_order]
|
||||||
|
input_batch = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
prefix_batches, batch_first=True, padding_value=tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if debug_mode:
|
||||||
|
print(f"[Initial] Input batch shape: {input_batch.shape}")
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True):
|
||||||
|
outputs = model(input_ids=input_batch, use_cache=True, past_key_values=None)
|
||||||
|
|
||||||
|
past = outputs.past_key_values
|
||||||
|
last_logits = outputs.logits[:, input_batch.size(1) - 1, :]
|
||||||
|
group_logits = {g: last_logits[i] for i, g in enumerate(group_order)}
|
||||||
|
|
||||||
|
pub_scaled = group_logits["PUBLIC"] / temperature
|
||||||
|
p_pub = F.softmax(pub_scaled, dim=-1)
|
||||||
|
|
||||||
|
p_priv_dict = {}
|
||||||
|
for pg in private_groups:
|
||||||
|
priv_scaled = group_logits[pg] / temperature
|
||||||
|
p_priv_dict[pg] = F.softmax(priv_scaled, dim=-1)
|
||||||
|
|
||||||
|
# DP-Fusion: find lambdas and form fused distribution
|
||||||
|
lambdas = {}
|
||||||
|
for pg in private_groups:
|
||||||
|
beta_val = beta_dict.get(pg)
|
||||||
|
lam_pg, got_div = find_lambda(p_priv_dict[pg], p_pub, alpha, beta_val, debug_mode=debug_mode)
|
||||||
|
lambdas[pg] = lam_pg
|
||||||
|
if debug_mode:
|
||||||
|
print(f"[Initial] Selected Lambda for group {pg}: {lam_pg}, Divergence: {got_div}")
|
||||||
|
|
||||||
|
sum_out = torch.zeros_like(p_pub)
|
||||||
|
for pg in private_groups:
|
||||||
|
lam_g = lambdas[pg]
|
||||||
|
mix_g = lam_g * p_priv_dict[pg] + (1 - lam_g) * p_pub
|
||||||
|
sum_out += mix_g
|
||||||
|
p_out_avg = sum_out / len(private_groups)
|
||||||
|
|
||||||
|
next_token = torch.multinomial(p_out_avg, 1).item()
|
||||||
|
|
||||||
|
if debug_mode:
|
||||||
|
token_str = tokenizer.decode([next_token])
|
||||||
|
print(f"[Initial] Sampled token '{token_str}' (ID={next_token})")
|
||||||
|
|
||||||
|
for g in group_order:
|
||||||
|
token_ids_groups[g] = torch.cat(
|
||||||
|
[token_ids_groups[g], torch.tensor([next_token], device=device)], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Incremental loop
|
||||||
|
for step in range(1, max_new_tokens):
|
||||||
|
new_tokens_batch = torch.tensor([[next_token]] * num_groups, device=device)
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True):
|
||||||
|
outputs = model(input_ids=new_tokens_batch, past_key_values=past, use_cache=True)
|
||||||
|
|
||||||
|
past = outputs.past_key_values
|
||||||
|
last_logits = outputs.logits[:, -1, :]
|
||||||
|
group_logits = {g: last_logits[i] for i, g in enumerate(group_order)}
|
||||||
|
|
||||||
|
pub_scaled = group_logits["PUBLIC"] / temperature
|
||||||
|
p_pub = F.softmax(pub_scaled, dim=-1)
|
||||||
|
|
||||||
|
p_priv_dict = {}
|
||||||
|
for pg in private_groups:
|
||||||
|
priv_scaled = group_logits[pg] / temperature
|
||||||
|
p_priv_dict[pg] = F.softmax(priv_scaled, dim=-1)
|
||||||
|
|
||||||
|
lambdas = {}
|
||||||
|
for pg in private_groups:
|
||||||
|
beta_val = beta_dict.get(pg)
|
||||||
|
lam_pg, div_got = find_lambda(p_priv_dict[pg], p_pub, alpha, beta_val, debug_mode=debug_mode)
|
||||||
|
lambdas[pg] = lam_pg
|
||||||
|
|
||||||
|
if debug_mode:
|
||||||
|
print(f"[Step {step}] Selected Lambda for group {pg}: {lam_pg}, Divergence: {div_got}")
|
||||||
|
|
||||||
|
if pg not in going_lambdas:
|
||||||
|
going_lambdas[pg] = []
|
||||||
|
going_divergence[pg] = []
|
||||||
|
going_lambdas[pg].append(lam_pg)
|
||||||
|
going_divergence[pg].append(div_got)
|
||||||
|
|
||||||
|
sum_out = torch.zeros_like(p_pub)
|
||||||
|
for pg in private_groups:
|
||||||
|
mix_g = lambdas[pg] * p_priv_dict[pg] + (1 - lambdas[pg]) * p_pub
|
||||||
|
sum_out += mix_g
|
||||||
|
p_out_avg = sum_out / len(private_groups)
|
||||||
|
|
||||||
|
next_token = torch.multinomial(p_out_avg, 1).item()
|
||||||
|
|
||||||
|
for g in group_order:
|
||||||
|
token_ids_groups[g] = torch.cat(
|
||||||
|
[token_ids_groups[g], torch.tensor([next_token], device=device)], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
del outputs, last_logits, group_logits
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if next_token == eos_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
final_text = tokenizer.decode(token_ids_groups["PUBLIC"], skip_special_tokens=True)
|
||||||
|
|
||||||
|
if debug_mode:
|
||||||
|
print("[DP-Fusion] Generation complete.")
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return final_text, going_lambdas, going_divergence
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt_new_template(tokenizer, private_doc: str, placeholder: str) -> str:
|
||||||
|
"""
|
||||||
|
Create a prompt using the HuggingFace chat template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: HuggingFace tokenizer
|
||||||
|
private_doc: Document to paraphrase (may contain placeholders)
|
||||||
|
placeholder: Placeholder character (e.g., "_")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted prompt for the model
|
||||||
|
"""
|
||||||
|
system_text = (
|
||||||
|
"You are given a passage that may contain placeholders (underscores) or incomplete data. "
|
||||||
|
"Your job is to produce a natural paraphrase. "
|
||||||
|
"Do not use any underscores or placeholders in your output. "
|
||||||
|
"If data is missing, just omit it or paraphrase gracefully. "
|
||||||
|
"Do not output anything except the paraphrase. "
|
||||||
|
"Make sure retain all information from the source document."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_task = (
|
||||||
|
f"Document:\n\n{private_doc}\n\n"
|
||||||
|
f"Paraphrase the above text. Whenever a placeholder i.e {placeholder} exists, you must completely ignore that information, "
|
||||||
|
f"as {placeholder} indicates redacted text. To ensure the generated text is as natural as possible, "
|
||||||
|
f"you must never output the {placeholder} themselves."
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"{system_text}\n\n{user_task}"},
|
||||||
|
{"role": "assistant", "content": "Sure. Here is the paraphrased document without underscores or placeholders:"},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Tests for dp-fusion-lib."""
|
||||||
27
tests/conftest.py
Normal file
27
tests/conftest.py
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""Pytest configuration and fixtures for dp-fusion-lib tests."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_divergences():
|
||||||
|
"""Sample divergence values for testing epsilon computation."""
|
||||||
|
return [0.05, 0.08, 0.03, 0.06, 0.04, 0.07, 0.02, 0.05, 0.09, 0.04]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def alpha():
|
||||||
|
"""Standard Renyi order for tests."""
|
||||||
|
return 2.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def beta():
|
||||||
|
"""Standard beta value for tests."""
|
||||||
|
return 0.1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def delta():
|
||||||
|
"""Standard delta value for tests."""
|
||||||
|
return 1e-5
|
||||||
227
tests/test_epsilon.py
Normal file
227
tests/test_epsilon.py
Normal file
|
|
@ -0,0 +1,227 @@
|
||||||
|
"""Tests for epsilon computation functions."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from dp_fusion_lib import compute_epsilon_single_group, compute_dp_epsilon
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeEpsilonSingleGroup:
|
||||||
|
"""Tests for compute_epsilon_single_group function."""
|
||||||
|
|
||||||
|
def test_basic_computation(self, sample_divergences, alpha, delta, beta):
|
||||||
|
"""Test basic epsilon computation returns expected structure."""
|
||||||
|
result = compute_epsilon_single_group(
|
||||||
|
divergences=sample_divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
beta=beta
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "empirical" in result
|
||||||
|
assert "T" in result
|
||||||
|
assert "theoretical" in result
|
||||||
|
assert result["T"] == len(sample_divergences)
|
||||||
|
assert result["empirical"] >= 0
|
||||||
|
assert result["theoretical"] >= 0
|
||||||
|
|
||||||
|
def test_empirical_less_than_theoretical(self, sample_divergences, alpha, delta, beta):
|
||||||
|
"""Test that empirical epsilon <= theoretical (worst-case)."""
|
||||||
|
result = compute_epsilon_single_group(
|
||||||
|
divergences=sample_divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
beta=beta
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empirical should be <= theoretical since theoretical is worst-case
|
||||||
|
assert result["empirical"] <= result["theoretical"] + 1e-9
|
||||||
|
|
||||||
|
def test_empty_divergences(self, alpha, delta, beta):
|
||||||
|
"""Test with empty divergence list."""
|
||||||
|
result = compute_epsilon_single_group(
|
||||||
|
divergences=[],
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
beta=beta
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["T"] == 0
|
||||||
|
# With no tokens, epsilon is just the log(1/delta) term
|
||||||
|
expected = math.log(1.0 / delta) / (alpha - 1.0)
|
||||||
|
assert abs(result["empirical"] - expected) < 1e-9
|
||||||
|
|
||||||
|
def test_alpha_validation(self, sample_divergences, delta):
|
||||||
|
"""Test that alpha <= 1 raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="alpha must be > 1"):
|
||||||
|
compute_epsilon_single_group(
|
||||||
|
divergences=sample_divergences,
|
||||||
|
alpha=1.0,
|
||||||
|
delta=delta
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="alpha must be > 1"):
|
||||||
|
compute_epsilon_single_group(
|
||||||
|
divergences=sample_divergences,
|
||||||
|
alpha=0.5,
|
||||||
|
delta=delta
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_delta_validation(self, sample_divergences, alpha):
|
||||||
|
"""Test that invalid delta raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="delta must be in"):
|
||||||
|
compute_epsilon_single_group(
|
||||||
|
divergences=sample_divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="delta must be in"):
|
||||||
|
compute_epsilon_single_group(
|
||||||
|
divergences=sample_divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_without_beta(self, sample_divergences, alpha, delta):
|
||||||
|
"""Test that theoretical is not computed when beta is not provided."""
|
||||||
|
result = compute_epsilon_single_group(
|
||||||
|
divergences=sample_divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
beta=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "empirical" in result
|
||||||
|
assert "T" in result
|
||||||
|
assert "theoretical" not in result
|
||||||
|
|
||||||
|
def test_higher_divergences_higher_epsilon(self, alpha, delta, beta):
|
||||||
|
"""Test that higher divergences lead to higher epsilon."""
|
||||||
|
low_div = [0.01, 0.02, 0.01]
|
||||||
|
high_div = [0.1, 0.2, 0.1]
|
||||||
|
|
||||||
|
result_low = compute_epsilon_single_group(low_div, alpha, delta, beta)
|
||||||
|
result_high = compute_epsilon_single_group(high_div, alpha, delta, beta)
|
||||||
|
|
||||||
|
assert result_low["empirical"] < result_high["empirical"]
|
||||||
|
|
||||||
|
def test_more_tokens_higher_epsilon(self, alpha, delta, beta):
|
||||||
|
"""Test that more tokens lead to higher epsilon."""
|
||||||
|
short_div = [0.05] * 10
|
||||||
|
long_div = [0.05] * 100
|
||||||
|
|
||||||
|
result_short = compute_epsilon_single_group(short_div, alpha, delta, beta)
|
||||||
|
result_long = compute_epsilon_single_group(long_div, alpha, delta, beta)
|
||||||
|
|
||||||
|
assert result_short["empirical"] < result_long["empirical"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeDpEpsilon:
|
||||||
|
"""Tests for compute_dp_epsilon function (multi-group)."""
|
||||||
|
|
||||||
|
def test_single_group_global(self, alpha, delta):
|
||||||
|
"""Test global mode with single group."""
|
||||||
|
divergences = {"GROUP1": [0.05, 0.06, 0.04]}
|
||||||
|
|
||||||
|
epsilon = compute_dp_epsilon(
|
||||||
|
divergences=divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
mode="global"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(epsilon, float)
|
||||||
|
assert epsilon > 0
|
||||||
|
|
||||||
|
def test_multi_group_global(self, alpha, delta):
|
||||||
|
"""Test global mode with multiple groups."""
|
||||||
|
divergences = {
|
||||||
|
"GROUP1": [0.05, 0.06, 0.04],
|
||||||
|
"GROUP2": [0.03, 0.08, 0.05]
|
||||||
|
}
|
||||||
|
|
||||||
|
epsilon = compute_dp_epsilon(
|
||||||
|
divergences=divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
mode="global"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(epsilon, float)
|
||||||
|
assert epsilon > 0
|
||||||
|
|
||||||
|
def test_per_group_mode(self, alpha, delta):
|
||||||
|
"""Test per_group mode returns dict."""
|
||||||
|
divergences = {
|
||||||
|
"GROUP1": [0.05, 0.06, 0.04],
|
||||||
|
"GROUP2": [0.03, 0.08, 0.05]
|
||||||
|
}
|
||||||
|
|
||||||
|
epsilons = compute_dp_epsilon(
|
||||||
|
divergences=divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
mode="per_group"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(epsilons, dict)
|
||||||
|
assert "GROUP1" in epsilons
|
||||||
|
assert "GROUP2" in epsilons
|
||||||
|
assert epsilons["GROUP1"] > 0
|
||||||
|
assert epsilons["GROUP2"] > 0
|
||||||
|
|
||||||
|
def test_public_group_ignored(self, alpha, delta):
|
||||||
|
"""Test that PUBLIC group is ignored."""
|
||||||
|
divergences = {
|
||||||
|
"PUBLIC": [0.99, 0.99, 0.99], # Should be ignored
|
||||||
|
"PRIVATE": [0.05, 0.06, 0.04]
|
||||||
|
}
|
||||||
|
|
||||||
|
epsilon = compute_dp_epsilon(
|
||||||
|
divergences=divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
mode="global"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should compute based on PRIVATE only
|
||||||
|
assert isinstance(epsilon, float)
|
||||||
|
assert epsilon > 0
|
||||||
|
|
||||||
|
def test_no_private_groups_error(self, alpha, delta):
|
||||||
|
"""Test error when only PUBLIC group exists."""
|
||||||
|
divergences = {"PUBLIC": [0.05, 0.06, 0.04]}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No private groups"):
|
||||||
|
compute_dp_epsilon(
|
||||||
|
divergences=divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_unequal_lengths_error(self, alpha, delta):
|
||||||
|
"""Test error when groups have different lengths."""
|
||||||
|
divergences = {
|
||||||
|
"GROUP1": [0.05, 0.06, 0.04],
|
||||||
|
"GROUP2": [0.03, 0.08] # Different length
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="unequal lengths"):
|
||||||
|
compute_dp_epsilon(
|
||||||
|
divergences=divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_invalid_mode_error(self, alpha, delta):
|
||||||
|
"""Test error for invalid mode."""
|
||||||
|
divergences = {"GROUP1": [0.05, 0.06, 0.04]}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="mode must be"):
|
||||||
|
compute_dp_epsilon(
|
||||||
|
divergences=divergences,
|
||||||
|
alpha=alpha,
|
||||||
|
delta=delta,
|
||||||
|
mode="invalid"
|
||||||
|
)
|
||||||
93
tests/test_tagger.py
Normal file
93
tests/test_tagger.py
Normal file
|
|
@ -0,0 +1,93 @@
|
||||||
|
"""Tests for Tagger class and phrase extraction."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from dp_fusion_lib import find_phrase_offsets
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindPhraseOffsets:
|
||||||
|
"""Tests for find_phrase_offsets function."""
|
||||||
|
|
||||||
|
def test_single_phrase(self):
|
||||||
|
"""Test finding a single phrase."""
|
||||||
|
text = "My name is John Smith and I live here."
|
||||||
|
phrases = ["John Smith"]
|
||||||
|
|
||||||
|
offsets = find_phrase_offsets(text, phrases)
|
||||||
|
|
||||||
|
assert len(offsets) == 1
|
||||||
|
assert text[offsets[0][0]:offsets[0][1]] == "John Smith"
|
||||||
|
|
||||||
|
def test_multiple_phrases(self):
|
||||||
|
"""Test finding multiple different phrases."""
|
||||||
|
text = "John Smith visited New York on Monday."
|
||||||
|
phrases = ["John Smith", "New York", "Monday"]
|
||||||
|
|
||||||
|
offsets = find_phrase_offsets(text, phrases)
|
||||||
|
|
||||||
|
assert len(offsets) == 3
|
||||||
|
|
||||||
|
found_texts = [text[o[0]:o[1]] for o in offsets]
|
||||||
|
assert "John Smith" in found_texts
|
||||||
|
assert "New York" in found_texts
|
||||||
|
assert "Monday" in found_texts
|
||||||
|
|
||||||
|
def test_repeated_phrase(self):
|
||||||
|
"""Test finding a phrase that appears multiple times."""
|
||||||
|
text = "John met John at John's house."
|
||||||
|
phrases = ["John"]
|
||||||
|
|
||||||
|
offsets = find_phrase_offsets(text, phrases)
|
||||||
|
|
||||||
|
assert len(offsets) == 3
|
||||||
|
for offset in offsets:
|
||||||
|
assert text[offset[0]:offset[1]] == "John"
|
||||||
|
|
||||||
|
def test_no_match(self):
|
||||||
|
"""Test when phrase is not found."""
|
||||||
|
text = "Hello world."
|
||||||
|
phrases = ["John Smith"]
|
||||||
|
|
||||||
|
offsets = find_phrase_offsets(text, phrases)
|
||||||
|
|
||||||
|
assert len(offsets) == 0
|
||||||
|
|
||||||
|
def test_empty_phrases(self):
|
||||||
|
"""Test with empty phrase list."""
|
||||||
|
text = "Some text here."
|
||||||
|
phrases = []
|
||||||
|
|
||||||
|
offsets = find_phrase_offsets(text, phrases)
|
||||||
|
|
||||||
|
assert len(offsets) == 0
|
||||||
|
|
||||||
|
def test_overlapping_matches(self):
|
||||||
|
"""Test phrases that overlap in the text."""
|
||||||
|
text = "New York City is great."
|
||||||
|
phrases = ["New York", "York City"]
|
||||||
|
|
||||||
|
offsets = find_phrase_offsets(text, phrases)
|
||||||
|
|
||||||
|
# Both should be found
|
||||||
|
assert len(offsets) == 2
|
||||||
|
|
||||||
|
def test_offset_values(self):
|
||||||
|
"""Test that offset values are correct."""
|
||||||
|
text = "Hello John!"
|
||||||
|
phrases = ["John"]
|
||||||
|
|
||||||
|
offsets = find_phrase_offsets(text, phrases)
|
||||||
|
|
||||||
|
assert len(offsets) == 1
|
||||||
|
assert offsets[0][0] == 6 # "John" starts at index 6
|
||||||
|
assert offsets[0][1] == 10 # "John" ends at index 10
|
||||||
|
|
||||||
|
def test_case_sensitive(self):
|
||||||
|
"""Test that matching is case-sensitive."""
|
||||||
|
text = "John and john are different."
|
||||||
|
phrases = ["John"]
|
||||||
|
|
||||||
|
offsets = find_phrase_offsets(text, phrases)
|
||||||
|
|
||||||
|
assert len(offsets) == 1
|
||||||
|
assert text[offsets[0][0]:offsets[0][1]] == "John"
|
||||||
128
tests/test_utils.py
Normal file
128
tests/test_utils.py
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
"""Tests for utility functions."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from dp_fusion_lib import (
|
||||||
|
compute_renyi_divergence_clipped_symmetric,
|
||||||
|
find_lambda,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRenyiDivergence:
|
||||||
|
"""Tests for Renyi divergence computation."""
|
||||||
|
|
||||||
|
def test_identical_distributions(self):
|
||||||
|
"""Test that identical distributions have zero divergence."""
|
||||||
|
p = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
||||||
|
q = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
||||||
|
|
||||||
|
div = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
||||||
|
|
||||||
|
assert div.item() < 1e-6
|
||||||
|
|
||||||
|
def test_symmetric(self):
|
||||||
|
"""Test that symmetric divergence gives same result for p,q and q,p."""
|
||||||
|
p = torch.tensor([0.4, 0.3, 0.2, 0.1])
|
||||||
|
q = torch.tensor([0.1, 0.2, 0.3, 0.4])
|
||||||
|
|
||||||
|
div_pq = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
||||||
|
div_qp = compute_renyi_divergence_clipped_symmetric(q, p, alpha=2.0)
|
||||||
|
|
||||||
|
# Symmetric divergence should be the same
|
||||||
|
assert abs(div_pq.item() - div_qp.item()) < 1e-6
|
||||||
|
|
||||||
|
def test_positive_divergence(self):
|
||||||
|
"""Test that divergence is non-negative."""
|
||||||
|
p = torch.tensor([0.7, 0.2, 0.1])
|
||||||
|
q = torch.tensor([0.1, 0.2, 0.7])
|
||||||
|
|
||||||
|
div = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
||||||
|
|
||||||
|
assert div.item() >= 0
|
||||||
|
|
||||||
|
def test_higher_alpha_different_result(self):
|
||||||
|
"""Test that different alpha values give different results."""
|
||||||
|
p = torch.tensor([0.7, 0.2, 0.1])
|
||||||
|
q = torch.tensor([0.3, 0.4, 0.3])
|
||||||
|
|
||||||
|
div_2 = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
||||||
|
div_5 = compute_renyi_divergence_clipped_symmetric(p, q, alpha=5.0)
|
||||||
|
|
||||||
|
# Results should be different for different alpha
|
||||||
|
assert abs(div_2.item() - div_5.item()) > 1e-6
|
||||||
|
|
||||||
|
def test_alpha_validation(self):
|
||||||
|
"""Test that alpha <= 1 raises error."""
|
||||||
|
p = torch.tensor([0.5, 0.5])
|
||||||
|
q = torch.tensor([0.5, 0.5])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="alpha must be > 1"):
|
||||||
|
compute_renyi_divergence_clipped_symmetric(p, q, alpha=1.0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="alpha must be > 1"):
|
||||||
|
compute_renyi_divergence_clipped_symmetric(p, q, alpha=0.5)
|
||||||
|
|
||||||
|
def test_batch_computation(self):
|
||||||
|
"""Test that batch computation works."""
|
||||||
|
p = torch.tensor([[0.5, 0.5], [0.7, 0.3]])
|
||||||
|
q = torch.tensor([[0.5, 0.5], [0.3, 0.7]])
|
||||||
|
|
||||||
|
div = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
||||||
|
|
||||||
|
assert div.shape == (2,)
|
||||||
|
assert div[0].item() < 1e-6 # Identical
|
||||||
|
assert div[1].item() > 0 # Different
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindLambda:
|
||||||
|
"""Tests for lambda search function."""
|
||||||
|
|
||||||
|
def test_identical_distributions_lambda_1(self):
|
||||||
|
"""Test that identical distributions give lambda=1."""
|
||||||
|
p = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
||||||
|
q = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
||||||
|
|
||||||
|
lam, div = find_lambda(p, q, alpha=2.0, beta=0.1)
|
||||||
|
|
||||||
|
assert lam == 1.0
|
||||||
|
assert div < 1e-6
|
||||||
|
|
||||||
|
def test_beta_zero_lambda_zero(self):
|
||||||
|
"""Test that beta=0 gives lambda=0."""
|
||||||
|
p = torch.tensor([0.7, 0.2, 0.1])
|
||||||
|
q = torch.tensor([0.1, 0.2, 0.7])
|
||||||
|
|
||||||
|
lam, div = find_lambda(p, q, alpha=2.0, beta=0.0)
|
||||||
|
|
||||||
|
assert lam == 0.0
|
||||||
|
assert div == 0.0
|
||||||
|
|
||||||
|
def test_lambda_in_range(self):
|
||||||
|
"""Test that lambda is in [0, 1]."""
|
||||||
|
p = torch.tensor([0.7, 0.2, 0.1])
|
||||||
|
q = torch.tensor([0.1, 0.2, 0.7])
|
||||||
|
|
||||||
|
lam, div = find_lambda(p, q, alpha=2.0, beta=0.5)
|
||||||
|
|
||||||
|
assert 0.0 <= lam <= 1.0
|
||||||
|
|
||||||
|
def test_divergence_respects_bound(self):
|
||||||
|
"""Test that returned divergence is <= beta."""
|
||||||
|
p = torch.tensor([0.6, 0.3, 0.1])
|
||||||
|
q = torch.tensor([0.2, 0.3, 0.5])
|
||||||
|
beta = 0.3
|
||||||
|
|
||||||
|
lam, div = find_lambda(p, q, alpha=2.0, beta=beta)
|
||||||
|
|
||||||
|
assert div <= beta + 1e-6 # Allow small numerical error
|
||||||
|
|
||||||
|
def test_higher_beta_higher_lambda(self):
|
||||||
|
"""Test that higher beta allows higher lambda."""
|
||||||
|
p = torch.tensor([0.8, 0.15, 0.05])
|
||||||
|
q = torch.tensor([0.1, 0.2, 0.7])
|
||||||
|
|
||||||
|
lam_low, _ = find_lambda(p, q, alpha=2.0, beta=0.1)
|
||||||
|
lam_high, _ = find_lambda(p, q, alpha=2.0, beta=0.5)
|
||||||
|
|
||||||
|
assert lam_low <= lam_high
|
||||||
Loading…
Add table
Add a link
Reference in a new issue