diff --git a/src/api/SecureChatCompletion.ts b/src/api/SecureChatCompletion.ts index 53c92e2..dd36c67 100644 --- a/src/api/SecureChatCompletion.ts +++ b/src/api/SecureChatCompletion.ts @@ -10,6 +10,8 @@ import { ChatCompletionRequest, ChatCompletionResponse } from '../types/api'; export class SecureChatCompletion { private client: SecureCompletionClient; private apiKey?: string; + /** Stored config used to spin up a temporary per-request instance when base_url is overridden */ + private readonly _config: ChatCompletionConfig; constructor(config: ChatCompletionConfig = {}) { const { @@ -22,8 +24,11 @@ export class SecureChatCompletion { keyRotationInterval, keyRotationDir, keyRotationPassword, + maxRetries, + keyDir, } = config; + this._config = config; this.apiKey = apiKey; this.client = new SecureCompletionClient({ routerUrl: baseUrl, @@ -34,6 +39,8 @@ export class SecureChatCompletion { ...(keyRotationInterval !== undefined && { keyRotationInterval }), ...(keyRotationDir !== undefined && { keyRotationDir }), ...(keyRotationPassword !== undefined && { keyRotationPassword }), + ...(maxRetries !== undefined && { maxRetries }), + ...(keyDir !== undefined && { keyDir }), }); } @@ -43,18 +50,19 @@ export class SecureChatCompletion { * Supports additional NOMYO-specific fields: * - `security_tier`: "standard" | "high" | "maximum" — controls hardware routing * - `api_key`: per-request API key override (takes precedence over constructor key) + * - `base_url`: per-request router URL override (creates a temporary client for + * this single call, matching the Python SDK's `create(base_url=...)` behaviour) */ async create(request: ChatCompletionRequest): Promise { const payloadId = generateUUID(); // Extract NOMYO-specific fields that must not go into the encrypted payload - const { security_tier, api_key, ...payload } = request as ChatCompletionRequest & { + const { security_tier, api_key, base_url, ...payload } = request as ChatCompletionRequest & { security_tier?: string; api_key?: string; + base_url?: string; }; - const apiKey = api_key ?? this.apiKey; - if (!payload.model) { throw new Error('Missing required field: model'); } @@ -62,6 +70,29 @@ export class SecureChatCompletion { throw new Error('Missing or invalid required field: messages'); } + const apiKey = api_key ?? this.apiKey; + + // Per-request base_url: spin up a temporary client for this one call, + // inheriting all other config from the current instance. + if (base_url !== undefined) { + const tempInstance = new SecureChatCompletion({ + ...this._config, + baseUrl: base_url, + apiKey: this.apiKey, + }); + try { + const response = await tempInstance.client.sendSecureRequest( + payload, + payloadId, + apiKey, + security_tier + ); + return response as unknown as ChatCompletionResponse; + } finally { + tempInstance.dispose(); + } + } + const response = await this.client.sendSecureRequest( payload, payloadId, diff --git a/src/core/SecureCompletionClient.ts b/src/core/SecureCompletionClient.ts index de88dfa..1edb1cd 100644 --- a/src/core/SecureCompletionClient.ts +++ b/src/core/SecureCompletionClient.ts @@ -72,6 +72,8 @@ export class SecureCompletionClient { private keyRotationTimer?: ReturnType; private readonly keyRotationDir?: string; private readonly keyRotationPassword?: string; + private readonly maxRetries: number; + private readonly keyDir: string; private _isHttps: boolean = true; // Promise-based mutex: serialises concurrent ensureKeys() calls @@ -88,6 +90,8 @@ export class SecureCompletionClient { keyRotationInterval = 86400000, // 24 hours keyRotationDir, keyRotationPassword, + maxRetries = 2, + keyDir = 'client_keys', } = config; this.debugMode = debug; @@ -95,6 +99,8 @@ export class SecureCompletionClient { this.keyRotationInterval = keyRotationInterval; this.keyRotationDir = keyRotationDir; this.keyRotationPassword = keyRotationPassword; + this.maxRetries = maxRetries; + this.keyDir = keyDir; this.keySize = keySize; this.allowHttp = allowHttp; this.secureMemory = secureMemory; @@ -243,29 +249,29 @@ export class SecureCompletionClient { private async _doEnsureKeys(): Promise { if (this.keyManager.hasKeys()) return; - // Try to load keys from default location (Node.js only) + // Try to load keys from the configured directory (Node.js only) if (typeof window === 'undefined') { try { const fs = require('fs').promises as { access: (p: string) => Promise }; const path = require('path') as { join: (...p: string[]) => string }; - const privateKeyPath = path.join('client_keys', 'private_key.pem'); - const publicKeyPath = path.join('client_keys', 'public_key.pem'); + const privateKeyPath = path.join(this.keyDir, 'private_key.pem'); + const publicKeyPath = path.join(this.keyDir, 'public_key.pem'); await fs.access(privateKeyPath); await fs.access(publicKeyPath); await this.loadKeys(privateKeyPath, publicKeyPath); - if (this.debugMode) console.log('Loaded existing keys from client_keys/'); + if (this.debugMode) console.log(`Loaded existing keys from ${this.keyDir}/`); return; } catch (_error) { - if (this.debugMode) console.log('No existing keys found, generating new keys...'); + if (this.debugMode) console.log(`No existing keys found in ${this.keyDir}/, generating new keys...`); } } await this.generateKeys({ saveToFile: typeof window === 'undefined', - keyDir: 'client_keys', + keyDir: this.keyDir, }); } @@ -439,6 +445,22 @@ export class SecureCompletionClient { } } + // Validate version and algorithm to prevent downgrade attacks + const SUPPORTED_VERSION = '1.0'; + const SUPPORTED_ALGORITHM = 'hybrid-aes256-rsa4096'; + if (packageData.version !== SUPPORTED_VERSION) { + throw new Error( + `Unsupported protocol version: '${String(packageData.version)}'. ` + + `Expected: '${SUPPORTED_VERSION}'` + ); + } + if (packageData.algorithm !== SUPPORTED_ALGORITHM) { + throw new Error( + `Unsupported encryption algorithm: '${String(packageData.algorithm)}'. ` + + `Expected: '${SUPPORTED_ALGORITHM}'` + ); + } + const encryptedPayload = packageData.encrypted_payload as Record; if (typeof encryptedPayload !== 'object' || encryptedPayload === null) { throw new Error('Invalid encrypted_payload: must be an object'); @@ -515,6 +537,9 @@ export class SecureCompletionClient { /** * Send a secure chat completion request to the router. * + * Retries on transient errors (429, 500, 502, 503, 504, network errors) + * with exponential backoff matching the Python SDK's `max_retries` behaviour. + * * @param securityTier Optional routing tier: "standard" | "high" | "maximum" */ async sendSecureRequest( @@ -545,8 +570,6 @@ export class SecureCompletionClient { await this.ensureKeys(); - const encryptedPayload = await this.encryptPayload(payload); - const publicKeyPem = await this.keyManager.getPublicKeyPEM(); const headers: Record = { 'X-Payload-ID': payloadId, @@ -565,31 +588,86 @@ export class SecureCompletionClient { const url = `${this.routerUrl}/v1/chat/secure_completion`; if (this.debugMode) console.log(`Target URL: ${url}`); - let response: { statusCode: number; body: ArrayBuffer }; - try { - response = await this.httpClient.post(url, { - headers, - body: encryptedPayload, - timeout: this.requestTimeout, - }); - } catch (error) { - if (error instanceof Error) { - if (error.message === 'Request timeout') { - throw new APIConnectionError('Connection to server timed out'); + // Retry loop — mirrors Python SDK's max_retries + exponential backoff. + // The payload is re-encrypted on every attempt so each attempt gets a + // fresh AES key and nonce (the HTTP client zeros the buffer after write). + let lastError: Error = new APIConnectionError('Request failed'); + + for (let attempt = 0; attempt <= this.maxRetries; attempt++) { + if (attempt > 0) { + const delaySec = Math.pow(2, attempt - 1); // 1 s, 2 s, 4 s, … + if (this.debugMode) { + console.warn( + `Retrying request (attempt ${attempt}/${this.maxRetries}) ` + + `after ${delaySec}s...` + ); } - throw new APIConnectionError(`Failed to connect to router: ${error.message}`); + await new Promise(resolve => setTimeout(resolve, delaySec * 1000)); } - throw error; + + // Re-encrypt each attempt (throws non-retryable errors like SecurityError + // or DisposedError — let those propagate immediately) + const encryptedPayload = await this.encryptPayload(payload); + + let response: { statusCode: number; body: ArrayBuffer }; + try { + response = await this.httpClient.post(url, { + headers, + body: encryptedPayload, + timeout: this.requestTimeout, + }); + } catch (error) { + // Network / timeout errors from the HTTP client + let connError: APIConnectionError; + if (error instanceof Error) { + connError = error.message === 'Request timeout' + ? new APIConnectionError('Connection to server timed out') + : new APIConnectionError(`Failed to connect to router: ${error.message}`); + } else { + connError = new APIConnectionError('Failed to connect to router: unknown error'); + } + lastError = connError; + if (attempt < this.maxRetries) { + if (this.debugMode) console.warn(`Network error on attempt ${attempt}: ${connError.message}`); + continue; + } + throw lastError; + } + + if (this.debugMode) console.log(`HTTP Status: ${response.statusCode}`); + + if (response.statusCode === 200) { + return await this.decryptResponse(response.body, payloadId); + } + + const err = this.buildErrorFromResponse(response); + + if (this.isRetryableError(err) && attempt < this.maxRetries) { + if (this.debugMode) { + console.warn(`Got retryable status ${response.statusCode}: retrying...`); + } + lastError = err; + continue; + } + + throw err; } - if (this.debugMode) console.log(`HTTP Status: ${response.statusCode}`); + throw lastError; + } - if (response.statusCode === 200) { - return await this.decryptResponse(response.body, payloadId); - } - - // Map HTTP error status codes to typed errors - throw this.buildErrorFromResponse(response); + /** + * Return true for errors that warrant a retry (transient failures). + * Non-retryable errors (auth, bad request, forbidden, etc.) propagate immediately. + */ + private isRetryableError(error: Error): boolean { + if (error instanceof APIConnectionError) return true; + if (error instanceof RateLimitError) return true; + if (error instanceof ServerError) return true; + if (error instanceof ServiceUnavailableError) return true; + // 502 Bad Gateway and 504 Gateway Timeout fall through as generic APIError + if (error instanceof APIError && (error.statusCode === 502 || error.statusCode === 504)) return true; + return false; } /** diff --git a/src/core/memory/secure.ts b/src/core/memory/secure.ts index 6e1cd15..9d6e0d7 100644 --- a/src/core/memory/secure.ts +++ b/src/core/memory/secure.ts @@ -1,15 +1,46 @@ /** - * Secure memory interface and context manager - * + * Secure memory interface, context manager, and public API. + * * IMPORTANT: This is a pure JavaScript implementation that provides memory zeroing only. * OS-level memory locking (mlock) is NOT implemented in this version. - * + * * For production use, consider implementing a native addon for true memory locking. * See SECURITY.md for details on memory protection limitations. */ import { ProtectionInfo } from '../../types/crypto'; +// ─── Global secure-memory state ────────────────────────────────────────────── + +/** Module-level flag, mirrors Python's global _secure_memory.enabled. */ +let _globalSecureMemoryEnabled = true; + +/** + * Disable secure memory operations globally. + * Affects new SecureByteContext instances created without an explicit `useSecure` argument. + * Existing client instances are unaffected (they pass `useSecure` explicitly). + * Mirrors Python's `disable_secure_memory()`. + */ +export function disableSecureMemory(): void { + _globalSecureMemoryEnabled = false; +} + +/** + * Re-enable secure memory operations globally. + * Mirrors Python's `enable_secure_memory()`. + */ +export function enableSecureMemory(): void { + _globalSecureMemoryEnabled = true; +} + +/** + * Return information about the memory protection capabilities available on this + * platform/runtime. Mirrors Python's `get_memory_protection_info()`. + */ +export function getMemoryProtectionInfo(): ProtectionInfo { + return createSecureMemory().getProtectionInfo(); +} + export interface SecureMemory { /** * Zero memory (fill with zeros) @@ -24,15 +55,19 @@ export interface SecureMemory { } /** - * Secure byte context manager - * Ensures memory is zeroed even if an exception occurs (similar to Python's context manager) + * Secure byte context manager. + * Ensures memory is zeroed even if an exception occurs (analogous to Python's + * `secure_bytearray()` context manager and `SecureBuffer` class). + * + * When `useSecure` is omitted, the module-level global flag set by + * `disableSecureMemory()` / `enableSecureMemory()` is consulted. */ export class SecureByteContext { private data: ArrayBuffer; private secureMemory: SecureMemory; private useSecure: boolean; - constructor(data: ArrayBuffer, useSecure: boolean = true) { + constructor(data: ArrayBuffer, useSecure: boolean = _globalSecureMemoryEnabled) { this.data = data; this.useSecure = useSecure; this.secureMemory = createSecureMemory(); diff --git a/src/index.ts b/src/index.ts index ccc92c7..7baa07a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -13,3 +13,13 @@ export * from './types/crypto'; // Export errors export * from './errors'; + +// Secure memory public API — mirrors Python's get_memory_protection_info(), +// disable_secure_memory(), enable_secure_memory(), and SecureBuffer/secure_bytearray() +export { + getMemoryProtectionInfo, + disableSecureMemory, + enableSecureMemory, + SecureByteContext, + createSecureMemory, +} from './core/memory/secure'; diff --git a/src/types/client.ts b/src/types/client.ts index 973c59e..342e6f4 100644 --- a/src/types/client.ts +++ b/src/types/client.ts @@ -32,6 +32,20 @@ export interface ClientConfig { /** Password to encrypt rotated private key files */ keyRotationPassword?: string; + + /** + * Directory to load/save RSA keys on startup. + * If the directory contains an existing key pair it is loaded; otherwise a + * new pair is generated and saved there. Default: 'client_keys'. + * Matches the Python SDK's `key_dir` constructor parameter. + */ + keyDir?: string; + + /** + * Maximum number of retries on retryable errors (429, 500, 502, 503, 504, + * network errors). Uses exponential backoff (1 s, 2 s, 4 s, …). Default: 2. + */ + maxRetries?: number; } export interface KeyGenOptions { @@ -83,4 +97,19 @@ export interface ChatCompletionConfig { /** Password to encrypt rotated private key files */ keyRotationPassword?: string; + + /** + * Directory to load/save RSA keys on startup. + * If the directory contains an existing key pair it is loaded; otherwise a + * new pair is generated and saved there. + * Omit (or set to undefined) to use the default 'client_keys/' directory. + * Matches the Python SDK's `key_dir` constructor parameter. + */ + keyDir?: string; + + /** + * Maximum number of retries on retryable errors (429, 500, 502, 503, 504, + * network errors). Uses exponential backoff (1 s, 2 s, 4 s, …). Default: 2. + */ + maxRetries?: number; }