import time import uuid from typing import Optional import redis.asyncio as aioredis from loguru import logger from api.constants import REDIS_URL class RateLimiter: """Sliding window rate limiter to enforce strict per-second limits and concurrent call limits""" def __init__(self): self.redis_client: Optional[aioredis.Redis] = None self.stale_call_timeout = 1200 # 20 minutes in seconds async def _get_redis(self) -> aioredis.Redis: """Get or create Redis connection""" if self.redis_client is None: self.redis_client = await aioredis.from_url( REDIS_URL, decode_responses=True ) return self.redis_client async def acquire_token(self, organization_id: int, rate_limit: int = 1) -> bool: """ Enforces strict rate limit: max N calls per rolling second window Returns True if allowed, False if rate limited """ redis_client = await self._get_redis() key = f"rate_limit:{organization_id}" now = time.time() window_start = now - 1.0 # 1 second sliding window # Lua script for atomic sliding window operation lua_script = """ local key = KEYS[1] local now = tonumber(ARGV[1]) local window_start = tonumber(ARGV[2]) local max_requests = tonumber(ARGV[3]) -- Remove timestamps older than window redis.call('ZREMRANGEBYSCORE', key, 0, window_start) -- Count requests in current window local current_requests = redis.call('ZCARD', key) if current_requests < max_requests then -- Add current timestamp redis.call('ZADD', key, now, now) redis.call('EXPIRE', key, 2) -- Expire after 2 seconds return 1 else return 0 end """ try: result = await redis_client.eval( lua_script, 1, key, now, window_start, rate_limit ) return bool(result) except Exception as e: logger.error(f"Rate limiter error: {e}") # On error, be conservative and deny return False async def get_next_available_slot( self, organization_id: int, rate_limit: int = 1 ) -> float: """ Returns seconds until next available slot Useful for implementing retry with backoff """ redis_client = await self._get_redis() key = f"rate_limit:{organization_id}" try: # Get oldest timestamp in current window oldest = await redis_client.zrange(key, 0, 0, withscores=True) if not oldest: return 0.0 # Can call immediately oldest_time = oldest[0][1] next_available = oldest_time + 1.0 # 1 second after oldest wait_time = max(0, next_available - time.time()) return wait_time except Exception as e: logger.error(f"Rate limiter get_next_available_slot error: {e}") return 1.0 # Default wait time on error async def try_acquire_concurrent_slot( self, organization_id: int, max_concurrent: int = 20 ) -> Optional[str]: """ Try to acquire a concurrent call slot. Returns a unique slot_id if successful, None if limit reached. """ redis_client = await self._get_redis() concurrent_key = f"concurrent_calls:{organization_id}" now = time.time() stale_cutoff = now - self.stale_call_timeout # Lua script for atomic operation lua_script = """ local key = KEYS[1] local now = tonumber(ARGV[1]) local max_concurrent = tonumber(ARGV[2]) local stale_cutoff = tonumber(ARGV[3]) local slot_id = ARGV[4] -- Remove stale entries (older than 30 minutes) redis.call('ZREMRANGEBYSCORE', key, 0, stale_cutoff) -- Get current count local current_count = redis.call('ZCARD', key) if current_count < max_concurrent then -- Add new slot redis.call('ZADD', key, now, slot_id) redis.call('EXPIRE', key, 3600) -- Expire after 1 hour return slot_id else return nil end """ # Generate unique slot ID (timestamp + random component) slot_id = f"{int(now * 1000)}_{uuid.uuid4().hex[:8]}" try: result = await redis_client.eval( lua_script, 1, concurrent_key, now, max_concurrent, stale_cutoff, slot_id, ) return result except Exception as e: logger.error(f"Concurrent limiter error: {e}") return None async def release_concurrent_slot(self, organization_id: int, slot_id: str) -> bool: """ Release a concurrent call slot. Returns True if slot was released, False otherwise. """ if not slot_id: return False redis_client = await self._get_redis() concurrent_key = f"concurrent_calls:{organization_id}" try: removed = await redis_client.zrem(concurrent_key, slot_id) if removed: logger.debug( f"Released concurrent slot {slot_id} for org {organization_id}" ) return bool(removed) except Exception as e: logger.error(f"Error releasing concurrent slot: {e}") return False async def get_concurrent_count(self, organization_id: int) -> int: """ Get current number of active concurrent calls for an organization. Automatically cleans up stale entries. """ redis_client = await self._get_redis() concurrent_key = f"concurrent_calls:{organization_id}" try: # Clean up stale entries first stale_cutoff = time.time() - self.stale_call_timeout await redis_client.zremrangebyscore(concurrent_key, 0, stale_cutoff) # Get current count count = await redis_client.zcard(concurrent_key) return count except Exception as e: logger.error(f"Error getting concurrent count: {e}") return 0 async def store_workflow_slot_mapping( self, workflow_run_id: int, organization_id: int, slot_id: str ) -> bool: """ Store the mapping between workflow_run_id and its concurrent slot. Used for cleanup when calls complete. """ redis_client = await self._get_redis() mapping_key = f"workflow_slot_mapping:{workflow_run_id}" try: # Store as a hash with TTL await redis_client.hset( mapping_key, mapping={"org_id": organization_id, "slot_id": slot_id} ) # Set expiry to match stale timeout await redis_client.expire(mapping_key, self.stale_call_timeout) return True except Exception as e: logger.error(f"Error storing workflow slot mapping: {e}") return False async def get_workflow_slot_mapping( self, workflow_run_id: int ) -> Optional[tuple[int, str]]: """ Get the concurrent slot mapping for a workflow run. Returns (organization_id, slot_id) tuple or None if not found. """ redis_client = await self._get_redis() mapping_key = f"workflow_slot_mapping:{workflow_run_id}" try: mapping = await redis_client.hgetall(mapping_key) if mapping and "org_id" in mapping and "slot_id" in mapping: return (int(mapping["org_id"]), mapping["slot_id"]) return None except Exception as e: logger.error(f"Error getting workflow slot mapping: {e}") return None async def delete_workflow_slot_mapping(self, workflow_run_id: int) -> bool: """ Delete the workflow slot mapping after releasing the slot. """ redis_client = await self._get_redis() mapping_key = f"workflow_slot_mapping:{workflow_run_id}" try: deleted = await redis_client.delete(mapping_key) return bool(deleted) except Exception as e: logger.error(f"Error deleting workflow slot mapping: {e}") return False # ======== FROM NUMBER POOL METHODS ======== @staticmethod def _from_number_pool_key( organization_id: int, telephony_configuration_id: int | None ) -> str: return f"from_number_pool:{organization_id}:{telephony_configuration_id}" async def initialize_from_number_pool( self, organization_id: int, from_numbers: list[str], telephony_configuration_id: int | None, ) -> bool: """ Initialize the from_number pool for an organization + telephony config. Uses ZADD NX so it won't overwrite numbers that are already in use. Pools are scoped per (organization_id, telephony_configuration_id) so that orgs with multiple telephony configurations do not leak caller IDs across configs. """ if not from_numbers: return False redis_client = await self._get_redis() key = self._from_number_pool_key(organization_id, telephony_configuration_id) try: # ZADD NX: only add members that don't already exist (preserves in-use scores) members = {number: 0 for number in from_numbers} await redis_client.zadd(key, members, nx=True) await redis_client.expire(key, 3600) # 1 hour TTL return True except Exception as e: logger.error(f"Error initializing from_number pool: {e}") return False async def acquire_from_number( self, organization_id: int, telephony_configuration_id: int | None ) -> Optional[str]: """ Atomically acquire an available from_number from the pool for the given (organization_id, telephony_configuration_id). Cleans stale entries (score > 0 and older than 30 min) before acquiring. Returns the phone number if available, None if all numbers are in use. """ redis_client = await self._get_redis() key = self._from_number_pool_key(organization_id, telephony_configuration_id) now = time.time() stale_cutoff = now - self.stale_call_timeout lua_script = """ local key = KEYS[1] local now = tonumber(ARGV[1]) local stale_cutoff = tonumber(ARGV[2]) -- Clean stale entries: members with score > 0 and score < stale_cutoff local stale = redis.call('ZRANGEBYSCORE', key, 1, stale_cutoff) for i, member in ipairs(stale) do redis.call('ZADD', key, 0, member) end -- Find all available numbers (score == 0) local available = redis.call('ZRANGEBYSCORE', key, 0, 0) if #available == 0 then return nil end -- Pick a random number from the available pool for uniform distribution local idx = math.random(#available) local chosen = available[idx] -- Mark as in-use with current timestamp redis.call('ZADD', key, now, chosen) return chosen """ try: result = await redis_client.eval(lua_script, 1, key, now, stale_cutoff) if result: logger.debug(f"Acquired from_number {result} for org {organization_id}") return result except Exception as e: logger.error(f"Error acquiring from_number: {e}") return None async def release_from_number( self, organization_id: int, from_number: str, telephony_configuration_id: int | None, ) -> bool: """ Release a from_number back to its (org, telephony config) pool by setting its score to 0. Harmless if already released (score already 0). """ if not from_number: return False redis_client = await self._get_redis() key = self._from_number_pool_key(organization_id, telephony_configuration_id) lua_script = """ local key = KEYS[1] local from_number = ARGV[1] local score = redis.call('ZSCORE', key, from_number) if score then redis.call('ZADD', key, 0, from_number) return 1 end return 0 """ try: result = await redis_client.eval(lua_script, 1, key, from_number) if result: logger.debug( f"Released from_number {from_number} for org {organization_id}" ) return bool(result) except Exception as e: logger.error(f"Error releasing from_number: {e}") return False async def store_workflow_from_number_mapping( self, workflow_run_id: int, organization_id: int, from_number: str, telephony_configuration_id: int | None, ) -> bool: """ Store the mapping between workflow_run_id and its from_number, plus the telephony_configuration_id so cleanup can release back to the correct pool. """ redis_client = await self._get_redis() mapping_key = f"workflow_from_number:{workflow_run_id}" try: # Redis hashes can't store None — use empty string sentinel for legacy # campaigns whose telephony_configuration_id has not been backfilled. tcid_value = ( "" if telephony_configuration_id is None else telephony_configuration_id ) await redis_client.hset( mapping_key, mapping={ "org_id": organization_id, "from_number": from_number, "telephony_configuration_id": tcid_value, }, ) await redis_client.expire(mapping_key, 1800) # 30 min TTL return True except Exception as e: logger.error(f"Error storing workflow from_number mapping: {e}") return False async def get_workflow_from_number_mapping( self, workflow_run_id: int ) -> Optional[tuple[int, str, int | None]]: """ Get the from_number mapping for a workflow run. Returns (organization_id, from_number, telephony_configuration_id) or None if not found. telephony_configuration_id is None for legacy entries. """ redis_client = await self._get_redis() mapping_key = f"workflow_from_number:{workflow_run_id}" try: mapping = await redis_client.hgetall(mapping_key) if mapping and "org_id" in mapping and "from_number" in mapping: raw_tcid = mapping.get("telephony_configuration_id", "") tcid = int(raw_tcid) if raw_tcid not in (None, "") else None return (int(mapping["org_id"]), mapping["from_number"], tcid) return None except Exception as e: logger.error(f"Error getting workflow from_number mapping: {e}") return None async def delete_workflow_from_number_mapping(self, workflow_run_id: int) -> bool: """ Delete the workflow from_number mapping after releasing the number. """ redis_client = await self._get_redis() mapping_key = f"workflow_from_number:{workflow_run_id}" try: deleted = await redis_client.delete(mapping_key) return bool(deleted) except Exception as e: logger.error(f"Error deleting workflow from_number mapping: {e}") return False async def close(self): """Close Redis connection""" if self.redis_client: await self.redis_client.close() self.redis_client = None # Global rate limiter instance rate_limiter = RateLimiter()