From 9b5fa56215b7b7fba96e3af916542ee653587cff Mon Sep 17 00:00:00 2001 From: Oracle Date: Wed, 29 Apr 2026 15:14:24 +0200 Subject: [PATCH] Complete functionality --- src/main/java/ai/nomyo/Main.java | 3 +- .../java/ai/nomyo/SecureChatCompletion.java | 20 +-- .../java/ai/nomyo/SecureCompletionClient.java | 127 +++++++----------- src/main/java/ai/nomyo/util/Pass2Key.java | 42 +++--- src/test/java/ai/nomyo/EnsureKeysTest.java | 2 +- 5 files changed, 75 insertions(+), 119 deletions(-) diff --git a/src/main/java/ai/nomyo/Main.java b/src/main/java/ai/nomyo/Main.java index 7564fb0..ab87c9d 100644 --- a/src/main/java/ai/nomyo/Main.java +++ b/src/main/java/ai/nomyo/Main.java @@ -9,7 +9,7 @@ import java.util.Map; public class Main { static void main() { - SecureChatCompletion secureChatCompletion = new SecureChatCompletion( Constants.DEFAULT_BASE_URL, "NOMYO_AI_E2EE_INFERENCE"); + SecureChatCompletion secureChatCompletion = new SecureChatCompletion(Constants.DEFAULT_BASE_URL, "NOMYO_AI_E2EE_INFERENCE"); List> messages = List.of( Map.of("role", "user", "content", "Hello! How are you today?") ); @@ -26,5 +26,4 @@ public class Main { System.out.println(response.toString()); } - } diff --git a/src/main/java/ai/nomyo/SecureChatCompletion.java b/src/main/java/ai/nomyo/SecureChatCompletion.java index 4e8292d..5d4e9e1 100644 --- a/src/main/java/ai/nomyo/SecureChatCompletion.java +++ b/src/main/java/ai/nomyo/SecureChatCompletion.java @@ -30,6 +30,10 @@ public class SecureChatCompletion { this(baseUrl, false, apiKey, true, null, Constants.DEFAULT_MAX_RETRIES); } + public SecureChatCompletion(String baseUrl, String apiKey, boolean allowHttp) { + this(baseUrl, allowHttp, apiKey, true, null, Constants.DEFAULT_MAX_RETRIES); + } + public SecureChatCompletion(String baseUrl) { this(baseUrl, false, null, true, null, Constants.DEFAULT_MAX_RETRIES); } @@ -132,20 +136,10 @@ public class SecureChatCompletion { throw new RuntimeException("Request interrupted", e); } catch (ExecutionException e) { Throwable cause = e.getCause(); - switch (cause) { - case SecurityError securityError -> throw new RuntimeException(cause); - case InvalidRequestError invalidRequestError -> throw new RuntimeException(cause); - case AuthenticationError authenticationError -> throw new RuntimeException(cause); - case ForbiddenError forbiddenError -> throw new RuntimeException(cause); - case RateLimitError rateLimitError -> throw new RuntimeException(cause); - case ServerError serverError -> throw new RuntimeException(cause); - case ServiceUnavailableError serviceUnavailableError -> throw new RuntimeException(cause); - case APIError apiError -> throw new RuntimeException(cause); - case APIConnectionError apiConnectionError -> throw new RuntimeException(cause); - case SecureCompletionClient.ValueError valueError -> throw new IllegalArgumentException(cause); - default -> - throw new RuntimeException("Request failed: " + cause.getMessage(), cause); + if (cause instanceof SecureCompletionClient.ValueError) { + throw new IllegalArgumentException(cause); } + throw new RuntimeException(cause); } } diff --git a/src/main/java/ai/nomyo/SecureCompletionClient.java b/src/main/java/ai/nomyo/SecureCompletionClient.java index 8b8e17a..69d2a4e 100644 --- a/src/main/java/ai/nomyo/SecureCompletionClient.java +++ b/src/main/java/ai/nomyo/SecureCompletionClient.java @@ -7,6 +7,8 @@ import lombok.Getter; import javax.crypto.*; import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.OAEPParameterSpec; +import javax.crypto.spec.PSource; import javax.crypto.spec.SecretKeySpec; import java.io.IOException; import java.math.BigInteger; @@ -26,6 +28,7 @@ import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; + import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.gson.JsonParser; @@ -106,13 +109,22 @@ public class SecureCompletionClient { this.useSecureMemory = secureMemory; this.keySize = Constants.RSA_KEY_SIZE; this.maxRetries = maxRetries; - this.httpClient = HttpClient.newHttpClient(); + this.httpClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(); } private static String readFileContent(String filePath) throws IOException { return Files.readString(Path.of(filePath)); } + private static Map parseErrorBody(String body) { + try { + @SuppressWarnings("unchecked") Map parsed = (Map) new Gson().fromJson(body, Object.class); + return parsed != null ? parsed : Map.of(); + } catch (Exception e) { + return Map.of(); + } + } + /** * Generates a 4096-bit RSA key pair (exponent 65537). Saves to disk if {@code saveToFile}. */ @@ -175,7 +187,7 @@ public class SecureCompletionClient { } catch (InvalidAlgorithmParameterException e) { throw new SecurityError("Invalid RSA key generation parameters: " + e.getMessage(), e); } catch (IOException e) { - throw new RuntimeException("Failed to save keys: " + e.getMessage(), e); + throw new SecurityError("Failed to save keys: " + e.getMessage(), e); } } @@ -212,7 +224,8 @@ public class SecureCompletionClient { try { keyContent = Pass2Key.decrypt("AES/GCM/NoPadding", keyContent, password); } catch (NoSuchPaddingException | NoSuchAlgorithmException | BadPaddingException | - IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException | SecurityError e) { + IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException | + SecurityError e) { throw new SecurityError("Failed to decrypt private key with provided password: " + e.getMessage(), e); } } else { @@ -369,12 +382,18 @@ public class SecureCompletionClient { Cipher rsa; try { + OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1 + PSource.PSpecified.DEFAULT); + rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding"); - rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey); - } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) { + rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey, oaepParams); + } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | + InvalidAlgorithmParameterException e) { throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e)); } + byte[] encryptedAESKey; + try { encryptedAESKey = rsa.doFinal(aesKey.getEncoded()); } catch (IllegalBlockSizeException | BadPaddingException e) { @@ -383,11 +402,13 @@ public class SecureCompletionClient { byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length); + byte[] actualCiphertext = Arrays.copyOf(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE); + EncryptedRequest request = new EncryptedRequest(); request.setVersion(Constants.PROTOCOL_VERSION); request.setAlgorithm(Constants.HYBRID_ALGORITHM); - request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(ciphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag))); + request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(actualCiphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag))); request.setEncryptedAESKey(Base64.getEncoder().encodeToString(encryptedAESKey)); request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM); request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM); @@ -424,8 +445,7 @@ public class SecureCompletionClient { return CompletableFuture.supplyAsync(() -> { // Validate security tier if provided if (securityTier != null && !Constants.VALID_SECURITY_TIERS.contains(securityTier)) { - throw new CompletionException(new ValueError( - "Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS)); + throw new CompletionException(new ValueError("Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS)); } try { @@ -457,12 +477,7 @@ public class SecureCompletionClient { throw new CompletionException(new APIConnectionError("Invalid URL: " + e.getMessage(), e)); } - HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url) - .timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)) - .header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM) - .header(Constants.HEADER_PAYLOAD_ID, payloadId) - .header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey)) - .POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload)); + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url).timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)).header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM).header(Constants.HEADER_PAYLOAD_ID, payloadId).header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey)).POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload)); if (apiKey != null && !apiKey.isEmpty()) { requestBuilder.header("Authorization", Constants.AUTHORIZATION_BEARER_PREFIX + apiKey); @@ -496,61 +511,28 @@ public class SecureCompletionClient { return decryptResponse(response.body(), payloadId).get(); } else if (statusCode == 400) { String body = new String(response.body(), StandardCharsets.UTF_8); - Map errorDetails = Map.of(); - try { - @SuppressWarnings("unchecked") - Map parsed = (Map) new Gson().fromJson(body, Object.class); - if (parsed != null) errorDetails = parsed; - } catch (Exception ignored) { - } + Map errorDetails = parseErrorBody(body); String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Unknown error"; throw new CompletionException(new InvalidRequestError("Bad request: " + detail, 400, errorDetails)); } else if (statusCode == 401) { String body = new String(response.body(), StandardCharsets.UTF_8); - Map errorDetails = Map.of(); - try { - @SuppressWarnings("unchecked") - Map parsed = (Map) new Gson().fromJson(body, Object.class); - if (parsed != null) errorDetails = parsed; - } catch (Exception ignored) { - } + Map errorDetails = parseErrorBody(body); String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Invalid API key or authentication failed"; throw new CompletionException(new AuthenticationError(detail, 401, errorDetails)); } else if (statusCode == 403) { String body = new String(response.body(), StandardCharsets.UTF_8); - Map errorDetails = Map.of(); - try { - @SuppressWarnings("unchecked") - Map parsed = (Map) new Gson().fromJson(body, Object.class); - if (parsed != null) errorDetails = parsed; - } catch (Exception ignored) { - } + Map errorDetails = parseErrorBody(body); String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Model not allowed for the requested security tier"; throw new CompletionException(new ForbiddenError("Forbidden: " + detail, 403, errorDetails)); } else if (statusCode == 404) { String body = new String(response.body(), StandardCharsets.UTF_8); - Map errorDetails = Map.of(); - try { - @SuppressWarnings("unchecked") - Map parsed = (Map) new Gson().fromJson(body, Object.class); - if (parsed != null) errorDetails = parsed; - } catch (Exception ignored) { - } + Map errorDetails = parseErrorBody(body); String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Secure inference not enabled"; throw new CompletionException(new APIError("Endpoint not found: " + detail, 404, errorDetails)); } else if (Constants.RETRYABLE_STATUS_CODES.contains(statusCode)) { String body = new String(response.body(), StandardCharsets.UTF_8); - Map error = Map.of(); - String detailMsg = "unknown"; - try { - @SuppressWarnings("unchecked") - Map parsed = (Map) new Gson().fromJson(body, Object.class); - if (parsed != null) { - error = parsed; - if (error.containsKey("detail")) detailMsg = error.get("detail").toString(); - } - } catch (Exception ignored) { - } + Map error = parseErrorBody(body); + String detailMsg = error.containsKey("detail") ? error.get("detail").toString() : "unknown"; if (statusCode == 429) { lastExc = new RateLimitError("Rate limit exceeded: " + detailMsg, 429, error); @@ -568,23 +550,13 @@ public class SecureCompletionClient { throw new CompletionException(lastExc); } else { String body = new String(response.body(), StandardCharsets.UTF_8); - Map errorDetails = Map.of(); - String detailMsg = "unknown"; - try { - @SuppressWarnings("unchecked") - Map parsed = (Map) new Gson().fromJson(body, Object.class); - if (parsed != null) { - errorDetails = parsed; - if (errorDetails.containsKey("detail")) detailMsg = errorDetails.get("detail").toString(); - } - } catch (Exception ignored) { - } + Map errorDetails = parseErrorBody(body); + String detailMsg = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "unknown"; throw new CompletionException(new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, errorDetails)); } } catch (CompletionException e) { Throwable cause = e.getCause(); - if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError || - cause instanceof ForbiddenError || cause instanceof SecurityError) { + if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError || cause instanceof ForbiddenError || cause instanceof SecurityError) { throw e; } if (cause instanceof RateLimitError || cause instanceof ServerError || cause instanceof ServiceUnavailableError) { @@ -764,12 +736,10 @@ public class SecureCompletionClient { String algorithm = packageJson.get("algorithm").getAsString(); if (!Constants.PROTOCOL_VERSION.equals(version)) { - throw new CompletionException(new ValueError( - "Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'")); + throw new CompletionException(new ValueError("Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'")); } if (!Constants.HYBRID_ALGORITHM.equals(algorithm)) { - throw new CompletionException(new ValueError( - "Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'")); + throw new CompletionException(new ValueError("Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'")); } // Validate encrypted_payload structure @@ -789,9 +759,12 @@ public class SecureCompletionClient { try { // Decrypt AES key with private key byte[] encryptedAESKey = Base64.getDecoder().decode(packageJson.get("encrypted_aes_key").getAsString()); - + + OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1 + PSource.PSpecified.DEFAULT); + Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding"); - rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey); + rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey, oaepParams); byte[] aesKeyBytes = rsaCipher.doFinal(encryptedAESKey); SecretKeySpec aesKey = new SecretKeySpec(aesKeyBytes, "AES"); @@ -802,7 +775,7 @@ public class SecureCompletionClient { Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding"); aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce)); - + // Combine ciphertext (without tag) and tag for decryption byte[] ciphertextWithTag = new byte[ciphertext.length + tag.length]; System.arraycopy(ciphertext, 0, ciphertextWithTag, 0, ciphertext.length); @@ -814,8 +787,7 @@ public class SecureCompletionClient { Map response; try { Object parsed = gson.fromJson(new String(plaintextBytes, StandardCharsets.UTF_8), Object.class); - @SuppressWarnings("unchecked") - Map resultMap = (Map) parsed; + @SuppressWarnings("unchecked") Map resultMap = (Map) parsed; response = resultMap != null ? resultMap : new HashMap<>(); } catch (Exception e) { throw new CompletionException(new ValueError("Decrypted response is not valid JSON: " + e.getMessage())); @@ -825,8 +797,7 @@ public class SecureCompletionClient { if (!response.containsKey("_metadata")) { response.put("_metadata", new HashMap()); } - @SuppressWarnings("unchecked") - Map metadata = (Map) response.get("_metadata"); + @SuppressWarnings("unchecked") Map metadata = (Map) response.get("_metadata"); metadata.put("payload_id", payloadId); metadata.put("processed_at", packageJson.has("processed_at") ? packageJson.get("processed_at").getAsString() : null); metadata.put("is_encrypted", true); @@ -835,7 +806,7 @@ public class SecureCompletionClient { return response; } catch (Exception e) { - throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed")); + throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed: " + e.getMessage(), e)); } }); } diff --git a/src/main/java/ai/nomyo/util/Pass2Key.java b/src/main/java/ai/nomyo/util/Pass2Key.java index f6d8b71..6e7cc7c 100644 --- a/src/main/java/ai/nomyo/util/Pass2Key.java +++ b/src/main/java/ai/nomyo/util/Pass2Key.java @@ -10,7 +10,6 @@ import javax.crypto.SecretKeyFactory; import javax.crypto.spec.GCMParameterSpec; import javax.crypto.spec.PBEKeySpec; import javax.crypto.spec.SecretKeySpec; -import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.*; import java.security.spec.InvalidKeySpecException; @@ -42,12 +41,14 @@ public final class Pass2Key { */ public static String encrypt(String algorithm, String input, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException, SecurityError { - byte[] salt = generateRandomBytes(SALT_LENGTH); + byte[] salt = new byte[SALT_LENGTH]; + RANDOM.nextBytes(salt); SecretKey key = deriveKey(password, salt); byte[] payload; if (isGcmMode(algorithm)) { - byte[] iv = generateRandomBytes(GCM_IV_LENGTH); + byte[] iv = new byte[GCM_IV_LENGTH]; + RANDOM.nextBytes(iv); GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH, iv); byte[] ciphertext = encryptWithCipher(algorithm, key, spec, input); payload = assemblePayloadGcm(salt, iv, ciphertext); @@ -71,18 +72,16 @@ public final class Pass2Key { byte[] decoded = Base64.getDecoder().decode(cipherText); - byte[] salt = new byte[SALT_LENGTH]; - System.arraycopy(decoded, 0, salt, 0, SALT_LENGTH); + byte[] salt = java.util.Arrays.copyOfRange(decoded, 0, SALT_LENGTH); SecretKey key = deriveKey(password, salt); if (isGcmMode(algorithm)) { - byte[] iv = new byte[GCM_IV_LENGTH]; - System.arraycopy(decoded, SALT_LENGTH, iv, 0, GCM_IV_LENGTH); + byte[] iv = java.util.Arrays.copyOfRange(decoded, SALT_LENGTH, SALT_LENGTH + GCM_IV_LENGTH); GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH, iv); - byte[] ciphertext = copyFrom(decoded, SALT_LENGTH + GCM_IV_LENGTH); + byte[] ciphertext = java.util.Arrays.copyOfRange(decoded, SALT_LENGTH + GCM_IV_LENGTH, decoded.length); return decryptWithCipher(algorithm, key, spec, ciphertext); } else { - byte[] ciphertext = copyFrom(decoded, SALT_LENGTH); + byte[] ciphertext = java.util.Arrays.copyOfRange(decoded, SALT_LENGTH, decoded.length); return decryptWithCipher(algorithm, key, ciphertext); } } @@ -127,26 +126,19 @@ public final class Pass2Key { return algorithm.contains("GCM"); } - private static byte[] generateRandomBytes(int length) { - byte[] bytes = new byte[length]; - RANDOM.nextBytes(bytes); - return bytes; - } - - private static byte[] copyFrom(byte[] source, int offset) { - return java.util.Arrays.copyOfRange(source, offset, source.length); - } - private static byte[] assemblePayloadGcm(byte[] salt, byte[] iv, byte[] ciphertext) { - ByteBuffer buffer = ByteBuffer.allocate(salt.length + iv.length + ciphertext.length); - buffer.put(salt).put(iv).put(ciphertext); - return buffer.array(); + byte[] payload = new byte[salt.length + iv.length + ciphertext.length]; + System.arraycopy(salt, 0, payload, 0, salt.length); + System.arraycopy(iv, 0, payload, salt.length, iv.length); + System.arraycopy(ciphertext, 0, payload, salt.length + iv.length, ciphertext.length); + return payload; } private static byte[] assemblePayloadSalt(byte[] salt, byte[] ciphertext) { - ByteBuffer buffer = ByteBuffer.allocate(salt.length + ciphertext.length); - buffer.put(salt).put(ciphertext); - return buffer.array(); + byte[] payload = new byte[salt.length + ciphertext.length]; + System.arraycopy(salt, 0, payload, 0, salt.length); + System.arraycopy(ciphertext, 0, payload, salt.length, ciphertext.length); + return payload; } public static PrivateKey convertStringToPrivateKey(String privateKeyString) throws Exception { diff --git a/src/test/java/ai/nomyo/EnsureKeysTest.java b/src/test/java/ai/nomyo/EnsureKeysTest.java index 2fa4606..ebc75ea 100644 --- a/src/test/java/ai/nomyo/EnsureKeysTest.java +++ b/src/test/java/ai/nomyo/EnsureKeysTest.java @@ -62,7 +62,7 @@ class EnsureKeysTest { @Test @Execution(ExecutionMode.SAME_THREAD) @DisplayName("ensureKeys should be thread-safe with concurrent calls") - void ensureKeys_concurrentCalls_shouldBeThreadSafe() throws SecurityError, InterruptedException { + void ensureKeys_concurrentCalls_shouldBeThreadSafe() throws InterruptedException { SecureCompletionClient client = new SecureCompletionClient(); Thread[] threads = new Thread[5];