Complete functionality

This commit is contained in:
Oracle 2026-04-29 15:14:24 +02:00
parent 675418f411
commit 9b5fa56215
Signed by: Oracle
SSH key fingerprint: SHA256:x4/RtnjUyuHkdvmwNDsWSfcfF1V5PNr3OpriZqOvCX8
5 changed files with 75 additions and 119 deletions

View file

@ -7,6 +7,8 @@ import lombok.Getter;
import javax.crypto.*;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.OAEPParameterSpec;
import javax.crypto.spec.PSource;
import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.math.BigInteger;
@ -26,6 +28,7 @@ import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
@ -106,13 +109,22 @@ public class SecureCompletionClient {
this.useSecureMemory = secureMemory;
this.keySize = Constants.RSA_KEY_SIZE;
this.maxRetries = maxRetries;
this.httpClient = HttpClient.newHttpClient();
this.httpClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build();
}
private static String readFileContent(String filePath) throws IOException {
return Files.readString(Path.of(filePath));
}
private static Map<String, Object> parseErrorBody(String body) {
try {
@SuppressWarnings("unchecked") Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
return parsed != null ? parsed : Map.of();
} catch (Exception e) {
return Map.of();
}
}
/**
* Generates a 4096-bit RSA key pair (exponent 65537). Saves to disk if {@code saveToFile}.
*/
@ -175,7 +187,7 @@ public class SecureCompletionClient {
} catch (InvalidAlgorithmParameterException e) {
throw new SecurityError("Invalid RSA key generation parameters: " + e.getMessage(), e);
} catch (IOException e) {
throw new RuntimeException("Failed to save keys: " + e.getMessage(), e);
throw new SecurityError("Failed to save keys: " + e.getMessage(), e);
}
}
@ -212,7 +224,8 @@ public class SecureCompletionClient {
try {
keyContent = Pass2Key.decrypt("AES/GCM/NoPadding", keyContent, password);
} catch (NoSuchPaddingException | NoSuchAlgorithmException | BadPaddingException |
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException | SecurityError e) {
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException |
SecurityError e) {
throw new SecurityError("Failed to decrypt private key with provided password: " + e.getMessage(), e);
}
} else {
@ -369,12 +382,18 @@ public class SecureCompletionClient {
Cipher rsa;
try {
OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1
PSource.PSpecified.DEFAULT);
rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey);
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) {
rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey, oaepParams);
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
InvalidAlgorithmParameterException e) {
throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e));
}
byte[] encryptedAESKey;
try {
encryptedAESKey = rsa.doFinal(aesKey.getEncoded());
} catch (IllegalBlockSizeException | BadPaddingException e) {
@ -383,11 +402,13 @@ public class SecureCompletionClient {
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length);
byte[] actualCiphertext = Arrays.copyOf(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE);
EncryptedRequest request = new EncryptedRequest();
request.setVersion(Constants.PROTOCOL_VERSION);
request.setAlgorithm(Constants.HYBRID_ALGORITHM);
request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(ciphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag)));
request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(actualCiphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag)));
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(encryptedAESKey));
request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM);
request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM);
@ -424,8 +445,7 @@ public class SecureCompletionClient {
return CompletableFuture.supplyAsync(() -> {
// Validate security tier if provided
if (securityTier != null && !Constants.VALID_SECURITY_TIERS.contains(securityTier)) {
throw new CompletionException(new ValueError(
"Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS));
throw new CompletionException(new ValueError("Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS));
}
try {
@ -457,12 +477,7 @@ public class SecureCompletionClient {
throw new CompletionException(new APIConnectionError("Invalid URL: " + e.getMessage(), e));
}
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url)
.timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS))
.header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM)
.header(Constants.HEADER_PAYLOAD_ID, payloadId)
.header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey))
.POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload));
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url).timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)).header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM).header(Constants.HEADER_PAYLOAD_ID, payloadId).header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey)).POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload));
if (apiKey != null && !apiKey.isEmpty()) {
requestBuilder.header("Authorization", Constants.AUTHORIZATION_BEARER_PREFIX + apiKey);
@ -496,61 +511,28 @@ public class SecureCompletionClient {
return decryptResponse(response.body(), payloadId).get();
} else if (statusCode == 400) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
Map<String, Object> errorDetails = parseErrorBody(body);
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Unknown error";
throw new CompletionException(new InvalidRequestError("Bad request: " + detail, 400, errorDetails));
} else if (statusCode == 401) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
Map<String, Object> errorDetails = parseErrorBody(body);
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Invalid API key or authentication failed";
throw new CompletionException(new AuthenticationError(detail, 401, errorDetails));
} else if (statusCode == 403) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
Map<String, Object> errorDetails = parseErrorBody(body);
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Model not allowed for the requested security tier";
throw new CompletionException(new ForbiddenError("Forbidden: " + detail, 403, errorDetails));
} else if (statusCode == 404) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
Map<String, Object> errorDetails = parseErrorBody(body);
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Secure inference not enabled";
throw new CompletionException(new APIError("Endpoint not found: " + detail, 404, errorDetails));
} else if (Constants.RETRYABLE_STATUS_CODES.contains(statusCode)) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> error = Map.of();
String detailMsg = "unknown";
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) {
error = parsed;
if (error.containsKey("detail")) detailMsg = error.get("detail").toString();
}
} catch (Exception ignored) {
}
Map<String, Object> error = parseErrorBody(body);
String detailMsg = error.containsKey("detail") ? error.get("detail").toString() : "unknown";
if (statusCode == 429) {
lastExc = new RateLimitError("Rate limit exceeded: " + detailMsg, 429, error);
@ -568,23 +550,13 @@ public class SecureCompletionClient {
throw new CompletionException(lastExc);
} else {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
String detailMsg = "unknown";
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) {
errorDetails = parsed;
if (errorDetails.containsKey("detail")) detailMsg = errorDetails.get("detail").toString();
}
} catch (Exception ignored) {
}
Map<String, Object> errorDetails = parseErrorBody(body);
String detailMsg = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "unknown";
throw new CompletionException(new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, errorDetails));
}
} catch (CompletionException e) {
Throwable cause = e.getCause();
if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError ||
cause instanceof ForbiddenError || cause instanceof SecurityError) {
if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError || cause instanceof ForbiddenError || cause instanceof SecurityError) {
throw e;
}
if (cause instanceof RateLimitError || cause instanceof ServerError || cause instanceof ServiceUnavailableError) {
@ -764,12 +736,10 @@ public class SecureCompletionClient {
String algorithm = packageJson.get("algorithm").getAsString();
if (!Constants.PROTOCOL_VERSION.equals(version)) {
throw new CompletionException(new ValueError(
"Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'"));
throw new CompletionException(new ValueError("Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'"));
}
if (!Constants.HYBRID_ALGORITHM.equals(algorithm)) {
throw new CompletionException(new ValueError(
"Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'"));
throw new CompletionException(new ValueError("Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'"));
}
// Validate encrypted_payload structure
@ -789,9 +759,12 @@ public class SecureCompletionClient {
try {
// Decrypt AES key with private key
byte[] encryptedAESKey = Base64.getDecoder().decode(packageJson.get("encrypted_aes_key").getAsString());
OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1
PSource.PSpecified.DEFAULT);
Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey);
rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey, oaepParams);
byte[] aesKeyBytes = rsaCipher.doFinal(encryptedAESKey);
SecretKeySpec aesKey = new SecretKeySpec(aesKeyBytes, "AES");
@ -802,7 +775,7 @@ public class SecureCompletionClient {
Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding");
aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce));
// Combine ciphertext (without tag) and tag for decryption
byte[] ciphertextWithTag = new byte[ciphertext.length + tag.length];
System.arraycopy(ciphertext, 0, ciphertextWithTag, 0, ciphertext.length);
@ -814,8 +787,7 @@ public class SecureCompletionClient {
Map<String, Object> response;
try {
Object parsed = gson.fromJson(new String(plaintextBytes, StandardCharsets.UTF_8), Object.class);
@SuppressWarnings("unchecked")
Map<String, Object> resultMap = (Map<String, Object>) parsed;
@SuppressWarnings("unchecked") Map<String, Object> resultMap = (Map<String, Object>) parsed;
response = resultMap != null ? resultMap : new HashMap<>();
} catch (Exception e) {
throw new CompletionException(new ValueError("Decrypted response is not valid JSON: " + e.getMessage()));
@ -825,8 +797,7 @@ public class SecureCompletionClient {
if (!response.containsKey("_metadata")) {
response.put("_metadata", new HashMap<String, Object>());
}
@SuppressWarnings("unchecked")
Map<String, Object> metadata = (Map<String, Object>) response.get("_metadata");
@SuppressWarnings("unchecked") Map<String, Object> metadata = (Map<String, Object>) response.get("_metadata");
metadata.put("payload_id", payloadId);
metadata.put("processed_at", packageJson.has("processed_at") ? packageJson.get("processed_at").getAsString() : null);
metadata.put("is_encrypted", true);
@ -835,7 +806,7 @@ public class SecureCompletionClient {
return response;
} catch (Exception e) {
throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed"));
throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed: " + e.getMessage(), e));
}
});
}