diff --git a/src/main/java/ai/nomyo/SecureCompletionClient.java b/src/main/java/ai/nomyo/SecureCompletionClient.java index 486afef..92303b9 100644 --- a/src/main/java/ai/nomyo/SecureCompletionClient.java +++ b/src/main/java/ai/nomyo/SecureCompletionClient.java @@ -297,7 +297,7 @@ class SecureCompletionClient { * @throws SecurityError if encryption fails or keys not loaded */ @SuppressWarnings("JavadocDeclaration") - public CompletableFuture encryptPayload(Map payload) { + public CompletableFuture encryptPayload(Map payload) { return CompletableFuture.supplyAsync(() -> { try { ensureKeys(null); @@ -323,10 +323,9 @@ class SecureCompletionClient { // Serialize payload to JSON Gson gson = new Gson(); - String payloadJson = gson.toJson(payload); - byte[] payloadBytes = payloadJson.getBytes(StandardCharsets.UTF_8); + byte[] payloadJson = gson.toJson(payload).getBytes(StandardCharsets.UTF_8); - try (SecureBuffer securePayload = SecureMemory.secureByteArray(payloadBytes)) { + try (SecureBuffer securePayload = SecureMemory.secureByteArray(payloadJson)) { return doEncrypt(securePayload, aesKey).join(); } }); @@ -335,14 +334,51 @@ class SecureCompletionClient { /** * Core hybrid encryption: AES-256-GCM encrypts {@code securePayload} with {@code aesKey}. */ - public CompletableFuture doEncrypt(SecureBuffer securePayload, Key aesKey) { + public CompletableFuture doEncrypt(SecureBuffer securePayload, Key aesKey) { return CompletableFuture.supplyAsync(() -> { + String serverPEM; + + try { + serverPEM = fetchServerPublicKey().get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(new SecurityError("Encryption interrupted while fetching server public key", e)); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SecurityError) { + throw new RuntimeException(cause); + } + throw new RuntimeException(new SecurityError("Failed to fetch server public key: " + cause.getMessage(), cause)); + } + + X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM)); + + PublicKey serverPublicKey; + + try { + serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec); + } catch (InvalidKeySpecException | NoSuchAlgorithmException e) { + throw new RuntimeException(new SecurityError("RSA key factory failed to parse server public key: " + e.getMessage(), e)); + } + + Cipher rsa; + try { + OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1 + PSource.PSpecified.DEFAULT); + + rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding"); + rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey, oaepParams); + } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | + InvalidAlgorithmParameterException e) { + throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e)); + } + SecureRandom random = new SecureRandom(); byte[] nonce = new byte[Constants.GCM_NONCE_SIZE]; random.nextBytes(nonce); - try (SecureBuffer secureNonce = SecureMemory.secureByteArray(nonce)) { - Cipher cipher; + Cipher cipher; + try (SecureBuffer secureNonce = SecureMemory.secureByteArray(nonce)) { try { cipher = Cipher.getInstance("AES/GCM/NoPadding"); cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(Constants.GCM_TAG_SIZE * Byte.SIZE, secureNonce.getAsByteArray())); @@ -359,69 +395,31 @@ class SecureCompletionClient { throw new RuntimeException(new SecurityError("AES-GCM encryption failed: " + e.getMessage(), e)); } - String serverPEM; + byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length); - try { - serverPEM = fetchServerPublicKey().get(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(new SecurityError("Encryption interrupted while fetching server public key", e)); - } catch (ExecutionException e) { - Throwable cause = e.getCause(); - if (cause instanceof SecurityError) { - throw new RuntimeException(cause); + byte[] actualCiphertext = Arrays.copyOf(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE); + + try (SecureBuffer secureEncryptedCipherText = SecureMemory.secureByteArray(ciphertext); SecureBuffer secureTag = SecureMemory.secureByteArray(tag); SecureBuffer secureActualCipherText = SecureMemory.secureByteArray(actualCiphertext)) { + byte[] encryptedAESKey; + + try (SecureBuffer secureAesKeyEncoded = SecureMemory.secureByteArray(aesKey.getEncoded())) { + encryptedAESKey = rsa.doFinal(secureAesKeyEncoded.getAsByteArray()); + } catch (IllegalBlockSizeException | BadPaddingException e) { + throw new RuntimeException(new SecurityError("RSA-OAEP key wrapping failed: " + e.getMessage(), e)); } - throw new RuntimeException(new SecurityError("Failed to fetch server public key: " + cause.getMessage(), cause)); - } - X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM)); + try (SecureBuffer secureEncryptedAESKey = SecureMemory.secureByteArray(encryptedAESKey)) { + EncryptedRequest request = new EncryptedRequest(); - PublicKey serverPublicKey; + request.setVersion(Constants.PROTOCOL_VERSION); + request.setAlgorithm(Constants.HYBRID_ALGORITHM); + request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(secureActualCipherText.getAsByteArray()), Base64.getEncoder().encodeToString(secureNonce.getAsByteArray()), Base64.getEncoder().encodeToString(secureTag.getAsByteArray()))); + request.setEncryptedAESKey(Base64.getEncoder().encodeToString(secureEncryptedAESKey.getAsByteArray())); + request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM); + request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM); - try { - serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec); - } catch (InvalidKeySpecException | NoSuchAlgorithmException e) { - throw new RuntimeException(new SecurityError("RSA key factory failed to parse server public key: " + e.getMessage(), e)); - } - - Cipher rsa; - try { - OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1 - PSource.PSpecified.DEFAULT); - - rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding"); - rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey, oaepParams); - } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | - InvalidAlgorithmParameterException e) { - throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e)); - } - - byte[] encryptedAESKey; - - try (SecureBuffer secureAesKeyEncoded = SecureMemory.secureByteArray(aesKey.getEncoded())) { - encryptedAESKey = rsa.doFinal(secureAesKeyEncoded.getAsByteArray()); - } catch (IllegalBlockSizeException | BadPaddingException e) { - throw new RuntimeException(new SecurityError("RSA-OAEP key wrapping failed: " + e.getMessage(), e)); - } - - try (SecureBuffer secureEncryptedAESKey = SecureMemory.secureByteArray(encryptedAESKey)) { - - Arrays.fill(encryptedAESKey, (byte) 0); - - byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length); - - byte[] actualCiphertext = Arrays.copyOf(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE); - - EncryptedRequest request = new EncryptedRequest(); - - request.setVersion(Constants.PROTOCOL_VERSION); - request.setAlgorithm(Constants.HYBRID_ALGORITHM); - request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(actualCiphertext), Base64.getEncoder().encodeToString(secureNonce.getAsByteArray()), Base64.getEncoder().encodeToString(tag))); - request.setEncryptedAESKey(Base64.getEncoder().encodeToString(secureEncryptedAESKey.getAsByteArray())); - request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM); - request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM); - - return request.toJson().getBytes(StandardCharsets.UTF_8); + return SecureMemory.secureByteArray(request.toJson().getBytes(StandardCharsets.UTF_8)); + } } } }); @@ -465,7 +463,8 @@ class SecureCompletionClient { } // Step 1: Encrypt payload - byte[] encryptedPayload; + SecureBuffer encryptedPayload; + try { encryptedPayload = encryptPayload(payload).get(); } catch (InterruptedException e) { @@ -487,7 +486,9 @@ class SecureCompletionClient { throw new CompletionException(new APIConnectionError("Invalid URL: " + e.getMessage(), e)); } - HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url).timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)).header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM).header(Constants.HEADER_PAYLOAD_ID, payloadId).header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey)).POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload)); + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url).timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)).header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM).header(Constants.HEADER_PAYLOAD_ID, payloadId).header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey)).POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload.getAsByteArray())); + + encryptedPayload.close(); if (apiKey != null && !apiKey.isEmpty()) { requestBuilder.header("Authorization", Constants.AUTHORIZATION_BEARER_PREFIX + apiKey); @@ -791,11 +792,11 @@ class SecureCompletionClient { Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding"); aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, secureNonce.getAsByteArray())); - // Combine ciphertext (without tag) and tag for decryption using SecureBuffer - try (SecureBuffer secureCiphertextWithTag = SecureMemory.secureByteArray(new byte[ciphertext.length + tag.length])) { - ByteBuffer combinedBuf = secureCiphertextWithTag.getData().asByteBuffer(); - combinedBuf.put(secureCiphertext.getAsByteArray()); - combinedBuf.put(secureTag.getAsByteArray()); + // Combine ciphertext (without tag) and tag for decryption using SecureBuffer + try (SecureBuffer secureCiphertextWithTag = SecureMemory.secureByteArray(new byte[ciphertext.length + tag.length])) { + ByteBuffer combinedBuf = secureCiphertextWithTag.getData().asByteBuffer(); + combinedBuf.put(secureCiphertext.getAsByteArray()); + combinedBuf.put(secureTag.getAsByteArray()); byte[] plaintextBytes = aesCipher.doFinal(secureCiphertextWithTag.getAsByteArray());