mirror of
https://github.com/rushil-thareja/dp-fusion-lib.git
synced 2026-04-26 12:26:22 +02:00
Initial release v0.1.0
- Token-level differential privacy for LLMs - Integration with Document Privacy API - Comprehensive test suite and documentation - Examples and Jupyter notebook included
This commit is contained in:
commit
d012046d85
31 changed files with 4480 additions and 0 deletions
85
src/dp_fusion_lib/__init__.py
Normal file
85
src/dp_fusion_lib/__init__.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""
|
||||
DP-Fusion-Lib: Token-Level Differentially Private Inference for LLMs
|
||||
|
||||
Generate text with formal (epsilon, delta)-differential privacy guarantees
|
||||
using distribution fusion techniques.
|
||||
|
||||
This library implements the DP-Fusion algorithm from:
|
||||
|
||||
Thareja et al. "DP-Fusion: Token-Level Differentially Private
|
||||
Inference for Large Language Models" (arXiv:2507.04531)
|
||||
|
||||
Quick Start:
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> from dp_fusion_lib import DPFusion, compute_epsilon_single_group
|
||||
>>>
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
>>>
|
||||
>>> dpf = DPFusion(model=model, tokenizer=tokenizer)
|
||||
>>> dpf.add_message("system", "You are a helpful assistant.", is_private=False)
|
||||
>>> dpf.add_message("user", "My SSN is 123-45-6789. Summarize my info.", is_private=True)
|
||||
>>>
|
||||
>>> output = dpf.generate(alpha=2.0, beta=0.1, max_new_tokens=100)
|
||||
>>> print(output["text"])
|
||||
>>>
|
||||
>>> # Compute privacy guarantee
|
||||
>>> eps = compute_epsilon_single_group(
|
||||
... divergences=output["divergences"]["PRIVATE"],
|
||||
... alpha=2.0,
|
||||
... delta=1e-5,
|
||||
... beta=0.1
|
||||
... )
|
||||
>>> print(f"Privacy: epsilon={eps['empirical']:.2f} at delta=1e-5")
|
||||
"""
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PyTorch is required but not installed. Install it first:\n"
|
||||
" pip install torch\n"
|
||||
" or with CUDA: pip install torch --index-url https://download.pytorch.org/whl/cu121\n"
|
||||
" or visit https://pytorch.org/get-started/locally/"
|
||||
) from e
|
||||
|
||||
# Core classes and functions
|
||||
from dp_fusion_lib.core import DPFusion, generate_dp_text
|
||||
from dp_fusion_lib.tagger import Tagger, find_phrase_offsets
|
||||
from dp_fusion_lib.epsilon import compute_epsilon_single_group, compute_dp_epsilon
|
||||
from dp_fusion_lib._version import __version__
|
||||
|
||||
# Utility functions (advanced usage)
|
||||
from dp_fusion_lib.utils import (
|
||||
compute_renyi_divergence_clipped_symmetric,
|
||||
find_lambda,
|
||||
replace_sequences_with_placeholder_fast,
|
||||
dp_fusion_groups_incremental,
|
||||
format_prompt_new_template,
|
||||
DEFAULT_BETA_DICT,
|
||||
ENTITY_TYPES,
|
||||
PLACEHOLDER_TOKEN,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Main API
|
||||
"DPFusion",
|
||||
"Tagger",
|
||||
"generate_dp_text",
|
||||
# Epsilon computation
|
||||
"compute_epsilon_single_group",
|
||||
"compute_dp_epsilon",
|
||||
# Utility functions (advanced)
|
||||
"find_phrase_offsets",
|
||||
"compute_renyi_divergence_clipped_symmetric",
|
||||
"find_lambda",
|
||||
"replace_sequences_with_placeholder_fast",
|
||||
"dp_fusion_groups_incremental",
|
||||
"format_prompt_new_template",
|
||||
# Constants
|
||||
"DEFAULT_BETA_DICT",
|
||||
"ENTITY_TYPES",
|
||||
"PLACEHOLDER_TOKEN",
|
||||
# Version
|
||||
"__version__",
|
||||
]
|
||||
3
src/dp_fusion_lib/_version.py
Normal file
3
src/dp_fusion_lib/_version.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""Version information for dp-fusion-lib."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
482
src/dp_fusion_lib/core.py
Normal file
482
src/dp_fusion_lib/core.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
"""
|
||||
Core DP-Fusion generation module.
|
||||
|
||||
This module provides the main DPFusion class and convenience functions
|
||||
for differentially private text generation using distribution fusion.
|
||||
|
||||
Theory:
|
||||
DP-Fusion mixes token probability distributions from:
|
||||
1. Private context: Full sensitive document
|
||||
2. Public context: Redacted version with placeholders
|
||||
|
||||
The mixing is controlled via λ to bound the Rényi divergence,
|
||||
providing formal (ε, δ)-differential privacy guarantees.
|
||||
|
||||
Reference:
|
||||
Thareja et al. "DP-Fusion: Token-Level Differentially Private
|
||||
Inference for Large Language Models" (arXiv:2507.04531)
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from dp_fusion_lib.tagger import Tagger, find_phrase_offsets
|
||||
from dp_fusion_lib.utils import (
|
||||
dp_fusion_groups_incremental,
|
||||
format_prompt_new_template,
|
||||
replace_sequences_with_placeholder_fast,
|
||||
)
|
||||
|
||||
|
||||
class DPFusion:
|
||||
"""
|
||||
DP-Fusion wrapper for differentially private text generation.
|
||||
|
||||
This class provides a clean API for mixing private and public distributions
|
||||
to generate text with differential privacy guarantees.
|
||||
|
||||
The workflow supports two modes:
|
||||
1. **Message-based**: Build context with `add_message()`, run `run_tagger()`
|
||||
for automatic phrase extraction, then `generate()`.
|
||||
2. **Direct context**: Pass `private_context` and `public_context` directly
|
||||
to `generate()`.
|
||||
|
||||
Example (Message-based with Tagger):
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> from dp_fusion_lib import DPFusion, Tagger
|
||||
>>>
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
>>> tagger = Tagger(api_key="sk_...")
|
||||
>>>
|
||||
>>> dpf = DPFusion(model=model, tokenizer=tokenizer, tagger=tagger)
|
||||
>>> dpf.add_message("system", "You are a helpful assistant.", is_private=False)
|
||||
>>> dpf.add_message("user", "My SSN is 123-45-6789.", is_private=True)
|
||||
>>> dpf.run_tagger()
|
||||
>>> output = dpf.generate(alpha=2.0, beta=0.1)
|
||||
>>> print(output["text"])
|
||||
|
||||
Example (Direct context):
|
||||
>>> dpf = DPFusion(model=model, tokenizer=tokenizer)
|
||||
>>> output = dpf.generate(
|
||||
... private_context="John Doe's SSN is 123-45-6789.",
|
||||
... public_context="_'s SSN is _.",
|
||||
... alpha=2.0,
|
||||
... beta=0.1
|
||||
... )
|
||||
|
||||
Args:
|
||||
model: A HuggingFace CausalLM model (on any device)
|
||||
tokenizer: Corresponding HuggingFace tokenizer
|
||||
max_tokens: Maximum number of tokens to generate (default: 100)
|
||||
placeholder: Placeholder character for redacted content (default: "_")
|
||||
tagger: Optional Tagger instance for automatic phrase extraction
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
tokenizer,
|
||||
max_tokens: int = 100,
|
||||
placeholder: str = "_",
|
||||
tagger: Optional[Tagger] = None
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.placeholder = placeholder
|
||||
self.tagger = tagger
|
||||
|
||||
# Auto-detect device from model parameters
|
||||
self.device = next(model.parameters()).device
|
||||
|
||||
# Ensure tokenizer has pad token
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
# Message storage for building context
|
||||
self._messages: List[Dict] = []
|
||||
|
||||
# Cached contexts (populated by run_tagger)
|
||||
self._private_context: Optional[str] = None
|
||||
self._public_context: Optional[str] = None
|
||||
self._private_tokens: Optional[torch.Tensor] = None
|
||||
self._public_tokens: Optional[torch.Tensor] = None
|
||||
|
||||
def add_message(self, role: str, content: str, is_private: bool = False):
|
||||
"""
|
||||
Add a message to the conversation context.
|
||||
|
||||
Args:
|
||||
role: Message role - "system", "user", or "assistant"
|
||||
content: The message text
|
||||
is_private: If True, content is sensitive and will be redacted
|
||||
in the public context
|
||||
"""
|
||||
self._messages.append({
|
||||
"role": role,
|
||||
"content": content,
|
||||
"is_private": is_private
|
||||
})
|
||||
|
||||
def clear_messages(self):
|
||||
"""Clear all stored messages and cached contexts."""
|
||||
self._messages = []
|
||||
self._private_context = None
|
||||
self._public_context = None
|
||||
self._private_tokens = None
|
||||
self._public_tokens = None
|
||||
|
||||
def run_tagger(self):
|
||||
"""
|
||||
Run the tagger on all private messages to extract and redact private phrases.
|
||||
|
||||
This method calls the privacy API to identify sensitive phrases in messages
|
||||
marked as private, then builds both private and public contexts with
|
||||
fine-grained redaction at the token level to ensure alignment.
|
||||
|
||||
Must be called before generate() if using fine-grained redaction.
|
||||
Populates self._private_context, self._public_context, and token versions.
|
||||
|
||||
Raises:
|
||||
ValueError: If no tagger is configured or no messages added
|
||||
requests.RequestException: If API call fails
|
||||
"""
|
||||
if self.tagger is None:
|
||||
raise ValueError("No tagger configured. Pass tagger to DPFusion.__init__")
|
||||
|
||||
if not self._messages:
|
||||
raise ValueError("No messages added. Use add_message() first.")
|
||||
|
||||
# Collect all private phrases from private messages
|
||||
all_phrases = []
|
||||
for msg in self._messages:
|
||||
if msg["is_private"]:
|
||||
phrases = self.tagger.extract_private_phrases(msg["content"])
|
||||
all_phrases.extend(phrases)
|
||||
|
||||
# Build the full private prompt text
|
||||
private_msgs = [{"role": msg["role"], "content": msg["content"]} for msg in self._messages]
|
||||
self._private_context = self.tokenizer.apply_chat_template(
|
||||
private_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Tokenize the full private context
|
||||
self._private_tokens = self.tokenizer.encode(self._private_context, return_tensors="pt")[0]
|
||||
|
||||
if all_phrases:
|
||||
# Find phrase offsets in the FULL prompt text
|
||||
offsets = find_phrase_offsets(self._private_context, all_phrases)
|
||||
|
||||
# Get public tokens directly - SAME LENGTH as private tokens!
|
||||
public_token_ids = replace_sequences_with_placeholder_fast(
|
||||
self._private_context, offsets, self.placeholder, self.tokenizer
|
||||
)
|
||||
self._public_tokens = torch.tensor(public_token_ids)
|
||||
|
||||
# Decode for display purposes only
|
||||
self._public_context = self.tokenizer.decode(self._public_tokens, skip_special_tokens=False)
|
||||
else:
|
||||
# No private phrases found, public = private
|
||||
self._public_tokens = self._private_tokens.clone()
|
||||
self._public_context = self._private_context
|
||||
|
||||
@property
|
||||
def private_context(self) -> str:
|
||||
"""
|
||||
Get the private context (full text with no redaction).
|
||||
|
||||
Call run_tagger() first to populate this property.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string with full private content
|
||||
|
||||
Raises:
|
||||
ValueError: If run_tagger() hasn't been called
|
||||
"""
|
||||
if self._private_context is None:
|
||||
raise ValueError("No context available. Call run_tagger() first.")
|
||||
return self._private_context
|
||||
|
||||
@property
|
||||
def public_context(self) -> str:
|
||||
"""
|
||||
Get the public context (text with private phrases redacted).
|
||||
|
||||
Call run_tagger() first to populate this property.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string with redacted content
|
||||
|
||||
Raises:
|
||||
ValueError: If run_tagger() hasn't been called
|
||||
"""
|
||||
if self._public_context is None:
|
||||
raise ValueError("No context available. Call run_tagger() first.")
|
||||
return self._public_context
|
||||
|
||||
def _build_contexts(self):
|
||||
"""
|
||||
Build private and public contexts from stored messages.
|
||||
|
||||
This is used when run_tagger() hasn't been called, providing
|
||||
a simple full-message redaction fallback.
|
||||
|
||||
Returns:
|
||||
Tuple of (private_messages, public_messages) for apply_chat_template.
|
||||
"""
|
||||
private_msgs = []
|
||||
public_msgs = []
|
||||
|
||||
for msg in self._messages:
|
||||
private_msgs.append({"role": msg["role"], "content": msg["content"]})
|
||||
if msg["is_private"]:
|
||||
# Redact entire content with placeholder
|
||||
public_msgs.append({"role": msg["role"], "content": self.placeholder})
|
||||
else:
|
||||
public_msgs.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
return private_msgs, public_msgs
|
||||
|
||||
def get_context_text(self) -> str:
|
||||
"""
|
||||
Get formatted context text using tokenizer's chat template.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string with special tokens
|
||||
"""
|
||||
msgs = [{"role": msg["role"], "content": msg["content"]} for msg in self._messages]
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
msgs,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
private_context: Optional[str] = None,
|
||||
public_context: Optional[str] = None,
|
||||
alpha: float = 2.0,
|
||||
beta: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
debug: bool = False
|
||||
) -> Dict[str, Union[str, dict]]:
|
||||
"""
|
||||
Generate text using DP-Fusion mixing of private and public distributions.
|
||||
|
||||
Can be called in two ways:
|
||||
1. **With stored messages** (via add_message): `generate(alpha=2.0, beta=0.5)`
|
||||
2. **With explicit contexts**: `generate(private_context="...", public_context="...")`
|
||||
|
||||
Args:
|
||||
private_context: The full sensitive document text (optional if using messages)
|
||||
public_context: The redacted document text (optional if using messages)
|
||||
alpha: Renyi divergence order, must be > 1 (default: 2.0)
|
||||
beta: Divergence threshold - lower = more privacy (default: 0.5)
|
||||
Internal bound is alpha * beta per the paper notation.
|
||||
temperature: Softmax temperature for sampling (default: 1.0)
|
||||
max_new_tokens: Override max tokens for this call (optional)
|
||||
debug: Enable debug printing (default: False)
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- "text": Generated text (str)
|
||||
- "lambdas": Per-step lambda values per group (dict)
|
||||
- "divergences": Per-step divergence values per group (dict)
|
||||
|
||||
Raises:
|
||||
ValueError: If no context is available (neither messages nor explicit contexts)
|
||||
"""
|
||||
if private_context is None and public_context is None:
|
||||
# Check if run_tagger() was called - use pre-computed tokens directly
|
||||
if self._private_tokens is not None:
|
||||
private_tokens = self._private_tokens
|
||||
public_tokens = self._public_tokens
|
||||
else:
|
||||
# Use stored messages with default _build_contexts behavior
|
||||
if not self._messages:
|
||||
raise ValueError(
|
||||
"No messages added. Use add_message() or provide "
|
||||
"private_context/public_context."
|
||||
)
|
||||
|
||||
private_msgs, public_msgs = self._build_contexts()
|
||||
|
||||
private_prompt = self.tokenizer.apply_chat_template(
|
||||
private_msgs,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
public_prompt = self.tokenizer.apply_chat_template(
|
||||
public_msgs,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
private_tokens = self.tokenizer.encode(private_prompt, return_tensors="pt")[0]
|
||||
public_tokens = self.tokenizer.encode(public_prompt, return_tensors="pt")[0]
|
||||
else:
|
||||
# Use provided contexts
|
||||
private_prompt = format_prompt_new_template(
|
||||
self.tokenizer,
|
||||
private_context,
|
||||
self.placeholder
|
||||
)
|
||||
public_prompt = format_prompt_new_template(
|
||||
self.tokenizer,
|
||||
public_context,
|
||||
self.placeholder
|
||||
)
|
||||
private_tokens = self.tokenizer.encode(private_prompt, return_tensors="pt")[0]
|
||||
public_tokens = self.tokenizer.encode(public_prompt, return_tensors="pt")[0]
|
||||
|
||||
# Create token groups dict
|
||||
# "PUBLIC" is the redacted version, "PRIVATE" is the full sensitive version
|
||||
token_ids_groups = {
|
||||
"PUBLIC": public_tokens,
|
||||
"PRIVATE": private_tokens
|
||||
}
|
||||
|
||||
# Beta dict for the private group
|
||||
# Paper notation: D_alpha <= alpha * beta, so internal bound = alpha * beta
|
||||
internal_beta = alpha * beta
|
||||
beta_dict = {"PRIVATE": internal_beta}
|
||||
|
||||
# Determine max tokens
|
||||
tokens_to_generate = max_new_tokens if max_new_tokens else self.max_tokens
|
||||
|
||||
# Run DP-Fusion generation
|
||||
generated_text, lambdas, divergences = dp_fusion_groups_incremental(
|
||||
token_ids_groups=token_ids_groups,
|
||||
beta_dict=beta_dict,
|
||||
alpha=alpha,
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
temperature=temperature,
|
||||
max_new_tokens=tokens_to_generate,
|
||||
debug_mode=debug
|
||||
)
|
||||
|
||||
return {
|
||||
"text": generated_text,
|
||||
"lambdas": lambdas,
|
||||
"divergences": divergences
|
||||
}
|
||||
|
||||
def generate_from_tokens(
|
||||
self,
|
||||
private_tokens: torch.Tensor,
|
||||
public_tokens: torch.Tensor,
|
||||
alpha: float = 2.0,
|
||||
beta: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
debug: bool = False
|
||||
) -> Dict[str, Union[str, dict]]:
|
||||
"""
|
||||
Generate text from pre-tokenized inputs.
|
||||
|
||||
This is useful when you want more control over tokenization
|
||||
or are processing batches.
|
||||
|
||||
Args:
|
||||
private_tokens: Token IDs for private context (1D tensor)
|
||||
public_tokens: Token IDs for public/redacted context (1D tensor)
|
||||
alpha: Renyi divergence order (default: 2.0)
|
||||
beta: Divergence threshold (default: 0.5)
|
||||
temperature: Softmax temperature (default: 1.0)
|
||||
max_new_tokens: Override max tokens (optional)
|
||||
debug: Enable debug printing (default: False)
|
||||
|
||||
Returns:
|
||||
dict: Same as generate()
|
||||
"""
|
||||
token_ids_groups = {
|
||||
"PUBLIC": public_tokens,
|
||||
"PRIVATE": private_tokens
|
||||
}
|
||||
|
||||
# Paper notation: D_alpha <= alpha * beta, so internal bound = alpha * beta
|
||||
internal_beta = alpha * beta
|
||||
beta_dict = {"PRIVATE": internal_beta}
|
||||
|
||||
tokens_to_generate = max_new_tokens if max_new_tokens else self.max_tokens
|
||||
|
||||
generated_text, lambdas, divergences = dp_fusion_groups_incremental(
|
||||
token_ids_groups=token_ids_groups,
|
||||
beta_dict=beta_dict,
|
||||
alpha=alpha,
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
temperature=temperature,
|
||||
max_new_tokens=tokens_to_generate,
|
||||
debug_mode=debug
|
||||
)
|
||||
|
||||
return {
|
||||
"text": generated_text,
|
||||
"lambdas": lambdas,
|
||||
"divergences": divergences
|
||||
}
|
||||
|
||||
|
||||
def generate_dp_text(
|
||||
model,
|
||||
tokenizer,
|
||||
private_context: str,
|
||||
public_context: str,
|
||||
alpha: float = 2.0,
|
||||
beta: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
max_new_tokens: int = 100,
|
||||
debug: bool = False
|
||||
) -> Dict[str, Union[str, dict]]:
|
||||
"""
|
||||
Convenience function for one-off DP-Fusion generation.
|
||||
|
||||
This is a shortcut that creates a temporary DPFusion instance
|
||||
and generates text in one call.
|
||||
|
||||
Args:
|
||||
model: HuggingFace CausalLM model
|
||||
tokenizer: HuggingFace tokenizer
|
||||
private_context: Full sensitive document text
|
||||
public_context: Redacted document text with placeholders
|
||||
alpha: Renyi divergence order (default: 2.0)
|
||||
beta: Divergence threshold - paper notation where bound = alpha * beta (default: 0.5)
|
||||
temperature: Softmax temperature (default: 1.0)
|
||||
max_new_tokens: Max tokens to generate (default: 100)
|
||||
debug: Enable debug printing (default: False)
|
||||
|
||||
Returns:
|
||||
dict: {"text": str, "lambdas": dict, "divergences": dict}
|
||||
|
||||
Example:
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> from dp_fusion_lib import generate_dp_text
|
||||
>>>
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
>>>
|
||||
>>> output = generate_dp_text(
|
||||
... model=model,
|
||||
... tokenizer=tokenizer,
|
||||
... private_context="John Doe's SSN is 123-45-6789.",
|
||||
... public_context="_'s SSN is _.",
|
||||
... alpha=2.0,
|
||||
... beta=0.1
|
||||
... )
|
||||
>>> print(output["text"])
|
||||
"""
|
||||
dpf = DPFusion(model=model, tokenizer=tokenizer)
|
||||
return dpf.generate(
|
||||
private_context=private_context,
|
||||
public_context=public_context,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
debug=debug
|
||||
)
|
||||
143
src/dp_fusion_lib/epsilon.py
Normal file
143
src/dp_fusion_lib/epsilon.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""
|
||||
Epsilon computation for differential privacy guarantees.
|
||||
|
||||
This module provides functions to compute (ε, δ)-DP guarantees from
|
||||
per-step Rényi divergences, following the theory in:
|
||||
|
||||
Thareja et al. "DP-Fusion: Token-Level Differentially Private
|
||||
Inference for Large Language Models" (arXiv:2507.04531)
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
def compute_epsilon_single_group(
|
||||
divergences: List[float],
|
||||
alpha: float,
|
||||
delta: float,
|
||||
beta: float = None
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute (ε, δ)-DP guarantee for a single private group.
|
||||
|
||||
For a single group (N=1), the per-step RDP formula simplifies to:
|
||||
eps_step = 4 * β_t
|
||||
|
||||
where β_t = divergence_t / α (paper notation).
|
||||
|
||||
Total epsilon:
|
||||
ε = (4/α) * Σ(divergences) + log(1/δ)/(α-1)
|
||||
|
||||
Args:
|
||||
divergences: List of per-step D_α values (bounded by α·β internally).
|
||||
alpha: Rényi order (>1).
|
||||
delta: Target δ in (ε, δ)-DP.
|
||||
beta: Paper's β (where internal bound = α·β). If provided,
|
||||
also computes theoretical epsilon.
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- "empirical": ε computed from actual divergences
|
||||
- "theoretical": ε assuming divergence = α·β at each step (if beta provided)
|
||||
- "T": number of tokens generated
|
||||
"""
|
||||
if alpha <= 1.0:
|
||||
raise ValueError("alpha must be > 1")
|
||||
if delta <= 0.0 or delta >= 1.0:
|
||||
raise ValueError("delta must be in (0,1)")
|
||||
|
||||
T = len(divergences)
|
||||
log_delta_term = math.log(1.0 / delta) / (alpha - 1.0)
|
||||
|
||||
# Empirical: divergences are bounded by α·β, so β_t = d/α
|
||||
# eps_t = 4 * β_t = 4 * (d / α)
|
||||
empirical_rdp = sum(4.0 * (d / alpha) for d in divergences)
|
||||
epsilon_empirical = empirical_rdp + log_delta_term
|
||||
|
||||
result = {
|
||||
"empirical": epsilon_empirical,
|
||||
"T": T
|
||||
}
|
||||
|
||||
# Theoretical: worst-case is divergence = α·β each step
|
||||
# β_t = β, so eps_t = 4 * β
|
||||
if beta is not None:
|
||||
theoretical_rdp = T * 4.0 * beta
|
||||
epsilon_theoretical = theoretical_rdp + log_delta_term
|
||||
result["theoretical"] = epsilon_theoretical
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compute_dp_epsilon(
|
||||
divergences: Dict[str, List[float]],
|
||||
alpha: float,
|
||||
delta: float,
|
||||
mode: str = "global"
|
||||
) -> Union[float, Dict[str, float]]:
|
||||
"""
|
||||
Compute (ε, δ)-DP guarantee from per-step Rényi divergences.
|
||||
|
||||
Supports multi-group privacy with either global or per-group guarantees.
|
||||
|
||||
Args:
|
||||
divergences: Mapping group_name -> list of β_t values (length=T).
|
||||
The key "PUBLIC" (if present) will be ignored.
|
||||
alpha: Rényi order (>1).
|
||||
delta: Target δ in (ε, δ)-DP.
|
||||
mode: "global" for one ε protecting all groups (worst-case per step),
|
||||
"per_group" for a dict of ε_i per group.
|
||||
|
||||
Returns:
|
||||
If mode == "global": float ε.
|
||||
If mode == "per_group": dict of {group: ε_i}.
|
||||
"""
|
||||
if alpha <= 1.0:
|
||||
raise ValueError("alpha must be > 1")
|
||||
if delta <= 0.0 or delta >= 1.0:
|
||||
raise ValueError("delta must be in (0,1)")
|
||||
|
||||
# Filter out PUBLIC and ensure at least one private group
|
||||
priv_div = {g: lst for g, lst in divergences.items() if g != "PUBLIC"}
|
||||
if not priv_div:
|
||||
raise ValueError("No private groups provided")
|
||||
|
||||
# Ensure all groups have same number of steps
|
||||
step_counts = {len(lst) for lst in priv_div.values()}
|
||||
if len(step_counts) != 1:
|
||||
raise ValueError(f"Divergence lists have unequal lengths: {step_counts}")
|
||||
|
||||
T = step_counts.pop()
|
||||
N = len(priv_div)
|
||||
|
||||
def eps_step(beta: float) -> float:
|
||||
"""Compute per-step RDP cost."""
|
||||
if beta is None:
|
||||
raise ValueError("Found None in divergence list")
|
||||
arg = (N - 1.0) / N + (1.0 / N) * math.exp((alpha - 1.0) * 4.0 * beta)
|
||||
if arg <= 0.0:
|
||||
raise ValueError(f"Non-positive argument for log: {arg}")
|
||||
return (1.0 / (alpha - 1.0)) * math.log(arg)
|
||||
|
||||
if mode == "global":
|
||||
total_rdp = 0.0
|
||||
for t in range(T):
|
||||
betas = [div_list[t] for div_list in priv_div.values()]
|
||||
beta_max = max(betas)
|
||||
total_rdp += eps_step(beta_max)
|
||||
epsilon = total_rdp + math.log(1.0 / delta) / (alpha - 1.0)
|
||||
return epsilon
|
||||
|
||||
elif mode == "per_group":
|
||||
epsilons = {}
|
||||
for group, div_list in priv_div.items():
|
||||
total_rdp_g = 0.0
|
||||
for beta_t in div_list:
|
||||
total_rdp_g += eps_step(beta_t)
|
||||
eps_group = total_rdp_g + math.log(1.0 / delta) / (alpha - 1.0)
|
||||
epsilons[group] = eps_group
|
||||
return epsilons
|
||||
|
||||
else:
|
||||
raise ValueError("mode must be 'global' or 'per_group'")
|
||||
172
src/dp_fusion_lib/tagger.py
Normal file
172
src/dp_fusion_lib/tagger.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""
|
||||
Private phrase extraction using the Document Privacy API.
|
||||
|
||||
This module provides the Tagger class for automatically identifying
|
||||
sensitive/private phrases in documents using an external API.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def find_phrase_offsets(text: str, phrases: List[str]) -> List[List[int]]:
|
||||
"""
|
||||
Find all occurrences of phrases in text and return [start, end] offsets.
|
||||
|
||||
Args:
|
||||
text: The full text to search in
|
||||
phrases: List of phrases to find
|
||||
|
||||
Returns:
|
||||
List of [start_char, end_char] offsets for all phrase occurrences
|
||||
"""
|
||||
offsets = []
|
||||
for phrase in phrases:
|
||||
start = 0
|
||||
while True:
|
||||
idx = text.find(phrase, start)
|
||||
if idx == -1:
|
||||
break
|
||||
offsets.append([idx, idx + len(phrase)])
|
||||
start = idx + 1
|
||||
return offsets
|
||||
|
||||
|
||||
class Tagger:
|
||||
"""
|
||||
Private phrase extraction using the Document Privacy API.
|
||||
|
||||
The Tagger uses an external API to identify sensitive information
|
||||
in documents. It supports different extraction models and document
|
||||
types (constitutions).
|
||||
|
||||
Example:
|
||||
>>> tagger = Tagger(api_key="sk_...")
|
||||
>>> tagger.set_model("llama3.1-8b")
|
||||
>>> tagger.set_constitution("HEALTH")
|
||||
>>> phrases = tagger.extract_private_phrases("John Doe visited on 01/01/1990.")
|
||||
>>> print(phrases)
|
||||
['John Doe', '01/01/1990']
|
||||
|
||||
Args:
|
||||
api_key: API key for the Document Privacy API
|
||||
verbose: If True, log input/output of API calls (default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, verbose: bool = False):
|
||||
"""
|
||||
Initialize the Tagger with an API key.
|
||||
|
||||
Args:
|
||||
api_key: API key for the Document Privacy API
|
||||
verbose: If True, log input/output of API calls (default: False)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_base = "https://api.documentprivacy.com"
|
||||
self._model = "llama3.1-8b"
|
||||
self._constitution = "HEALTH"
|
||||
self.verbose = verbose
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""
|
||||
Set the extraction model.
|
||||
|
||||
Args:
|
||||
model: Model identifier (e.g., 'llama3.1-8b')
|
||||
"""
|
||||
self._model = model
|
||||
|
||||
def set_constitution(self, constitution: str):
|
||||
"""
|
||||
Set the document type/constitution.
|
||||
|
||||
Available constitutions depend on the API. Common options:
|
||||
- 'HEALTH': Medical/healthcare documents
|
||||
- 'FINANCE': Financial documents
|
||||
- 'LEGAL': Legal documents
|
||||
|
||||
Args:
|
||||
constitution: Document type identifier
|
||||
"""
|
||||
self._constitution = constitution
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""
|
||||
Get list of available models from the API.
|
||||
|
||||
Returns:
|
||||
List of available model identifiers
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If API call fails
|
||||
"""
|
||||
url = f"{self.api_base}/models"
|
||||
headers = {
|
||||
"X-API-KEY": self.api_key
|
||||
}
|
||||
|
||||
if self.verbose:
|
||||
print(f"[Tagger] GET {url}")
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if self.verbose:
|
||||
print(f"[Tagger] Response: {result}")
|
||||
|
||||
return result
|
||||
|
||||
def extract_private_phrases(self, document: str) -> List[str]:
|
||||
"""
|
||||
Extract private phrases from a document using the API.
|
||||
|
||||
This method sends the document to the Document Privacy API,
|
||||
which uses the configured model and constitution to identify
|
||||
sensitive information.
|
||||
|
||||
Args:
|
||||
document: The text document to analyze
|
||||
|
||||
Returns:
|
||||
List of detected private/sensitive phrases
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If API call fails
|
||||
|
||||
Example:
|
||||
>>> tagger = Tagger(api_key="sk_...")
|
||||
>>> phrases = tagger.extract_private_phrases(
|
||||
... "Patient John Smith, DOB 05/15/1980, was diagnosed with diabetes."
|
||||
... )
|
||||
>>> print(phrases)
|
||||
['John Smith', '05/15/1980', 'diabetes']
|
||||
"""
|
||||
url = f"{self.api_base}/extract"
|
||||
headers = {
|
||||
"X-API-KEY": self.api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"document": document,
|
||||
"model": self._model,
|
||||
"type": self._constitution
|
||||
}
|
||||
|
||||
if self.verbose:
|
||||
print(f"[Tagger] POST {url}")
|
||||
print(f"[Tagger] Input document: {document[:200]}{'...' if len(document) > 200 else ''}")
|
||||
print(f"[Tagger] Model: {self._model}, Constitution: {self._constitution}")
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
private_phrases = data.get("private_phrases", [])
|
||||
|
||||
if self.verbose:
|
||||
print(f"[Tagger] Extracted phrases: {private_phrases}")
|
||||
|
||||
return private_phrases
|
||||
394
src/dp_fusion_lib/utils.py
Normal file
394
src/dp_fusion_lib/utils.py
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
"""
|
||||
Utility functions for DP-Fusion.
|
||||
|
||||
This module contains the core algorithmic components:
|
||||
- Rényi divergence computation
|
||||
- Lambda search for privacy-utility tradeoff
|
||||
- Token replacement for redaction
|
||||
- Incremental DP-Fusion generation
|
||||
"""
|
||||
|
||||
import math
|
||||
from bisect import bisect_right
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# Default beta values for different entity types
|
||||
DEFAULT_BETA_DICT = {
|
||||
"PERSON": 0.5,
|
||||
"CODE": 0.5,
|
||||
"LOC": 0.5,
|
||||
"ORG": 0.5,
|
||||
"DEM": 0.5,
|
||||
"DATETIME": 0.5,
|
||||
"QUANTITY": 0.5,
|
||||
"MISC": 0.5,
|
||||
}
|
||||
|
||||
# Entity types available
|
||||
ENTITY_TYPES = [
|
||||
"PERSON", "CODE", "LOC", "ORG", "DEM",
|
||||
"DATETIME", "QUANTITY", "MISC"
|
||||
]
|
||||
|
||||
# Default placeholder token for redaction
|
||||
PLACEHOLDER_TOKEN = "_"
|
||||
|
||||
|
||||
def replace_sequences_with_placeholder_fast(
|
||||
text: str,
|
||||
word_offsets: List[List[int]],
|
||||
placeholder: str,
|
||||
tokenizer
|
||||
) -> List[int]:
|
||||
"""
|
||||
Replace tokens falling within provided word offset ranges with placeholder tokens.
|
||||
|
||||
Args:
|
||||
text: Original text string
|
||||
word_offsets: List of [start_char, end_char] offsets for words to replace
|
||||
placeholder: Placeholder token to use (e.g., "_")
|
||||
tokenizer: Tokenizer that returns 'input_ids' and 'offset_mapping'
|
||||
|
||||
Returns:
|
||||
Token IDs with specified words replaced by placeholder token ID
|
||||
"""
|
||||
placeholder_id = tokenizer.convert_tokens_to_ids(placeholder)
|
||||
|
||||
encoded = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
input_ids = encoded['input_ids']
|
||||
offsets = encoded['offset_mapping']
|
||||
|
||||
word_offsets = sorted(word_offsets, key=lambda x: x[0])
|
||||
starts = [wo[0] for wo in word_offsets]
|
||||
ends = [wo[1] for wo in word_offsets]
|
||||
|
||||
for i, (t_start, t_end) in enumerate(offsets):
|
||||
if t_start == t_end:
|
||||
continue
|
||||
|
||||
idx = bisect_right(starts, t_end)
|
||||
|
||||
while idx > 0:
|
||||
idx -= 1
|
||||
w_start, w_end = starts[idx], ends[idx]
|
||||
|
||||
if w_end > t_start and w_start < t_end:
|
||||
input_ids[i] = placeholder_id
|
||||
break
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
def compute_renyi_divergence_clipped_symmetric(
|
||||
p: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
alpha: float,
|
||||
eps: float = 1e-10
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute symmetric Rényi divergence D↔_α(p‖q) = max{D_α(p‖q), D_α(q‖p)}.
|
||||
|
||||
Args:
|
||||
p: Probability vector (last dimension is the support)
|
||||
q: Probability vector (last dimension is the support)
|
||||
alpha: Rényi order (must be > 1)
|
||||
eps: Small constant for numerical stability
|
||||
|
||||
Returns:
|
||||
D↔_α(p, q) with shape p.shape[:-1]
|
||||
"""
|
||||
if alpha <= 1.0:
|
||||
raise ValueError("alpha must be > 1")
|
||||
|
||||
p = p.float().clamp_min(eps)
|
||||
q = q.float().clamp_min(eps)
|
||||
|
||||
# Forward direction D_α(p‖q)
|
||||
term_pq = torch.sum(p.pow(alpha) * q.pow(1.0 - alpha), dim=-1).clamp_min(eps)
|
||||
div_pq = (1.0 / (alpha - 1.0)) * torch.log(term_pq)
|
||||
|
||||
# Reverse direction D_α(q‖p)
|
||||
term_qp = torch.sum(q.pow(alpha) * p.pow(1.0 - alpha), dim=-1).clamp_min(eps)
|
||||
div_qp = (1.0 / (alpha - 1.0)) * torch.log(term_qp)
|
||||
|
||||
return torch.maximum(div_pq, div_qp)
|
||||
|
||||
|
||||
def find_lambda(
|
||||
p_priv: torch.Tensor,
|
||||
p_pub: torch.Tensor,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
debug_mode: bool = False,
|
||||
max_iter: int = 20,
|
||||
tol: float = 1e-6
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Binary search for λ in [0,1] that satisfies the divergence bound.
|
||||
|
||||
Finds λ such that:
|
||||
D_α((1-λ)*p_pub + λ*p_priv || p_pub) <= beta
|
||||
|
||||
Args:
|
||||
p_priv: Private distribution (already softmaxed & temperature-scaled)
|
||||
p_pub: Public distribution (already softmaxed & temperature-scaled)
|
||||
alpha: Rényi order (> 1)
|
||||
beta: Divergence threshold (>= 0)
|
||||
debug_mode: Whether to print debug information
|
||||
max_iter: Maximum binary search iterations
|
||||
tol: Tolerance for convergence
|
||||
|
||||
Returns:
|
||||
Tuple of (lambda_value, divergence)
|
||||
"""
|
||||
if beta <= 0:
|
||||
return 0.0, 0.0
|
||||
|
||||
div_at_1 = compute_renyi_divergence_clipped_symmetric(p_priv, p_pub, alpha)
|
||||
|
||||
if div_at_1 <= beta:
|
||||
return 1.0, div_at_1.item() if hasattr(div_at_1, 'item') else div_at_1
|
||||
|
||||
left, right = 0.0, 1.0
|
||||
for _ in range(max_iter):
|
||||
mid = 0.5 * (left + right)
|
||||
mixture = mid * p_priv + (1 - mid) * p_pub
|
||||
div = compute_renyi_divergence_clipped_symmetric(mixture, p_pub, alpha)
|
||||
|
||||
if div > beta:
|
||||
right = mid
|
||||
else:
|
||||
left = mid
|
||||
|
||||
if (right - left) < tol:
|
||||
break
|
||||
|
||||
final_lambda = left
|
||||
mixture = final_lambda * p_priv + (1 - final_lambda) * p_pub
|
||||
final_div = compute_renyi_divergence_clipped_symmetric(mixture, p_pub, alpha)
|
||||
|
||||
return final_lambda, final_div.item() if hasattr(final_div, 'item') else final_div
|
||||
|
||||
|
||||
def dp_fusion_groups_incremental(
|
||||
token_ids_groups: Dict[str, torch.Tensor],
|
||||
beta_dict: Dict[str, float],
|
||||
alpha: float,
|
||||
model,
|
||||
tokenizer,
|
||||
temperature: float = 1.0,
|
||||
max_new_tokens: int = 50,
|
||||
debug_mode: bool = False,
|
||||
device_map=None,
|
||||
batch_override=None
|
||||
) -> Tuple[str, Dict[str, List[float]], Dict[str, List[float]]]:
|
||||
"""
|
||||
DP-Fusion generation with incremental decoding using KV-cache.
|
||||
|
||||
Supports multi-group privacy where each group can have different β thresholds.
|
||||
|
||||
Args:
|
||||
token_ids_groups: Dict mapping group names to token ID tensors.
|
||||
Must include "PUBLIC" key for the redacted version.
|
||||
beta_dict: Mapping from group name to β threshold.
|
||||
alpha: Rényi divergence order (>1).
|
||||
model: HuggingFace CausalLM model.
|
||||
tokenizer: Corresponding tokenizer.
|
||||
temperature: Temperature for scaling logits.
|
||||
max_new_tokens: Maximum tokens to generate.
|
||||
debug_mode: Whether to print debug information.
|
||||
device_map: Optional device map.
|
||||
batch_override: Optional batch settings override.
|
||||
|
||||
Returns:
|
||||
Tuple of (generated_text, lambdas_dict, divergences_dict)
|
||||
"""
|
||||
eos_id = tokenizer.eos_token_id
|
||||
|
||||
going_lambdas: Dict[str, List[float]] = {}
|
||||
going_divergence: Dict[str, List[float]] = {}
|
||||
|
||||
if "PUBLIC" not in token_ids_groups:
|
||||
raise ValueError("Must have a 'PUBLIC' key in token_ids_groups.")
|
||||
|
||||
private_groups = [g for g in token_ids_groups if g != "PUBLIC"]
|
||||
if not private_groups:
|
||||
raise ValueError("No private groups besides 'PUBLIC' – need at least one for DP-Fusion.")
|
||||
|
||||
if device_map:
|
||||
first_device = next(iter(device_map.values()))
|
||||
device = torch.device(f"cuda:{first_device}" if isinstance(first_device, int) else first_device)
|
||||
else:
|
||||
device = model.device
|
||||
|
||||
for group, tokens in token_ids_groups.items():
|
||||
if not isinstance(tokens, torch.Tensor):
|
||||
tokens = torch.tensor(tokens, dtype=torch.long)
|
||||
token_ids_groups[group] = tokens.to(device)
|
||||
|
||||
if debug_mode:
|
||||
print(f"[DP-Fusion] Starting generation. Private groups: {private_groups}")
|
||||
for g in token_ids_groups:
|
||||
print(f"[Initial] Prefix shape for group {g}: {token_ids_groups[g].shape}")
|
||||
|
||||
group_order = list(token_ids_groups.keys())
|
||||
num_groups = len(group_order)
|
||||
|
||||
# Initial pass: process each group's full prefix to build cache
|
||||
prefix_batches = [token_ids_groups[g] for g in group_order]
|
||||
input_batch = torch.nn.utils.rnn.pad_sequence(
|
||||
prefix_batches, batch_first=True, padding_value=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
if debug_mode:
|
||||
print(f"[Initial] Input batch shape: {input_batch.shape}")
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True):
|
||||
outputs = model(input_ids=input_batch, use_cache=True, past_key_values=None)
|
||||
|
||||
past = outputs.past_key_values
|
||||
last_logits = outputs.logits[:, input_batch.size(1) - 1, :]
|
||||
group_logits = {g: last_logits[i] for i, g in enumerate(group_order)}
|
||||
|
||||
pub_scaled = group_logits["PUBLIC"] / temperature
|
||||
p_pub = F.softmax(pub_scaled, dim=-1)
|
||||
|
||||
p_priv_dict = {}
|
||||
for pg in private_groups:
|
||||
priv_scaled = group_logits[pg] / temperature
|
||||
p_priv_dict[pg] = F.softmax(priv_scaled, dim=-1)
|
||||
|
||||
# DP-Fusion: find lambdas and form fused distribution
|
||||
lambdas = {}
|
||||
for pg in private_groups:
|
||||
beta_val = beta_dict.get(pg)
|
||||
lam_pg, got_div = find_lambda(p_priv_dict[pg], p_pub, alpha, beta_val, debug_mode=debug_mode)
|
||||
lambdas[pg] = lam_pg
|
||||
if debug_mode:
|
||||
print(f"[Initial] Selected Lambda for group {pg}: {lam_pg}, Divergence: {got_div}")
|
||||
|
||||
sum_out = torch.zeros_like(p_pub)
|
||||
for pg in private_groups:
|
||||
lam_g = lambdas[pg]
|
||||
mix_g = lam_g * p_priv_dict[pg] + (1 - lam_g) * p_pub
|
||||
sum_out += mix_g
|
||||
p_out_avg = sum_out / len(private_groups)
|
||||
|
||||
next_token = torch.multinomial(p_out_avg, 1).item()
|
||||
|
||||
if debug_mode:
|
||||
token_str = tokenizer.decode([next_token])
|
||||
print(f"[Initial] Sampled token '{token_str}' (ID={next_token})")
|
||||
|
||||
for g in group_order:
|
||||
token_ids_groups[g] = torch.cat(
|
||||
[token_ids_groups[g], torch.tensor([next_token], device=device)], dim=0
|
||||
)
|
||||
|
||||
# Incremental loop
|
||||
for step in range(1, max_new_tokens):
|
||||
new_tokens_batch = torch.tensor([[next_token]] * num_groups, device=device)
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True):
|
||||
outputs = model(input_ids=new_tokens_batch, past_key_values=past, use_cache=True)
|
||||
|
||||
past = outputs.past_key_values
|
||||
last_logits = outputs.logits[:, -1, :]
|
||||
group_logits = {g: last_logits[i] for i, g in enumerate(group_order)}
|
||||
|
||||
pub_scaled = group_logits["PUBLIC"] / temperature
|
||||
p_pub = F.softmax(pub_scaled, dim=-1)
|
||||
|
||||
p_priv_dict = {}
|
||||
for pg in private_groups:
|
||||
priv_scaled = group_logits[pg] / temperature
|
||||
p_priv_dict[pg] = F.softmax(priv_scaled, dim=-1)
|
||||
|
||||
lambdas = {}
|
||||
for pg in private_groups:
|
||||
beta_val = beta_dict.get(pg)
|
||||
lam_pg, div_got = find_lambda(p_priv_dict[pg], p_pub, alpha, beta_val, debug_mode=debug_mode)
|
||||
lambdas[pg] = lam_pg
|
||||
|
||||
if debug_mode:
|
||||
print(f"[Step {step}] Selected Lambda for group {pg}: {lam_pg}, Divergence: {div_got}")
|
||||
|
||||
if pg not in going_lambdas:
|
||||
going_lambdas[pg] = []
|
||||
going_divergence[pg] = []
|
||||
going_lambdas[pg].append(lam_pg)
|
||||
going_divergence[pg].append(div_got)
|
||||
|
||||
sum_out = torch.zeros_like(p_pub)
|
||||
for pg in private_groups:
|
||||
mix_g = lambdas[pg] * p_priv_dict[pg] + (1 - lambdas[pg]) * p_pub
|
||||
sum_out += mix_g
|
||||
p_out_avg = sum_out / len(private_groups)
|
||||
|
||||
next_token = torch.multinomial(p_out_avg, 1).item()
|
||||
|
||||
for g in group_order:
|
||||
token_ids_groups[g] = torch.cat(
|
||||
[token_ids_groups[g], torch.tensor([next_token], device=device)], dim=0
|
||||
)
|
||||
|
||||
del outputs, last_logits, group_logits
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if next_token == eos_id:
|
||||
break
|
||||
|
||||
final_text = tokenizer.decode(token_ids_groups["PUBLIC"], skip_special_tokens=True)
|
||||
|
||||
if debug_mode:
|
||||
print("[DP-Fusion] Generation complete.")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return final_text, going_lambdas, going_divergence
|
||||
|
||||
|
||||
def format_prompt_new_template(tokenizer, private_doc: str, placeholder: str) -> str:
|
||||
"""
|
||||
Create a prompt using the HuggingFace chat template.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer
|
||||
private_doc: Document to paraphrase (may contain placeholders)
|
||||
placeholder: Placeholder character (e.g., "_")
|
||||
|
||||
Returns:
|
||||
Formatted prompt for the model
|
||||
"""
|
||||
system_text = (
|
||||
"You are given a passage that may contain placeholders (underscores) or incomplete data. "
|
||||
"Your job is to produce a natural paraphrase. "
|
||||
"Do not use any underscores or placeholders in your output. "
|
||||
"If data is missing, just omit it or paraphrase gracefully. "
|
||||
"Do not output anything except the paraphrase. "
|
||||
"Make sure retain all information from the source document."
|
||||
)
|
||||
|
||||
user_task = (
|
||||
f"Document:\n\n{private_doc}\n\n"
|
||||
f"Paraphrase the above text. Whenever a placeholder i.e {placeholder} exists, you must completely ignore that information, "
|
||||
f"as {placeholder} indicates redacted text. To ensure the generated text is as natural as possible, "
|
||||
f"you must never output the {placeholder} themselves."
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"{system_text}\n\n{user_task}"},
|
||||
{"role": "assistant", "content": "Sure. Here is the paraphrased document without underscores or placeholders:"},
|
||||
]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
|
||||
return prompt
|
||||
Loading…
Add table
Add a link
Reference in a new issue