Address PR feedback: error on stream=false for ChatGPT, fix auth file permissions

This commit is contained in:
Spherrrical 2026-04-20 13:18:44 -07:00
parent c0cc226b74
commit 5af3199f5a
4 changed files with 45 additions and 12 deletions

View file

@ -51,7 +51,8 @@ def load_auth() -> Optional[Dict[str, Any]]:
def save_auth(data: Dict[str, Any]): def save_auth(data: Dict[str, Any]):
"""Save auth data to disk.""" """Save auth data to disk."""
_ensure_auth_dir() _ensure_auth_dir()
with open(CHATGPT_AUTH_FILE, "w") as f: fd = os.open(CHATGPT_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)

View file

@ -228,7 +228,15 @@ async fn llm_chat_inner(
if let Some(ref client_api_kind) = client_api { if let Some(ref client_api_kind) = client_api {
let upstream_api = let upstream_api =
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request); provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
client_request.normalize_for_upstream(provider_id, &upstream_api); if let Err(e) = client_request.normalize_for_upstream(provider_id, &upstream_api) {
warn!(
"request_id={}: normalize_for_upstream failed: {}",
request_id, e
);
let mut bad_request = Response::new(full(e.message));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
} }
// --- Phase 2: Resolve conversation state (v1/responses API) --- // --- Phase 2: Resolve conversation state (v1/responses API) ---

View file

@ -77,7 +77,7 @@ impl ProviderRequestType {
&mut self, &mut self,
provider_id: ProviderId, provider_id: ProviderId,
upstream_api: &SupportedUpstreamAPIs, upstream_api: &SupportedUpstreamAPIs,
) { ) -> Result<(), ProviderRequestError> {
if provider_id == ProviderId::XAI if provider_id == ProviderId::XAI
&& matches!( && matches!(
upstream_api, upstream_api,
@ -110,6 +110,12 @@ impl ProviderRequestType {
} }
} }
req.store = Some(false); req.store = Some(false);
if req.stream == Some(false) {
return Err(ProviderRequestError {
message: "Non-streaming requests are not supported for the ChatGPT Codex provider. Set stream=true or omit the stream field.".to_string(),
source: None,
});
}
req.stream = Some(true); req.stream = Some(true);
// ChatGPT backend requires input to be a list, not a plain string // ChatGPT backend requires input to be a list, not a plain string
@ -124,6 +130,7 @@ impl ProviderRequestType {
} }
} }
} }
Ok(())
} }
} }
@ -859,10 +866,12 @@ mod tests {
..Default::default() ..Default::default()
}); });
request.normalize_for_upstream( request
.normalize_for_upstream(
ProviderId::XAI, ProviderId::XAI,
&SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
); )
.unwrap();
let ProviderRequestType::ChatCompletionsRequest(req) = request else { let ProviderRequestType::ChatCompletionsRequest(req) = request else {
panic!("expected chat request"); panic!("expected chat request");
@ -887,10 +896,12 @@ mod tests {
..Default::default() ..Default::default()
}); });
request.normalize_for_upstream( request
.normalize_for_upstream(
ProviderId::OpenAI, ProviderId::OpenAI,
&SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
); )
.unwrap();
let ProviderRequestType::ChatCompletionsRequest(req) = request else { let ProviderRequestType::ChatCompletionsRequest(req) = request else {
panic!("expected chat request"); panic!("expected chat request");

View file

@ -1056,7 +1056,20 @@ impl HttpContext for StreamContext {
match ProviderRequestType::try_from((deserialized_client_request, upstream)) { match ProviderRequestType::try_from((deserialized_client_request, upstream)) {
Ok(mut request) => { Ok(mut request) => {
request.normalize_for_upstream(self.get_provider_id(), upstream); if let Err(e) =
request.normalize_for_upstream(self.get_provider_id(), upstream)
{
warn!(
"request_id={}: normalize_for_upstream failed: {}",
self.request_identifier(),
e
);
self.send_server_error(
ServerError::LogicError(e.message),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
debug!( debug!(
"request_id={}: upstream request payload: {}", "request_id={}: upstream request payload: {}",
self.request_identifier(), self.request_identifier(),