Finish functionality

This commit is contained in:
Oracle 2026-04-26 18:21:05 +02:00
parent b6af1c9792
commit 89d5282b0f
Signed by: Oracle
SSH key fingerprint: SHA256:x4/RtnjUyuHkdvmwNDsWSfcfF1V5PNr3OpriZqOvCX8
9 changed files with 583 additions and 133 deletions

View file

@ -4,6 +4,7 @@ import ai.nomyo.errors.*;
import ai.nomyo.util.PEMConverter;
import ai.nomyo.util.Pass2Key;
import lombok.Getter;
import lombok.Setter;
import javax.crypto.*;
import javax.crypto.spec.GCMParameterSpec;
@ -27,6 +28,10 @@ import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import java.util.concurrent.locks.ReentrantLock;
/**
@ -113,7 +118,7 @@ public class SecureCompletionClient {
/**
* Generates a 4096-bit RSA key pair (exponent 65537). Saves to disk if {@code saveToFile}.
*/
public void generateKeys(boolean saveToFile, String keyDir, String password) {
public void generateKeys(boolean saveToFile, String keyDir, String password) throws SecurityError {
try {
KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA");
generator.initialize(new RSAKeyGenParameterSpec(Constants.RSA_KEY_SIZE, BigInteger.valueOf(Constants.RSA_PUBLIC_EXPONENT)));
@ -145,8 +150,8 @@ public class SecureCompletionClient {
try {
privatePem = Pass2Key.encrypt("AES/GCM/NoPadding", privatePem, password);
} catch (NoSuchPaddingException | IllegalBlockSizeException | BadPaddingException |
InvalidKeyException e) {
throw new RuntimeException(e);
InvalidKeyException | SecurityError e) {
throw new SecurityError("Failed to encrypt private key with password: " + e.getMessage(), e);
}
}
writer.write(privatePem);
@ -167,10 +172,12 @@ public class SecureCompletionClient {
this.privateKey = pair.getPrivate();
this.publicPemKey = publicPem;
} catch (SecurityError e) {
throw e;
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("RSA not available: " + e.getMessage(), e);
throw new SecurityError("RSA algorithm not available: " + e.getMessage(), e);
} catch (InvalidAlgorithmParameterException e) {
throw new RuntimeException(e);
throw new SecurityError("Invalid RSA key generation parameters: " + e.getMessage(), e);
} catch (IOException e) {
throw new RuntimeException("Failed to save keys: " + e.getMessage(), e);
}
@ -179,7 +186,7 @@ public class SecureCompletionClient {
/**
* Generates a 4096-bit RSA key pair and saves to the default directory.
*/
public void generateKeys(boolean saveToFile) {
public void generateKeys(boolean saveToFile) throws SecurityError {
generateKeys(saveToFile, Constants.DEFAULT_KEY_DIR, null);
}
@ -190,11 +197,12 @@ public class SecureCompletionClient {
* @param privateKeyPath private key PEM path
* @param publicPemKeyPath optional public key PEM path
* @param password optional password for encrypted private key
* @throws SecurityError if key file not found, unreadable, or decryption fails
*/
public void loadKeys(String privateKeyPath, String publicPemKeyPath, String password) {
public void loadKeys(String privateKeyPath, String publicPemKeyPath, String password) throws SecurityError {
Path keyPath = Path.of(privateKeyPath);
if (!Files.exists(keyPath)) {
throw new RuntimeException("Private key file not found: " + privateKeyPath);
throw new SecurityError("Private key file not found: " + privateKeyPath);
}
String keyContent;
@ -202,35 +210,36 @@ public class SecureCompletionClient {
try {
keyContent = readFileContent(privateKeyPath);
} catch (IOException e) {
throw new RuntimeException("Failed to read private key file: " + e.getMessage(), e);
throw new SecurityError("Failed to read private key file: " + e.getMessage(), e);
}
try {
keyContent = Pass2Key.decrypt("AES/GCM/NoPadding", keyContent, password);
} catch (NoSuchPaddingException | NoSuchAlgorithmException | BadPaddingException |
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException e) {
System.out.println("Wrong password!");
return;
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException | SecurityError e) {
throw new SecurityError("Failed to decrypt private key with provided password: " + e.getMessage(), e);
}
} else {
try {
keyContent = readFileContent(privateKeyPath);
} catch (IOException e) {
throw new RuntimeException("Failed to read private key file: " + e.getMessage(), e);
throw new SecurityError("Failed to read private key file: " + e.getMessage(), e);
}
}
try {
this.privateKey = Pass2Key.convertStringToPrivateKey(keyContent);
} catch (Exception e) {
throw new RuntimeException("Failed to load private key: " + e.getMessage(), e);
throw new SecurityError("Failed to load private key: " + e.getMessage(), e);
}
}
/**
* Loads RSA private key from disk, deriving public key.
*
* @throws SecurityError if key file not found, unreadable, or decryption fails
*/
public void loadKeys(String privateKeyPath, String password) {
public void loadKeys(String privateKeyPath, String password) throws SecurityError {
loadKeys(privateKeyPath, null, password);
}
@ -250,12 +259,12 @@ public class SecureCompletionClient {
URI url;
try {
url = new URI(this.routerUrl + "/pki/public_key");
url = new URI(this.routerUrl + Constants.PKI_PUBLIC_KEY_PATH);
} catch (URISyntaxException e) {
return CompletableFuture.failedFuture(new CompletionException("Invalid URI: " + e.getMessage(), e));
return CompletableFuture.failedFuture(new CompletionException(new APIConnectionError("Invalid URI: " + e.getMessage(), e)));
}
HttpRequest request = HttpRequest.newBuilder(url).timeout(Duration.of(60, ChronoUnit.SECONDS)).GET().build();
HttpRequest request = HttpRequest.newBuilder(url).timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)).GET().build();
return this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenApply(response -> {
if (response.statusCode() != 200) {
@ -264,7 +273,7 @@ public class SecureCompletionClient {
return response.body();
}).thenApply(body -> {
if (!PEMConverter.validatePEM(body)) {
throw new CompletionException(new InvalidKeyException("PEM key had invalid format"));
throw new CompletionException(new SecurityError("Server returned invalid PEM key format (possible MITM attack)"));
}
return body;
});
@ -278,7 +287,37 @@ public class SecureCompletionClient {
* @throws SecurityError if encryption fails or keys not loaded
*/
public CompletableFuture<byte[]> encryptPayload(Map<String, Object> payload) {
throw new UnsupportedOperationException("Not yet implemented");
return CompletableFuture.supplyAsync(() -> {
try {
ensureKeys(null);
} catch (SecurityError e) {
throw new CompletionException(new SecurityError("Failed to ensure keys are initialized: " + e.getMessage(), e));
}
if (this.privateKey == null) {
throw new CompletionException(new SecurityError("Private key not available for encryption"));
}
// Generate AES key
KeyGenerator keyGen;
try {
keyGen = KeyGenerator.getInstance("AES");
keyGen.init(Constants.AES_KEY_SIZE * 8);
} catch (NoSuchAlgorithmException e) {
throw new CompletionException(new SecurityError("AES key generation not available: " + e.getMessage(), e));
}
Key aesKey = keyGen.generateKey();
// Serialize payload to JSON
Gson gson = new Gson();
String payloadJson = gson.toJson(payload);
byte[] payloadBytes = payloadJson.getBytes(StandardCharsets.UTF_8);
// Encrypt
return doEncrypt(payloadBytes, aesKey).join();
});
}
/**
@ -287,16 +326,16 @@ public class SecureCompletionClient {
public CompletableFuture<byte[]> doEncrypt(byte[] payloadBytes, Key aesKey) {
return CompletableFuture.supplyAsync(() -> {
SecureRandom random = new SecureRandom();
byte[] nonce = new byte[12];
byte[] nonce = new byte[Constants.GCM_NONCE_SIZE];
random.nextBytes(nonce);
Cipher cipher = null;
try {
cipher = Cipher.getInstance("AES/GCM/NoPadding");
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(128, nonce));
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(Constants.GCM_TAG_SIZE * Byte.SIZE, nonce));
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidAlgorithmParameterException |
InvalidKeyException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("AES-GCM cipher initialization failed: " + e.getMessage(), e));
}
byte[] ciphertext;
@ -304,66 +343,63 @@ public class SecureCompletionClient {
try {
ciphertext = cipher.doFinal(payloadBytes);
} catch (IllegalBlockSizeException | BadPaddingException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("AES-GCM encryption failed: " + e.getMessage(), e));
}
String serverPEM;
try {
serverPEM = fetchServerPublicKey().get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(new SecurityError("Encryption interrupted while fetching server public key", e));
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof SecurityError) {
throw new RuntimeException((SecurityError) cause);
}
throw new RuntimeException(new SecurityError("Failed to fetch server public key: " + cause.getMessage(), cause));
}
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM).getBytes());
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM));
PublicKey serverPublicKey;
try {
serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec);
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("RSA key factory failed to parse server public key: " + e.getMessage(), e));
}
Cipher rsa;
byte[] enryptedAESKey = aesKey.getEncoded();
try {
rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey);
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e));
}
byte[] encryptedAESKey;
try {
rsa.doFinal(enryptedAESKey);
encryptedAESKey = rsa.doFinal(aesKey.getEncoded());
} catch (IllegalBlockSizeException | BadPaddingException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("RSA-OAEP key wrapping failed: " + e.getMessage(), e));
}
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - (128 / Byte.SIZE), ciphertext.length);
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length);
EncryptedRequest request = new EncryptedRequest();
request.setVersion("1.0");
request.setAlgorithm("hybrid-aes256-rsa4096");
request.setVersion(Constants.PROTOCOL_VERSION);
request.setAlgorithm(Constants.HYBRID_ALGORITHM);
request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(ciphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag)));
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(enryptedAESKey));
request.setKeyAlgorithm("RSA-OAEP-SHA256");
request.setPayloadAlgorithm("AES-256-GCM");
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(encryptedAESKey));
request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM);
request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM);
return request.toJson().getBytes(StandardCharsets.UTF_8);
});
}
/**
* Decrypts server response.
*/
public CompletableFuture<Map<String, Object>> decryptResponse(byte[] encryptedResponse, String payloadId) {
throw new UnsupportedOperationException("Not yet implemented");
}
/**
* encrypt POST {routerUrl}/v1/chat/secure_completion retry decrypt return.
* <p>Headers: Content-Type=octet-stream, X-Payload-ID, X-Public-Key, Authorization (Bearer), X-Security-Tier.
@ -388,7 +424,210 @@ public class SecureCompletionClient {
* @throws APIError other errors
*/
public CompletableFuture<Map<String, Object>> sendSecureRequest(Map<String, Object> payload, String payloadId, String apiKey, String securityTier) {
throw new UnsupportedOperationException("Not yet implemented");
return CompletableFuture.supplyAsync(() -> {
// Validate security tier if provided
if (securityTier != null && !Constants.VALID_SECURITY_TIERS.contains(securityTier)) {
throw new CompletionException(new ValueError(
"Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS));
}
try {
ensureKeys(null);
} catch (SecurityError e) {
throw new CompletionException(new SecurityError("Failed to ensure keys: " + e.getMessage(), e));
}
// Step 1: Encrypt payload
byte[] encryptedPayload;
try {
encryptedPayload = encryptPayload(payload).get();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new CompletionException(new APIConnectionError("Encryption interrupted"));
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof SecurityError) {
throw new CompletionException((SecurityError) cause);
}
throw new CompletionException(new SecurityError("Encryption failed: " + cause.getMessage(), cause));
}
// Step 2: Prepare headers
URI url;
try {
url = new URI(this.routerUrl + Constants.SECURE_COMPLETION_PATH);
} catch (URISyntaxException e) {
throw new CompletionException(new APIConnectionError("Invalid URL: " + e.getMessage(), e));
}
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url)
.timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS))
.header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM)
.header(Constants.HEADER_PAYLOAD_ID, payloadId)
.header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey))
.POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload));
if (apiKey != null && !apiKey.isEmpty()) {
requestBuilder.header("Authorization", Constants.AUTHORIZATION_BEARER_PREFIX + apiKey);
}
if (securityTier != null) {
requestBuilder.header(Constants.HEADER_SECURITY_TIER, securityTier);
}
HttpRequest request = requestBuilder.build();
// Step 3: Send request with retry
Exception lastExc = new APIConnectionError("Request failed");
int retryableCodes = 0;
for (int attempt = 0; attempt <= this.maxRetries; attempt++) {
if (attempt > 0) {
long delay = (long) Math.pow(2, attempt - 1);
try {
Thread.sleep(delay * 1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new CompletionException(new APIConnectionError("Retry interrupted"));
}
}
try {
HttpResponse<byte[]> response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofByteArray());
int statusCode = response.statusCode();
if (statusCode == 200) {
// Step 4: Decrypt response
return decryptResponse(response.body(), payloadId).get();
} else if (statusCode == 400) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Unknown error";
throw new CompletionException(new InvalidRequestError("Bad request: " + detail, 400, errorDetails));
} else if (statusCode == 401) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Invalid API key or authentication failed";
throw new CompletionException(new AuthenticationError(detail, 401, errorDetails));
} else if (statusCode == 403) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Model not allowed for the requested security tier";
throw new CompletionException(new ForbiddenError("Forbidden: " + detail, 403, errorDetails));
} else if (statusCode == 404) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Secure inference not enabled";
throw new CompletionException(new APIError("Endpoint not found: " + detail, 404, errorDetails));
} else if (Constants.RETRYABLE_STATUS_CODES.contains(statusCode)) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> error = Map.of();
String detailMsg = "unknown";
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) {
error = parsed;
if (error.containsKey("detail")) detailMsg = error.get("detail").toString();
}
} catch (Exception ignored) {
}
if (statusCode == 429) {
lastExc = new RateLimitError("Rate limit exceeded: " + detailMsg, 429, error);
} else if (statusCode == 500) {
lastExc = new ServerError("Server error: " + detailMsg, 500, error);
} else if (statusCode == 503) {
lastExc = new ServiceUnavailableError("Service unavailable: " + detailMsg, 503, error);
} else {
lastExc = new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, error);
}
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
} else {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
String detailMsg = "unknown";
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) {
errorDetails = parsed;
if (errorDetails.containsKey("detail")) detailMsg = errorDetails.get("detail").toString();
}
} catch (Exception ignored) {
}
throw new CompletionException(new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, errorDetails));
}
} catch (CompletionException e) {
Throwable cause = e.getCause();
if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError ||
cause instanceof ForbiddenError || cause instanceof SecurityError) {
throw e;
}
if (cause instanceof RateLimitError || cause instanceof ServerError || cause instanceof ServiceUnavailableError) {
lastExc = (Exception) cause;
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
}
if (cause instanceof APIError apiError) {
if (Constants.RETRYABLE_STATUS_CODES.contains(apiError.getStatusCode())) {
lastExc = apiError;
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
}
throw e;
}
lastExc = new APIConnectionError("Failed to connect to router: " + e.getMessage());
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new CompletionException(new APIConnectionError("Request interrupted"));
} catch (ExecutionException e) {
throw new CompletionException(new APIConnectionError("Request failed: " + e.getMessage(), e));
} catch (IOException e) {
lastExc = new APIConnectionError("IO error: " + e.getMessage(), e);
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
}
}
throw new CompletionException(lastExc);
});
}
/**
@ -410,7 +649,7 @@ public class SecureCompletionClient {
*
* @param keyDir key directory or {@code null} for ephemeral
*/
public void ensureKeys(String keyDir) {
public void ensureKeys(String keyDir) throws SecurityError {
if (keysInitialized) return;
keyInitLock.lock();
try {
@ -458,21 +697,19 @@ public class SecureCompletionClient {
* Maps HTTP status code to exception (200null).
*/
public Exception mapHttpStatus(int statusCode, String responseBody) {
String message = responseBody != null ? responseBody : "no body";
Map<String, Object> errorDetails = responseBody != null ? Map.of("response", responseBody) : Map.of();
return switch (statusCode) {
case 200 -> null;
case 400 ->
new InvalidRequestError("Invalid request: " + (responseBody != null ? responseBody : "no body"));
case 401 ->
new AuthenticationError("Authentication failed: " + (responseBody != null ? responseBody : "no body"));
case 403 -> new ForbiddenError("Access forbidden: " + (responseBody != null ? responseBody : "no body"));
case 404 -> new APIError("Not found: " + (responseBody != null ? responseBody : "no body"));
case 429 -> new RateLimitError("Rate limit exceeded: " + (responseBody != null ? responseBody : "no body"));
case 500 -> new ServerError("Internal server error: " + (responseBody != null ? responseBody : "no body"));
case 503 ->
new ServiceUnavailableError("Service unavailable: " + (responseBody != null ? responseBody : "no body"));
case 502, 504 -> new APIError("Gateway error: " + (responseBody != null ? responseBody : "no body"));
default ->
new APIError("Unexpected status " + statusCode + ": " + (responseBody != null ? responseBody : "no body"));
case 400 -> new InvalidRequestError("Invalid request: " + message, statusCode, errorDetails);
case 401 -> new AuthenticationError("Authentication failed: " + message, statusCode, errorDetails);
case 403 -> new ForbiddenError("Access forbidden: " + message, statusCode, errorDetails);
case 404 -> new APIError("Not found: " + message, statusCode, errorDetails);
case 429 -> new RateLimitError("Rate limit exceeded: " + message, statusCode, errorDetails);
case 500 -> new ServerError("Internal server error: " + message, statusCode, errorDetails);
case 503 -> new ServiceUnavailableError("Service unavailable: " + message, statusCode, errorDetails);
case 502, 504 -> new APIError("Gateway error: " + message, statusCode, errorDetails);
default -> new APIError("Unexpected status " + statusCode + ": " + message, statusCode, errorDetails);
};
}
@ -484,9 +721,138 @@ public class SecureCompletionClient {
}
/**
* Delegates to resource cleanup (stub).
* Closes the HTTP client and clears keys from memory.
*/
public void close() {
throw new UnsupportedOperationException("Not yet implemented");
this.httpClient.close();
this.privateKey = null;
this.publicPemKey = null;
this.keysInitialized = false;
}
/**
* Decrypts server response.
*/
public CompletableFuture<Map<String, Object>> decryptResponse(byte[] encryptedResponse, String payloadId) {
return CompletableFuture.supplyAsync(() -> {
if (encryptedResponse == null || encryptedResponse.length == 0) {
throw new CompletionException(new ValueError("Empty encrypted response"));
}
String jsonResponse;
try {
jsonResponse = new String(encryptedResponse, StandardCharsets.UTF_8);
} catch (Exception e) {
throw new CompletionException(new ValueError("Invalid encrypted package format: not valid UTF-8"));
}
Gson gson = new Gson();
JsonObject packageJson;
try {
packageJson = JsonParser.parseString(jsonResponse).getAsJsonObject();
} catch (Exception e) {
throw new CompletionException(new ValueError("Invalid encrypted package format: malformed JSON"));
}
// Validate required fields
String[] requiredFields = {"version", "algorithm", "encrypted_payload", "encrypted_aes_key"};
for (String field : requiredFields) {
if (!packageJson.has(field)) {
throw new CompletionException(new ValueError("Missing required fields in encrypted package: " + field));
}
}
// Validate version and algorithm
String version = packageJson.get("version").getAsString();
String algorithm = packageJson.get("algorithm").getAsString();
if (!Constants.PROTOCOL_VERSION.equals(version)) {
throw new CompletionException(new ValueError(
"Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'"));
}
if (!Constants.HYBRID_ALGORITHM.equals(algorithm)) {
throw new CompletionException(new ValueError(
"Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'"));
}
// Validate encrypted_payload structure
JsonObject encryptedPayload = packageJson.get("encrypted_payload").getAsJsonObject();
String[] payloadRequired = {"ciphertext", "nonce", "tag"};
for (String field : payloadRequired) {
if (!encryptedPayload.has(field)) {
throw new CompletionException(new ValueError("Missing fields in encrypted_payload: " + field));
}
}
// Guard: private key must be initialized
if (this.privateKey == null) {
throw new CompletionException(new SecurityError("Private key not initialized. Call generateKeys() or loadKeys() first."));
}
try {
// Decrypt AES key with private key
byte[] encryptedAESKey = Base64.getDecoder().decode(packageJson.get("encrypted_aes_key").getAsString());
Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey);
byte[] aesKeyBytes = rsaCipher.doFinal(encryptedAESKey);
SecretKeySpec aesKey = new SecretKeySpec(aesKeyBytes, "AES");
// Decrypt payload
byte[] ciphertext = Base64.getDecoder().decode(encryptedPayload.get("ciphertext").getAsString());
byte[] nonce = Base64.getDecoder().decode(encryptedPayload.get("nonce").getAsString());
byte[] tag = Base64.getDecoder().decode(encryptedPayload.get("tag").getAsString());
Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding");
aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce));
// Combine ciphertext (without tag) and tag for decryption
byte[] ciphertextWithTag = new byte[ciphertext.length + tag.length];
System.arraycopy(ciphertext, 0, ciphertextWithTag, 0, ciphertext.length);
System.arraycopy(tag, 0, ciphertextWithTag, ciphertext.length, tag.length);
byte[] plaintextBytes = aesCipher.doFinal(ciphertextWithTag);
// Parse JSON response
Map<String, Object> response;
try {
Object parsed = gson.fromJson(new String(plaintextBytes, StandardCharsets.UTF_8), Object.class);
@SuppressWarnings("unchecked")
Map<String, Object> resultMap = (Map<String, Object>) parsed;
response = resultMap != null ? resultMap : new HashMap<>();
} catch (Exception e) {
throw new CompletionException(new ValueError("Decrypted response is not valid JSON: " + e.getMessage()));
}
// Add metadata
if (!response.containsKey("_metadata")) {
response.put("_metadata", new HashMap<String, Object>());
}
@SuppressWarnings("unchecked")
Map<String, Object> metadata = (Map<String, Object>) response.get("_metadata");
metadata.put("payload_id", payloadId);
metadata.put("processed_at", packageJson.has("processed_at") ? packageJson.get("processed_at").getAsString() : null);
metadata.put("is_encrypted", true);
metadata.put("encryption_algorithm", Constants.HYBRID_ALGORITHM);
return response;
} catch (Exception e) {
throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed"));
}
});
}
/**
* Error class for invalid argument/value errors (maps to Python ValueError).
*/
public static class ValueError extends Exception {
public ValueError(String message) {
super(message);
}
public ValueError(String message, Throwable cause) {
super(message, cause);
}
}
}