Add X-Request-Id middleware

Per-request ULID minted at the edge, exposed in request extensions and
on the response header. Caller-supplied X-Request-Id is echoed when
well-formed (1..=128 ASCII printable characters); otherwise rejected
and replaced with a fresh ULID so the value is always safe to log.

Companion to the TypeScript SDK redesign — clients now correlate logs
across the wire by reading X-Request-Id from response headers (and the
SDK already surfaces it on every OmnigraphError as `requestId`).

No spec change required; the header is a transport-layer concern.

Tests:
- mint a ULID when no header is provided
- echo a valid caller-supplied id
- reject overlong header (200 chars), mint a fresh ULID

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Ragnor Comerford 2026-04-25 22:56:17 +02:00
parent 7ea868485e
commit 284c9377c2
No known key found for this signature in database
5 changed files with 142 additions and 0 deletions

1
Cargo.lock generated
View file

@ -4700,6 +4700,7 @@ dependencies = [
"tower-http",
"tracing",
"tracing-subscriber",
"ulid",
"utoipa",
]

View file

@ -31,6 +31,7 @@ serde_yaml = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
tower-http = { workspace = true }
ulid = { workspace = true }
utoipa = { workspace = true }
cedar-policy = { workspace = true }
futures = { workspace = true }

View file

@ -2,6 +2,7 @@ pub mod api;
pub mod auth;
pub mod config;
pub mod policy;
pub mod request_id;
use std::collections::{HashMap, HashSet};
use std::fs;
@ -460,6 +461,7 @@ pub fn build_app(state: AppState) -> Router {
.merge(protected)
.layer(DefaultBodyLimit::max(DEFAULT_REQUEST_BODY_LIMIT_BYTES))
.layer(TraceLayer::new_for_http())
.layer(middleware::from_fn(request_id::request_id_middleware))
.with_state(state)
}

View file

@ -0,0 +1,59 @@
//! `X-Request-Id` middleware.
//!
//! Mints a ULID per inbound request, or echoes a caller-supplied
//! `X-Request-Id` header if it's well-formed. Stores the value in request
//! extensions so handlers can include it in error bodies, log lines, or
//! audit records, and surfaces it on the response header so SDK clients
//! can correlate logs across the wire.
use axum::{
body::Body,
extract::Request,
http::{HeaderName, HeaderValue, header},
middleware::Next,
response::Response,
};
use ulid::Ulid;
pub const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
/// Wraps a request id pulled out of (or minted into) request extensions.
#[derive(Clone, Debug)]
pub struct RequestId(pub String);
impl RequestId {
pub fn as_str(&self) -> &str {
&self.0
}
}
/// Acceptable inbound `X-Request-Id` shape: 1..=128 ASCII printable chars.
/// Rejecting wider input keeps the value safe to log and emit verbatim.
fn is_valid_inbound(raw: &str) -> bool {
!raw.is_empty()
&& raw.len() <= 128
&& raw
.bytes()
.all(|b| b.is_ascii_graphic() || b == b' ' || b == b'-' || b == b'_')
}
pub async fn request_id_middleware(mut req: Request<Body>, next: Next) -> Response {
let inbound = req
.headers()
.get(header::HeaderName::from_static("x-request-id"))
.and_then(|v| v.to_str().ok())
.filter(|raw| is_valid_inbound(raw));
let id = match inbound {
Some(raw) => raw.to_owned(),
None => Ulid::new().to_string(),
};
req.extensions_mut().insert(RequestId(id.clone()));
let mut response = next.run(req).await;
if let Ok(value) = HeaderValue::from_str(&id) {
response.headers_mut().insert(X_REQUEST_ID, value);
}
response
}

View file

@ -675,6 +675,85 @@ async fn healthz_succeeds_after_startup() {
}
}
#[tokio::test(flavor = "multi_thread")]
async fn request_id_minted_when_absent() {
let (_temp, app) = app_for_loaded_repo().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/healthz")
.method(Method::GET)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let id = response
.headers()
.get("x-request-id")
.expect("X-Request-Id missing")
.to_str()
.unwrap()
.to_owned();
// ULIDs are 26 chars Crockford base32.
assert_eq!(id.len(), 26);
assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
}
#[tokio::test(flavor = "multi_thread")]
async fn request_id_echoed_when_caller_supplies_valid_value() {
let (_temp, app) = app_for_loaded_repo().await;
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/healthz")
.method(Method::GET)
.header("X-Request-Id", "trace-abc123")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(
response
.headers()
.get("x-request-id")
.unwrap()
.to_str()
.unwrap(),
"trace-abc123"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn request_id_minted_when_caller_supplies_invalid_value() {
let (_temp, app) = app_for_loaded_repo().await;
// 200-char string is a valid HeaderValue but exceeds the inbound length cap.
let too_long = "a".repeat(200);
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/healthz")
.method(Method::GET)
.header("X-Request-Id", &too_long)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let id = response
.headers()
.get("x-request-id")
.unwrap()
.to_str()
.unwrap();
assert_ne!(id, too_long);
assert_eq!(id.len(), 26);
}
#[tokio::test(flavor = "multi_thread")]
async fn schema_drift_returns_conflict_for_snapshot_read_and_change() {
let (temp, app) = app_for_loaded_repo().await;