mirror of
https://github.com/rushil-thareja/dp-fusion-lib.git
synced 2026-04-24 12:06:23 +02:00
- Token-level differential privacy for LLMs - Integration with Document Privacy API - Comprehensive test suite and documentation - Examples and Jupyter notebook included
128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
"""Tests for utility functions."""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from dp_fusion_lib import (
|
|
compute_renyi_divergence_clipped_symmetric,
|
|
find_lambda,
|
|
)
|
|
|
|
|
|
class TestRenyiDivergence:
|
|
"""Tests for Renyi divergence computation."""
|
|
|
|
def test_identical_distributions(self):
|
|
"""Test that identical distributions have zero divergence."""
|
|
p = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
|
q = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
|
|
|
div = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
|
|
|
assert div.item() < 1e-6
|
|
|
|
def test_symmetric(self):
|
|
"""Test that symmetric divergence gives same result for p,q and q,p."""
|
|
p = torch.tensor([0.4, 0.3, 0.2, 0.1])
|
|
q = torch.tensor([0.1, 0.2, 0.3, 0.4])
|
|
|
|
div_pq = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
|
div_qp = compute_renyi_divergence_clipped_symmetric(q, p, alpha=2.0)
|
|
|
|
# Symmetric divergence should be the same
|
|
assert abs(div_pq.item() - div_qp.item()) < 1e-6
|
|
|
|
def test_positive_divergence(self):
|
|
"""Test that divergence is non-negative."""
|
|
p = torch.tensor([0.7, 0.2, 0.1])
|
|
q = torch.tensor([0.1, 0.2, 0.7])
|
|
|
|
div = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
|
|
|
assert div.item() >= 0
|
|
|
|
def test_higher_alpha_different_result(self):
|
|
"""Test that different alpha values give different results."""
|
|
p = torch.tensor([0.7, 0.2, 0.1])
|
|
q = torch.tensor([0.3, 0.4, 0.3])
|
|
|
|
div_2 = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
|
div_5 = compute_renyi_divergence_clipped_symmetric(p, q, alpha=5.0)
|
|
|
|
# Results should be different for different alpha
|
|
assert abs(div_2.item() - div_5.item()) > 1e-6
|
|
|
|
def test_alpha_validation(self):
|
|
"""Test that alpha <= 1 raises error."""
|
|
p = torch.tensor([0.5, 0.5])
|
|
q = torch.tensor([0.5, 0.5])
|
|
|
|
with pytest.raises(ValueError, match="alpha must be > 1"):
|
|
compute_renyi_divergence_clipped_symmetric(p, q, alpha=1.0)
|
|
|
|
with pytest.raises(ValueError, match="alpha must be > 1"):
|
|
compute_renyi_divergence_clipped_symmetric(p, q, alpha=0.5)
|
|
|
|
def test_batch_computation(self):
|
|
"""Test that batch computation works."""
|
|
p = torch.tensor([[0.5, 0.5], [0.7, 0.3]])
|
|
q = torch.tensor([[0.5, 0.5], [0.3, 0.7]])
|
|
|
|
div = compute_renyi_divergence_clipped_symmetric(p, q, alpha=2.0)
|
|
|
|
assert div.shape == (2,)
|
|
assert div[0].item() < 1e-6 # Identical
|
|
assert div[1].item() > 0 # Different
|
|
|
|
|
|
class TestFindLambda:
|
|
"""Tests for lambda search function."""
|
|
|
|
def test_identical_distributions_lambda_1(self):
|
|
"""Test that identical distributions give lambda=1."""
|
|
p = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
|
q = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
|
|
|
lam, div = find_lambda(p, q, alpha=2.0, beta=0.1)
|
|
|
|
assert lam == 1.0
|
|
assert div < 1e-6
|
|
|
|
def test_beta_zero_lambda_zero(self):
|
|
"""Test that beta=0 gives lambda=0."""
|
|
p = torch.tensor([0.7, 0.2, 0.1])
|
|
q = torch.tensor([0.1, 0.2, 0.7])
|
|
|
|
lam, div = find_lambda(p, q, alpha=2.0, beta=0.0)
|
|
|
|
assert lam == 0.0
|
|
assert div == 0.0
|
|
|
|
def test_lambda_in_range(self):
|
|
"""Test that lambda is in [0, 1]."""
|
|
p = torch.tensor([0.7, 0.2, 0.1])
|
|
q = torch.tensor([0.1, 0.2, 0.7])
|
|
|
|
lam, div = find_lambda(p, q, alpha=2.0, beta=0.5)
|
|
|
|
assert 0.0 <= lam <= 1.0
|
|
|
|
def test_divergence_respects_bound(self):
|
|
"""Test that returned divergence is <= beta."""
|
|
p = torch.tensor([0.6, 0.3, 0.1])
|
|
q = torch.tensor([0.2, 0.3, 0.5])
|
|
beta = 0.3
|
|
|
|
lam, div = find_lambda(p, q, alpha=2.0, beta=beta)
|
|
|
|
assert div <= beta + 1e-6 # Allow small numerical error
|
|
|
|
def test_higher_beta_higher_lambda(self):
|
|
"""Test that higher beta allows higher lambda."""
|
|
p = torch.tensor([0.8, 0.15, 0.05])
|
|
q = torch.tensor([0.1, 0.2, 0.7])
|
|
|
|
lam_low, _ = find_lambda(p, q, alpha=2.0, beta=0.1)
|
|
lam_high, _ = find_lambda(p, q, alpha=2.0, beta=0.5)
|
|
|
|
assert lam_low <= lam_high
|