dograh/api/tests/test_usage_concurrency.py
Abhishek Kumar 4f2a629340 Initial Commit 🚀 🚀
2025-09-09 14:37:32 +05:30

152 lines
4.9 KiB
Python

#!/usr/bin/env python
"""
Test script to verify atomic operations in organization_usage_client.py
This simulates concurrent access from multiple processes.
"""
import asyncio
import os
from concurrent.futures import ProcessPoolExecutor
# Set up environment
os.environ.setdefault("DATABASE_URL", os.environ.get("DATABASE_URL", ""))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from api.db.organization_usage_client import OrganizationUsageClient
async def reserve_quota_process(org_id: int, tokens: int, process_id: int):
"""Simulate a process trying to reserve quota."""
engine = create_async_engine(os.environ["DATABASE_URL"])
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
client = OrganizationUsageClient(async_session)
results = []
for i in range(5):
result = await client.check_and_reserve_quota(org_id, tokens)
results.append((process_id, i, result))
await asyncio.sleep(0.01) # Small delay to increase contention
await engine.dispose()
return results
async def update_usage_process(org_id: int, tokens: int, process_id: int):
"""Simulate a process updating usage after runs."""
engine = create_async_engine(os.environ["DATABASE_URL"])
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
client = OrganizationUsageClient(async_session)
for i in range(5):
await client.update_usage_after_run(org_id, tokens, duration_seconds=10)
await asyncio.sleep(0.01)
await engine.dispose()
return f"Process {process_id} completed updates"
def run_reserve_quota(args):
"""Wrapper to run async function in process."""
org_id, tokens, process_id = args
return asyncio.run(reserve_quota_process(org_id, tokens, process_id))
def run_update_usage(args):
"""Wrapper to run async function in process."""
org_id, tokens, process_id = args
return asyncio.run(update_usage_process(org_id, tokens, process_id))
async def test_concurrent_quota_reservation():
"""Test that concurrent quota reservations are handled atomically."""
print("Testing concurrent quota reservations...")
# Assuming org_id 1 exists with quota enabled
org_id = 1
tokens_per_request = 100
# Run multiple processes trying to reserve quota simultaneously
with ProcessPoolExecutor(max_workers=3) as executor:
futures = []
for i in range(3):
futures.append(
executor.submit(run_reserve_quota, (org_id, tokens_per_request, i))
)
results = []
for future in futures:
results.extend(future.result())
print(f"Reservation results: {results}")
# Check that reservations were handled atomically
successful_reservations = sum(1 for _, _, success in results if success)
print(f"Successful reservations: {successful_reservations}")
async def test_concurrent_usage_updates():
"""Test that concurrent usage updates are handled atomically."""
print("\nTesting concurrent usage updates...")
org_id = 1
tokens_per_update = 50
# Get initial usage
engine = create_async_engine(os.environ["DATABASE_URL"])
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
client = OrganizationUsageClient(async_session)
initial_usage = await client.get_current_usage(org_id)
initial_tokens = initial_usage["used_dograh_tokens"]
print(f"Initial tokens: {initial_tokens}")
# Run multiple processes updating usage simultaneously
with ProcessPoolExecutor(max_workers=3) as executor:
futures = []
for i in range(3):
futures.append(
executor.submit(run_update_usage, (org_id, tokens_per_update, i))
)
for future in futures:
print(future.result())
# Check final usage
final_usage = await client.get_current_usage(org_id)
final_tokens = final_usage["used_dograh_tokens"]
expected_tokens = initial_tokens + (
3 * 5 * tokens_per_update
) # 3 processes * 5 updates * 50 tokens
print(f"Final tokens: {final_tokens}")
print(f"Expected tokens: {expected_tokens}")
print(f"Difference: {final_tokens - expected_tokens}")
await engine.dispose()
if final_tokens == expected_tokens:
print("✅ All updates were applied atomically!")
else:
print("❌ Some updates were lost due to race conditions!")
async def main():
"""Run all concurrency tests."""
try:
await test_concurrent_quota_reservation()
await test_concurrent_usage_updates()
except Exception as e:
print(f"Error during testing: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
print("Starting organization usage concurrency tests...")
print(f"Using DATABASE_URL: {os.environ.get('DATABASE_URL', 'NOT SET')}")
asyncio.run(main())