From ebe3a4fca0cc51f98a005477c47bdf78560f3544 Mon Sep 17 00:00:00 2001 From: pitboss Date: Fri, 22 May 2026 00:22:08 -0500 Subject: [PATCH] [pitboss/grind] deferred session-0002 (20260522T043516Z-29b8) --- src/dynamic/framework/adapters/rust_actix.rs | 19 +- src/dynamic/framework/adapters/rust_axum.rs | 18 +- src/dynamic/framework/adapters/rust_rocket.rs | 19 +- src/dynamic/framework/adapters/rust_routes.rs | 228 +++++++++++++++++- src/dynamic/framework/adapters/rust_warp.rs | 18 +- 5 files changed, 289 insertions(+), 13 deletions(-) diff --git a/src/dynamic/framework/adapters/rust_actix.rs b/src/dynamic/framework/adapters/rust_actix.rs index f7cf5e0e..41a564a3 100644 --- a/src/dynamic/framework/adapters/rust_actix.rs +++ b/src/dynamic/framework/adapters/rust_actix.rs @@ -19,8 +19,8 @@ use crate::symbol::Lang; use tree_sitter::Node; use super::rust_routes::{ - bind_rust_path_params, find_actix_route_chain, find_method_attribute, find_rust_function, - rust_formal_names, source_imports_actix, + bind_rust_path_params, collect_rust_middleware, find_actix_route_chain, find_method_attribute, + find_rust_function, rust_formal_names, source_imports_actix, }; pub struct RustActixAdapter; @@ -50,13 +50,14 @@ impl FrameworkAdapter for RustActixAdapter { .or_else(|| find_actix_route_chain(ast, file_bytes, &summary.name))?; let formals = rust_formal_names(func, file_bytes); let request_params = bind_rust_path_params(&formals, &path); + let middleware = collect_rust_middleware(ast, file_bytes); Some(FrameworkBinding { adapter: ADAPTER_NAME.to_owned(), kind: EntryKind::HttpRoute, route: Some(RouteShape { method, path }), request_params, response_writer: None, - middleware: Vec::new(), + middleware, }) } } @@ -167,6 +168,18 @@ mod tests { assert_eq!(route.path, "/save"); } + #[test] + fn populates_middleware_from_wrap_call() { + let src: &[u8] = b"use actix_web::{App, web};\n\ + fn build() -> App<()> { App::new().wrap(HttpAuthentication::bearer(validator)).route(\"/u\", web::get().to(show)) }\n\ + async fn show() -> String { String::new() }\n"; + let tree = parse(src); + let binding = RustActixAdapter + .detect(&summary("show"), tree.root_node(), src) + .expect("binding"); + assert!(binding.middleware.iter().any(|m| m.name.contains("HttpAuthentication"))); + } + #[test] fn chained_builder_requires_handler_match() { let src: &[u8] = b"use actix_web::{App, web};\n\ diff --git a/src/dynamic/framework/adapters/rust_axum.rs b/src/dynamic/framework/adapters/rust_axum.rs index a09efc48..f09f8153 100644 --- a/src/dynamic/framework/adapters/rust_axum.rs +++ b/src/dynamic/framework/adapters/rust_axum.rs @@ -19,8 +19,8 @@ use crate::symbol::Lang; use tree_sitter::Node; use super::rust_routes::{ - bind_rust_path_params, find_axum_route, find_rust_function, rust_formal_names, - source_imports_axum, + bind_rust_path_params, collect_rust_middleware, find_axum_route, find_rust_function, + rust_formal_names, source_imports_axum, }; pub struct RustAxumAdapter; @@ -52,13 +52,14 @@ impl FrameworkAdapter for RustAxumAdapter { bind_rust_path_params(&formals, &path) }) .unwrap_or_default(); + let middleware = collect_rust_middleware(ast, file_bytes); Some(FrameworkBinding { adapter: ADAPTER_NAME.to_owned(), kind: EntryKind::HttpRoute, route: Some(RouteShape { method, path }), request_params, response_writer: None, - middleware: Vec::new(), + middleware, }) } } @@ -123,6 +124,17 @@ mod tests { ); } + #[test] + fn populates_middleware_from_layer_calls() { + let src: &[u8] = b"use axum::Router;\nfn build() -> Router { Router::new().route(\"/u/{id}\", get(show)).layer(AuthLayer) }\nfn show(id: String) -> String { id }\n"; + let tree = parse(src); + let binding = RustAxumAdapter + .detect(&summary("show"), tree.root_node(), src) + .expect("binding"); + assert_eq!(binding.middleware.len(), 1); + assert_eq!(binding.middleware[0].name, "AuthLayer"); + } + #[test] fn skips_when_route_does_not_reference_function() { let src: &[u8] = b"use axum::Router;\nfn build() -> Router { Router::new().route(\"/u\", get(show)) }\nfn helper() {}\n"; diff --git a/src/dynamic/framework/adapters/rust_rocket.rs b/src/dynamic/framework/adapters/rust_rocket.rs index a2ecd43d..4742837c 100644 --- a/src/dynamic/framework/adapters/rust_rocket.rs +++ b/src/dynamic/framework/adapters/rust_rocket.rs @@ -23,8 +23,8 @@ use crate::symbol::Lang; use tree_sitter::Node; use super::rust_routes::{ - bind_rust_path_params, find_method_attribute, find_rust_function, rust_formal_names, - source_imports_rocket, + bind_rust_path_params, collect_rust_middleware, find_method_attribute, find_rust_function, + rust_formal_names, source_imports_rocket, }; pub struct RustRocketAdapter; @@ -53,13 +53,14 @@ impl FrameworkAdapter for RustRocketAdapter { let (method, path) = find_method_attribute(func, file_bytes)?; let formals = rust_formal_names(func, file_bytes); let request_params = bind_rust_path_params(&formals, &path); + let middleware = collect_rust_middleware(ast, file_bytes); Some(FrameworkBinding { adapter: ADAPTER_NAME.to_owned(), kind: EntryKind::HttpRoute, route: Some(RouteShape { method, path }), request_params, response_writer: None, - middleware: Vec::new(), + middleware, }) } } @@ -115,6 +116,18 @@ mod tests { assert_eq!(binding.route.unwrap().method, HttpMethod::POST); } + #[test] + fn populates_middleware_from_attach_fairing() { + let src: &[u8] = b"use rocket::get;\n#[get(\"/u\")]\nfn show() -> &'static str { \"ok\" }\n\ + #[launch]\nfn rocket() -> _ { rocket::build().attach(CsrfLayer).mount(\"/\", routes![show]) }\n"; + let tree = parse(src); + let binding = RustRocketAdapter + .detect(&summary("show"), tree.root_node(), src) + .expect("binding"); + assert_eq!(binding.middleware.len(), 1); + assert_eq!(binding.middleware[0].name, "CsrfLayer"); + } + #[test] fn skips_when_rocket_not_imported() { let src: &[u8] = b"#[get(\"/u\")]\nfn show() {}\n"; diff --git a/src/dynamic/framework/adapters/rust_routes.rs b/src/dynamic/framework/adapters/rust_routes.rs index d53a933c..b4c0cf7c 100644 --- a/src/dynamic/framework/adapters/rust_routes.rs +++ b/src/dynamic/framework/adapters/rust_routes.rs @@ -13,7 +13,9 @@ //! paradigm; the warp adapter binds formals positionally rather //! than by name. -use crate::dynamic::framework::{HttpMethod, ParamBinding, ParamSource}; +use crate::dynamic::framework::auth_markers; +use crate::dynamic::framework::{HttpMethod, MiddlewareShape, ParamBinding, ParamSource}; +use crate::symbol::Lang; use tree_sitter::Node; /// True when `bytes` carries any of the well-known axum markers. @@ -287,6 +289,153 @@ pub fn verb_from_ident(ident: &str) -> Option { } } +/// Walk every method-chain call in the file whose field name is one +/// of the known middleware-attach verbs and collect argument +/// expressions whose names match a known Rust middleware marker (see +/// [`crate::dynamic::framework::auth_markers::is_protective`]). +/// +/// Per-framework attach verbs: +/// - axum: `.layer(...)`, `.route_layer(...)` +/// - actix: `.wrap(...)`, `.wrap_fn(...)` +/// - rocket: `.attach(...)` (fairings) +/// - warp: `.and(filter)` filter composition +/// +/// Argument rendering: +/// - bare identifier (`.layer(AuthLayer)`) → `"AuthLayer"` +/// - scoped identifier (`.wrap(middleware::Logger::default())`'s +/// receiver path) — the call-form below covers it via callee text +/// - call expression (`.layer(AuthLayer::new())`) → +/// `"AuthLayer::new"` (callee text, args dropped) +/// - turbofish call expression (`.layer(Service::::new())`) → +/// callee stripped of generics +/// +/// De-duplicates within a single file; preserves declaration order. +/// Names the registry does not recognise are dropped silently — the +/// caller can re-walk with a wider predicate if it needs broader +/// inclusion. +pub fn collect_rust_middleware(root: Node<'_>, bytes: &[u8]) -> Vec { + let mut raw: Vec = Vec::new(); + walk_attach_calls(root, bytes, &mut raw); + let mut out: Vec = Vec::new(); + for name in raw { + if auth_markers::is_protective(Lang::Rust, &name) + && !out.iter().any(|m| m.name == name) + { + out.push(MiddlewareShape { name }); + } + } + out +} + +fn walk_attach_calls(node: Node<'_>, bytes: &[u8], out: &mut Vec) { + if node.kind() == "call_expression" { + try_collect_attach_call(node, bytes, out); + } + let mut cur = node.walk(); + for child in node.children(&mut cur) { + walk_attach_calls(child, bytes, out); + } +} + +fn try_collect_attach_call(call: Node<'_>, bytes: &[u8], out: &mut Vec) { + let Some(callee) = call.child_by_field_name("function") else { + return; + }; + if callee.kind() != "field_expression" { + return; + } + let Some(field) = callee.child_by_field_name("field") else { + return; + }; + let Ok(verb) = field.utf8_text(bytes) else { + return; + }; + if !matches!( + verb, + "layer" | "route_layer" | "wrap" | "wrap_fn" | "attach" | "and" + ) { + return; + } + let Some(args) = call.child_by_field_name("arguments") else { + return; + }; + let mut cur = args.walk(); + for arg in args.named_children(&mut cur) { + if matches!(arg.kind(), "line_comment" | "block_comment") { + continue; + } + push_middleware_candidates(arg, bytes, out); + } +} + +fn push_middleware_candidates(node: Node<'_>, bytes: &[u8], out: &mut Vec) { + let Some(primary) = middleware_arg_name(node, bytes) else { + return; + }; + out.push(primary.clone()); + // Also push the leading path segment so a scoped callee like + // `HttpAuthentication::bearer(validator)` matches the marker + // `HttpAuthentication` in the auth-markers table. + if let Some((head, _)) = primary.split_once("::") { + let head = head.trim(); + if !head.is_empty() && head != primary { + out.push(head.to_owned()); + } + } +} + +fn middleware_arg_name(node: Node<'_>, bytes: &[u8]) -> Option { + match node.kind() { + "identifier" | "scoped_identifier" => { + node.utf8_text(bytes).ok().map(|s| s.trim().to_owned()) + } + "call_expression" => { + let callee = node.child_by_field_name("function")?; + let raw = callee.utf8_text(bytes).ok()?.trim().to_owned(); + // Strip turbofish generics: `Service::::new` → `Service::new`. + Some(strip_turbofish(&raw)) + } + "generic_function" => { + let callee = node.child_by_field_name("function")?; + callee.utf8_text(bytes).ok().map(|s| s.trim().to_owned()) + } + _ => None, + } +} + +fn strip_turbofish(raw: &str) -> String { + let mut out = String::with_capacity(raw.len()); + let mut depth: i32 = 0; + let bytes = raw.as_bytes(); + let mut i = 0; + while i < bytes.len() { + if depth == 0 && i + 1 < bytes.len() && bytes[i] == b':' && bytes[i + 1] == b':' { + // peek for `<` + let mut j = i + 2; + while j < bytes.len() && bytes[j].is_ascii_whitespace() { + j += 1; + } + if j < bytes.len() && bytes[j] == b'<' { + depth += 1; + i = j + 1; + continue; + } + } + if depth > 0 { + match bytes[i] { + b'<' => depth += 1, + b'>' => depth -= 1, + _ => {} + } + i += 1; + continue; + } + out.push(bytes[i] as char); + i += 1; + } + out +} + /// Read the content of a Rust `string_literal` node, stripping the /// surrounding `"` quotes. Returns `None` if `node` is not a string /// literal. @@ -921,4 +1070,81 @@ mod tests { assert!(matches!(bindings[0].source, ParamSource::Implicit)); assert!(matches!(bindings[1].source, ParamSource::PathSegment(_))); } + + #[test] + fn collect_rust_middleware_picks_axum_layer_bare_ident() { + let src: &[u8] = b"use axum::Router;\nfn build() -> Router { Router::new().route(\"/x\", get(show)).layer(AuthLayer) }\nfn show() {}\n"; + let tree = parse(src); + let mw = collect_rust_middleware(tree.root_node(), src); + assert_eq!(mw.len(), 1); + assert_eq!(mw[0].name, "AuthLayer"); + } + + #[test] + fn collect_rust_middleware_picks_axum_route_layer() { + let src: &[u8] = b"use axum::Router;\nfn build() -> Router { Router::new().route(\"/x\", get(show)).route_layer(CsrfLayer) }\nfn show() {}\n"; + let tree = parse(src); + let mw = collect_rust_middleware(tree.root_node(), src); + assert_eq!(mw.len(), 1); + assert_eq!(mw[0].name, "CsrfLayer"); + } + + #[test] + fn collect_rust_middleware_picks_actix_wrap_call() { + let src: &[u8] = b"use actix_web::App;\nfn build() -> App<()> { App::new().wrap(HttpAuthentication::bearer(validator)) }\n"; + let tree = parse(src); + let mw = collect_rust_middleware(tree.root_node(), src); + assert!(mw.iter().any(|m| m.name.contains("HttpAuthentication"))); + } + + #[test] + fn collect_rust_middleware_picks_rocket_attach_fairing() { + let src: &[u8] = b"use rocket::Rocket;\nfn build() { rocket::build().attach(CsrfLayer) }\n"; + let tree = parse(src); + let mw = collect_rust_middleware(tree.root_node(), src); + assert_eq!(mw.len(), 1); + assert_eq!(mw[0].name, "CsrfLayer"); + } + + #[test] + fn collect_rust_middleware_picks_warp_and_filter() { + let src: &[u8] = b"use warp::Filter;\nfn build() { let r = warp::path!(\"x\").and(BearerAuth).map(show); }\nfn show() {}\n"; + let tree = parse(src); + let mw = collect_rust_middleware(tree.root_node(), src); + assert_eq!(mw.len(), 1); + assert_eq!(mw[0].name, "BearerAuth"); + } + + #[test] + fn collect_rust_middleware_drops_unknown_names() { + let src: &[u8] = b"use axum::Router;\nfn build() -> Router { Router::new().layer(LoggingLayer) }\n"; + let tree = parse(src); + let mw = collect_rust_middleware(tree.root_node(), src); + assert!(mw.is_empty(), "LoggingLayer is not a recognised marker"); + } + + #[test] + fn collect_rust_middleware_dedupes_and_preserves_order() { + let src: &[u8] = b"use axum::Router;\nfn build() -> Router { Router::new().layer(AuthLayer).route_layer(CsrfLayer).layer(AuthLayer) }\n"; + let tree = parse(src); + let mw = collect_rust_middleware(tree.root_node(), src); + let names: Vec<&str> = mw.iter().map(|m| m.name.as_str()).collect(); + assert_eq!(names, vec!["AuthLayer", "CsrfLayer"]); + } + + #[test] + fn collect_rust_middleware_returns_empty_when_no_attach() { + let src: &[u8] = b"use axum::Router;\nfn build() -> Router { Router::new().route(\"/x\", get(show)) }\nfn show() {}\n"; + let tree = parse(src); + let mw = collect_rust_middleware(tree.root_node(), src); + assert!(mw.is_empty()); + } + + #[test] + fn strip_turbofish_removes_generic_args() { + assert_eq!(strip_turbofish("Foo::::new"), "Foo::new"); + assert_eq!(strip_turbofish("Foo::new"), "Foo::new"); + assert_eq!(strip_turbofish("foo"), "foo"); + assert_eq!(strip_turbofish("Foo::::bar"), "Foo::bar"); + } } diff --git a/src/dynamic/framework/adapters/rust_warp.rs b/src/dynamic/framework/adapters/rust_warp.rs index bc3d60bc..2d55a0dd 100644 --- a/src/dynamic/framework/adapters/rust_warp.rs +++ b/src/dynamic/framework/adapters/rust_warp.rs @@ -21,8 +21,8 @@ use crate::symbol::Lang; use tree_sitter::Node; use super::rust_routes::{ - bind_rust_path_params, find_rust_function, find_warp_route, rust_formal_names, - source_imports_warp, + bind_rust_path_params, collect_rust_middleware, find_rust_function, find_warp_route, + rust_formal_names, source_imports_warp, }; pub struct RustWarpAdapter; @@ -54,13 +54,14 @@ impl FrameworkAdapter for RustWarpAdapter { bind_rust_path_params(&formals, &path) }) .unwrap_or_default(); + let middleware = collect_rust_middleware(ast, file_bytes); Some(FrameworkBinding { adapter: ADAPTER_NAME.to_owned(), kind: EntryKind::HttpRoute, route: Some(RouteShape { method, path }), request_params, response_writer: None, - middleware: Vec::new(), + middleware, }) } } @@ -108,6 +109,17 @@ mod tests { assert!(binding.route.unwrap().path.contains("x")); } + #[test] + fn populates_middleware_from_and_filter() { + let src: &[u8] = b"use warp::Filter;\nfn build() { let r = warp::path!(\"x\" / u32).and(BearerAuth).map(show); }\nfn show(id: u32) -> String { String::new() }\n"; + let tree = parse(src); + let binding = RustWarpAdapter + .detect(&summary("show"), tree.root_node(), src) + .expect("binding"); + assert_eq!(binding.middleware.len(), 1); + assert_eq!(binding.middleware[0].name, "BearerAuth"); + } + #[test] fn skips_when_warp_not_imported() { let src: &[u8] = b"fn show() {}\n";