Complete functionality
This commit is contained in:
parent
675418f411
commit
9b5fa56215
5 changed files with 75 additions and 119 deletions
|
|
@ -7,6 +7,8 @@ import lombok.Getter;
|
|||
|
||||
import javax.crypto.*;
|
||||
import javax.crypto.spec.GCMParameterSpec;
|
||||
import javax.crypto.spec.OAEPParameterSpec;
|
||||
import javax.crypto.spec.PSource;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
import java.io.IOException;
|
||||
import java.math.BigInteger;
|
||||
|
|
@ -26,6 +28,7 @@ import java.util.*;
|
|||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionException;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonParser;
|
||||
|
|
@ -106,13 +109,22 @@ public class SecureCompletionClient {
|
|||
this.useSecureMemory = secureMemory;
|
||||
this.keySize = Constants.RSA_KEY_SIZE;
|
||||
this.maxRetries = maxRetries;
|
||||
this.httpClient = HttpClient.newHttpClient();
|
||||
this.httpClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build();
|
||||
}
|
||||
|
||||
private static String readFileContent(String filePath) throws IOException {
|
||||
return Files.readString(Path.of(filePath));
|
||||
}
|
||||
|
||||
private static Map<String, Object> parseErrorBody(String body) {
|
||||
try {
|
||||
@SuppressWarnings("unchecked") Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
return parsed != null ? parsed : Map.of();
|
||||
} catch (Exception e) {
|
||||
return Map.of();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a 4096-bit RSA key pair (exponent 65537). Saves to disk if {@code saveToFile}.
|
||||
*/
|
||||
|
|
@ -175,7 +187,7 @@ public class SecureCompletionClient {
|
|||
} catch (InvalidAlgorithmParameterException e) {
|
||||
throw new SecurityError("Invalid RSA key generation parameters: " + e.getMessage(), e);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to save keys: " + e.getMessage(), e);
|
||||
throw new SecurityError("Failed to save keys: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -212,7 +224,8 @@ public class SecureCompletionClient {
|
|||
try {
|
||||
keyContent = Pass2Key.decrypt("AES/GCM/NoPadding", keyContent, password);
|
||||
} catch (NoSuchPaddingException | NoSuchAlgorithmException | BadPaddingException |
|
||||
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException | SecurityError e) {
|
||||
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException |
|
||||
SecurityError e) {
|
||||
throw new SecurityError("Failed to decrypt private key with provided password: " + e.getMessage(), e);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -369,12 +382,18 @@ public class SecureCompletionClient {
|
|||
|
||||
Cipher rsa;
|
||||
try {
|
||||
OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1
|
||||
PSource.PSpecified.DEFAULT);
|
||||
|
||||
rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
|
||||
rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey);
|
||||
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) {
|
||||
rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey, oaepParams);
|
||||
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
|
||||
InvalidAlgorithmParameterException e) {
|
||||
throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
byte[] encryptedAESKey;
|
||||
|
||||
try {
|
||||
encryptedAESKey = rsa.doFinal(aesKey.getEncoded());
|
||||
} catch (IllegalBlockSizeException | BadPaddingException e) {
|
||||
|
|
@ -383,11 +402,13 @@ public class SecureCompletionClient {
|
|||
|
||||
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length);
|
||||
|
||||
byte[] actualCiphertext = Arrays.copyOf(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE);
|
||||
|
||||
EncryptedRequest request = new EncryptedRequest();
|
||||
|
||||
request.setVersion(Constants.PROTOCOL_VERSION);
|
||||
request.setAlgorithm(Constants.HYBRID_ALGORITHM);
|
||||
request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(ciphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag)));
|
||||
request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(actualCiphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag)));
|
||||
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(encryptedAESKey));
|
||||
request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM);
|
||||
request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM);
|
||||
|
|
@ -424,8 +445,7 @@ public class SecureCompletionClient {
|
|||
return CompletableFuture.supplyAsync(() -> {
|
||||
// Validate security tier if provided
|
||||
if (securityTier != null && !Constants.VALID_SECURITY_TIERS.contains(securityTier)) {
|
||||
throw new CompletionException(new ValueError(
|
||||
"Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS));
|
||||
throw new CompletionException(new ValueError("Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS));
|
||||
}
|
||||
|
||||
try {
|
||||
|
|
@ -457,12 +477,7 @@ public class SecureCompletionClient {
|
|||
throw new CompletionException(new APIConnectionError("Invalid URL: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url)
|
||||
.timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS))
|
||||
.header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM)
|
||||
.header(Constants.HEADER_PAYLOAD_ID, payloadId)
|
||||
.header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey))
|
||||
.POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload));
|
||||
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url).timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)).header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM).header(Constants.HEADER_PAYLOAD_ID, payloadId).header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey)).POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload));
|
||||
|
||||
if (apiKey != null && !apiKey.isEmpty()) {
|
||||
requestBuilder.header("Authorization", Constants.AUTHORIZATION_BEARER_PREFIX + apiKey);
|
||||
|
|
@ -496,61 +511,28 @@ public class SecureCompletionClient {
|
|||
return decryptResponse(response.body(), payloadId).get();
|
||||
} else if (statusCode == 400) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) errorDetails = parsed;
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
Map<String, Object> errorDetails = parseErrorBody(body);
|
||||
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Unknown error";
|
||||
throw new CompletionException(new InvalidRequestError("Bad request: " + detail, 400, errorDetails));
|
||||
} else if (statusCode == 401) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) errorDetails = parsed;
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
Map<String, Object> errorDetails = parseErrorBody(body);
|
||||
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Invalid API key or authentication failed";
|
||||
throw new CompletionException(new AuthenticationError(detail, 401, errorDetails));
|
||||
} else if (statusCode == 403) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) errorDetails = parsed;
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
Map<String, Object> errorDetails = parseErrorBody(body);
|
||||
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Model not allowed for the requested security tier";
|
||||
throw new CompletionException(new ForbiddenError("Forbidden: " + detail, 403, errorDetails));
|
||||
} else if (statusCode == 404) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) errorDetails = parsed;
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
Map<String, Object> errorDetails = parseErrorBody(body);
|
||||
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Secure inference not enabled";
|
||||
throw new CompletionException(new APIError("Endpoint not found: " + detail, 404, errorDetails));
|
||||
} else if (Constants.RETRYABLE_STATUS_CODES.contains(statusCode)) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> error = Map.of();
|
||||
String detailMsg = "unknown";
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) {
|
||||
error = parsed;
|
||||
if (error.containsKey("detail")) detailMsg = error.get("detail").toString();
|
||||
}
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
Map<String, Object> error = parseErrorBody(body);
|
||||
String detailMsg = error.containsKey("detail") ? error.get("detail").toString() : "unknown";
|
||||
|
||||
if (statusCode == 429) {
|
||||
lastExc = new RateLimitError("Rate limit exceeded: " + detailMsg, 429, error);
|
||||
|
|
@ -568,23 +550,13 @@ public class SecureCompletionClient {
|
|||
throw new CompletionException(lastExc);
|
||||
} else {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
String detailMsg = "unknown";
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) {
|
||||
errorDetails = parsed;
|
||||
if (errorDetails.containsKey("detail")) detailMsg = errorDetails.get("detail").toString();
|
||||
}
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
Map<String, Object> errorDetails = parseErrorBody(body);
|
||||
String detailMsg = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "unknown";
|
||||
throw new CompletionException(new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, errorDetails));
|
||||
}
|
||||
} catch (CompletionException e) {
|
||||
Throwable cause = e.getCause();
|
||||
if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError ||
|
||||
cause instanceof ForbiddenError || cause instanceof SecurityError) {
|
||||
if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError || cause instanceof ForbiddenError || cause instanceof SecurityError) {
|
||||
throw e;
|
||||
}
|
||||
if (cause instanceof RateLimitError || cause instanceof ServerError || cause instanceof ServiceUnavailableError) {
|
||||
|
|
@ -764,12 +736,10 @@ public class SecureCompletionClient {
|
|||
String algorithm = packageJson.get("algorithm").getAsString();
|
||||
|
||||
if (!Constants.PROTOCOL_VERSION.equals(version)) {
|
||||
throw new CompletionException(new ValueError(
|
||||
"Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'"));
|
||||
throw new CompletionException(new ValueError("Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'"));
|
||||
}
|
||||
if (!Constants.HYBRID_ALGORITHM.equals(algorithm)) {
|
||||
throw new CompletionException(new ValueError(
|
||||
"Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'"));
|
||||
throw new CompletionException(new ValueError("Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'"));
|
||||
}
|
||||
|
||||
// Validate encrypted_payload structure
|
||||
|
|
@ -789,9 +759,12 @@ public class SecureCompletionClient {
|
|||
try {
|
||||
// Decrypt AES key with private key
|
||||
byte[] encryptedAESKey = Base64.getDecoder().decode(packageJson.get("encrypted_aes_key").getAsString());
|
||||
|
||||
|
||||
OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1
|
||||
PSource.PSpecified.DEFAULT);
|
||||
|
||||
Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
|
||||
rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey);
|
||||
rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey, oaepParams);
|
||||
byte[] aesKeyBytes = rsaCipher.doFinal(encryptedAESKey);
|
||||
SecretKeySpec aesKey = new SecretKeySpec(aesKeyBytes, "AES");
|
||||
|
||||
|
|
@ -802,7 +775,7 @@ public class SecureCompletionClient {
|
|||
|
||||
Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding");
|
||||
aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce));
|
||||
|
||||
|
||||
// Combine ciphertext (without tag) and tag for decryption
|
||||
byte[] ciphertextWithTag = new byte[ciphertext.length + tag.length];
|
||||
System.arraycopy(ciphertext, 0, ciphertextWithTag, 0, ciphertext.length);
|
||||
|
|
@ -814,8 +787,7 @@ public class SecureCompletionClient {
|
|||
Map<String, Object> response;
|
||||
try {
|
||||
Object parsed = gson.fromJson(new String(plaintextBytes, StandardCharsets.UTF_8), Object.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultMap = (Map<String, Object>) parsed;
|
||||
@SuppressWarnings("unchecked") Map<String, Object> resultMap = (Map<String, Object>) parsed;
|
||||
response = resultMap != null ? resultMap : new HashMap<>();
|
||||
} catch (Exception e) {
|
||||
throw new CompletionException(new ValueError("Decrypted response is not valid JSON: " + e.getMessage()));
|
||||
|
|
@ -825,8 +797,7 @@ public class SecureCompletionClient {
|
|||
if (!response.containsKey("_metadata")) {
|
||||
response.put("_metadata", new HashMap<String, Object>());
|
||||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> metadata = (Map<String, Object>) response.get("_metadata");
|
||||
@SuppressWarnings("unchecked") Map<String, Object> metadata = (Map<String, Object>) response.get("_metadata");
|
||||
metadata.put("payload_id", payloadId);
|
||||
metadata.put("processed_at", packageJson.has("processed_at") ? packageJson.get("processed_at").getAsString() : null);
|
||||
metadata.put("is_encrypted", true);
|
||||
|
|
@ -835,7 +806,7 @@ public class SecureCompletionClient {
|
|||
return response;
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed"));
|
||||
throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed: " + e.getMessage(), e));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue