mirror of
https://github.com/0xMassi/webclaw.git
synced 2026-06-14 23:25:12 +02:00
Some checks are pending
CI / Test (push) Waiting to run
CI / Lint (push) Waiting to run
CI / Docs (push) Waiting to run
Three P3 items from the 2026-04-16 audit. Bump to 0.3.17. webclaw-fetch/sitemap.rs: parse_robots_txt used trimmed[..8] slice plus eq_ignore_ascii_case for the directive test. That was fragile: "Sitemap :" (space before colon) fell through silently, inline "# ..." comments leaked into the URL, and a line with no URL at all returned an empty string. Rewritten to split on the first colon, match any-case "sitemap" as the directive name, strip comments, and require `://` in the value. +7 unit tests cover case variants, space-before-colon, comments, empty values, non-URL values, and non-sitemap directives. webclaw-fetch/crawler.rs: is_cancelled uses Ordering::Acquire instead of Relaxed. Behaviourally equivalent on current hardware for single-word atomic loads, but the explicit ordering documents intent for readers + compilers. webclaw-mcp/server.rs: add lazy OnceLock cache for the Firefox FetchClient. Tool calls that repeatedly request the firefox profile without cookies used to build a fresh reqwest pool + TLS stack per call. Chrome (default) already used the long-lived field; Random is per-call by design; cookie-bearing requests still build ad-hoc since the cookie header is part of the client shape. Tests: 85 webclaw-fetch (was 78, +7 new sitemap), 272 webclaw-core, 43 webclaw-llm, 11 CLI — all green. Clippy clean across workspace. Refs: docs/AUDIT-2026-04-16.md P3 section Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
858 lines
30 KiB
Rust
858 lines
30 KiB
Rust
/// Recursive web crawler built on top of [`FetchClient`].
|
||
///
|
||
/// Starts from a seed URL, extracts content, discovers links, and follows
|
||
/// them breadth-first up to a configurable depth/page limit. Uses a semaphore
|
||
/// for bounded concurrency and per-request delays for politeness.
|
||
///
|
||
/// Scope control: by default only same-origin links are followed. Enable
|
||
/// `allow_subdomains` to include sibling/child subdomains of the seed host,
|
||
/// or `allow_external_links` to follow links to any domain.
|
||
///
|
||
/// When `use_sitemap` is enabled, the crawler first discovers URLs from the
|
||
/// site's sitemaps and seeds the BFS frontier before crawling.
|
||
use std::collections::HashSet;
|
||
use std::path::Path;
|
||
use std::sync::Arc;
|
||
use std::sync::atomic::{AtomicBool, Ordering};
|
||
use std::time::{Duration, Instant};
|
||
|
||
use serde::{Deserialize, Serialize};
|
||
use tokio::sync::Semaphore;
|
||
use tracing::{debug, info, warn};
|
||
use url::Url;
|
||
|
||
use crate::client::{FetchClient, FetchConfig};
|
||
use crate::error::FetchError;
|
||
use crate::sitemap;
|
||
|
||
/// Controls crawl scope, depth, concurrency, and politeness.
|
||
#[derive(Debug, Clone)]
|
||
pub struct CrawlConfig {
|
||
/// Fetch configuration (browser profile, proxy, timeout, etc.)
|
||
pub fetch: FetchConfig,
|
||
/// How deep to follow links. 1 = only immediate links from seed page.
|
||
pub max_depth: usize,
|
||
/// Hard cap on total pages fetched (including the seed).
|
||
pub max_pages: usize,
|
||
/// Max concurrent in-flight requests.
|
||
pub concurrency: usize,
|
||
/// Minimum delay before each request (politeness).
|
||
pub delay: Duration,
|
||
/// Only follow URLs whose path starts with this prefix (e.g. "/docs/").
|
||
pub path_prefix: Option<String>,
|
||
/// Seed BFS frontier from sitemap discovery before crawling.
|
||
pub use_sitemap: bool,
|
||
/// Glob patterns for paths to include. If non-empty, only matching URLs are crawled.
|
||
/// E.g. `["/api/*", "/guides/*"]` -- matched against the URL path.
|
||
pub include_patterns: Vec<String>,
|
||
/// Glob patterns for paths to exclude. Checked after include_patterns.
|
||
/// E.g. `["/changelog/*", "/blog/*"]` -- matching URLs are skipped.
|
||
pub exclude_patterns: Vec<String>,
|
||
/// Follow links on subdomains of the seed domain (e.g. blog.example.com
|
||
/// when crawling example.com). Default: false (same-origin only).
|
||
pub allow_subdomains: bool,
|
||
/// Follow links to entirely different domains. Default: false.
|
||
/// When true, the crawler becomes cross-origin. Use with caution.
|
||
pub allow_external_links: bool,
|
||
/// Optional channel sender for streaming per-page results as they complete.
|
||
/// When set, each `PageResult` is sent on this channel immediately after extraction.
|
||
pub progress_tx: Option<tokio::sync::broadcast::Sender<PageResult>>,
|
||
/// When set to `true`, the crawler breaks out of the main loop early.
|
||
/// Callers (e.g. a Ctrl+C handler) can flip this to request graceful cancellation.
|
||
pub cancel_flag: Option<Arc<AtomicBool>>,
|
||
}
|
||
|
||
impl Default for CrawlConfig {
|
||
fn default() -> Self {
|
||
Self {
|
||
fetch: FetchConfig::default(),
|
||
max_depth: 1,
|
||
max_pages: 50,
|
||
concurrency: 5,
|
||
delay: Duration::from_millis(100),
|
||
path_prefix: None,
|
||
use_sitemap: false,
|
||
include_patterns: Vec::new(),
|
||
exclude_patterns: Vec::new(),
|
||
allow_subdomains: false,
|
||
allow_external_links: false,
|
||
progress_tx: None,
|
||
cancel_flag: None,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Aggregated results from a crawl run.
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
pub struct CrawlResult {
|
||
pub pages: Vec<PageResult>,
|
||
pub total: usize,
|
||
pub ok: usize,
|
||
pub errors: usize,
|
||
pub elapsed_secs: f64,
|
||
/// URLs visited during this crawl (for resume state).
|
||
#[serde(skip)]
|
||
pub visited: HashSet<String>,
|
||
/// Remaining frontier when crawl was cancelled (for resume state).
|
||
#[serde(skip)]
|
||
pub remaining_frontier: Vec<(String, usize)>,
|
||
}
|
||
|
||
/// Outcome of extracting a single page during the crawl.
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct PageResult {
|
||
pub url: String,
|
||
pub depth: usize,
|
||
pub extraction: Option<webclaw_core::ExtractionResult>,
|
||
pub error: Option<String>,
|
||
#[serde(skip)]
|
||
pub elapsed: Duration,
|
||
}
|
||
|
||
/// Serializable crawl state for resume after Ctrl+C cancellation.
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct CrawlState {
|
||
pub seed_url: String,
|
||
pub visited: Vec<String>,
|
||
pub frontier: Vec<(String, usize)>,
|
||
pub completed_pages: usize,
|
||
pub max_pages: usize,
|
||
pub max_depth: usize,
|
||
}
|
||
|
||
/// Recursive crawler that wraps a shared [`FetchClient`].
|
||
pub struct Crawler {
|
||
client: Arc<FetchClient>,
|
||
config: CrawlConfig,
|
||
seed_origin: String,
|
||
/// Root domain of the seed URL for subdomain matching (e.g. "example.com").
|
||
seed_root_domain: String,
|
||
}
|
||
|
||
impl Crawler {
|
||
/// Build a new crawler from a seed URL and config.
|
||
/// Constructs the underlying `FetchClient` from `config.fetch`.
|
||
pub fn new(seed_url: &str, config: CrawlConfig) -> Result<Self, FetchError> {
|
||
let seed = Url::parse(seed_url).map_err(|_| FetchError::InvalidUrl(seed_url.into()))?;
|
||
let seed_origin = origin_key(&seed);
|
||
let seed_root_domain = root_domain(&seed);
|
||
|
||
// Reject pathological user-supplied glob patterns before they can
|
||
// exercise the recursive `**` handler in glob_match_inner. The
|
||
// matcher is a straight backtracking implementation; a deeply
|
||
// nested `**/**/**/...` pattern against a long path can degrade
|
||
// to exponential time per link checked, per page crawled.
|
||
for pat in config
|
||
.include_patterns
|
||
.iter()
|
||
.chain(config.exclude_patterns.iter())
|
||
{
|
||
validate_glob(pat)?;
|
||
}
|
||
|
||
let client = FetchClient::new(config.fetch.clone())?;
|
||
|
||
Ok(Self {
|
||
client: Arc::new(client),
|
||
config,
|
||
seed_origin,
|
||
seed_root_domain,
|
||
})
|
||
}
|
||
|
||
/// Save current crawl state to a JSON file for later resume.
|
||
pub fn save_state(
|
||
path: &Path,
|
||
seed_url: &str,
|
||
visited: &HashSet<String>,
|
||
frontier: &[(String, usize)],
|
||
completed_pages: usize,
|
||
max_pages: usize,
|
||
max_depth: usize,
|
||
) -> Result<(), String> {
|
||
let state = CrawlState {
|
||
seed_url: seed_url.to_string(),
|
||
visited: visited.iter().cloned().collect(),
|
||
frontier: frontier.to_vec(),
|
||
completed_pages,
|
||
max_pages,
|
||
max_depth,
|
||
};
|
||
let json =
|
||
serde_json::to_string_pretty(&state).map_err(|e| format!("serialize state: {e}"))?;
|
||
std::fs::write(path, json).map_err(|e| format!("write state to {}: {e}", path.display()))
|
||
}
|
||
|
||
/// Load crawl state from a JSON file. Returns `None` if file doesn't exist.
|
||
pub fn load_state(path: &Path) -> Option<CrawlState> {
|
||
let content = std::fs::read_to_string(path).ok()?;
|
||
serde_json::from_str(&content).ok()
|
||
}
|
||
|
||
/// Returns true if the cancel flag has been set.
|
||
///
|
||
/// Uses `Acquire` load to pair with a `Release` store on the cancel
|
||
/// path. `Relaxed` was technically fine in practice (x86/arm64 give
|
||
/// release semantics for free on single-word stores) but `Acquire`
|
||
/// makes the ordering explicit so the compiler and future readers
|
||
/// don't need to reason about the memory model.
|
||
fn is_cancelled(&self) -> bool {
|
||
self.config
|
||
.cancel_flag
|
||
.as_ref()
|
||
.is_some_and(|f| f.load(Ordering::Acquire))
|
||
}
|
||
|
||
/// Crawl starting from `start_url`, returning results for every page visited.
|
||
///
|
||
/// Uses breadth-first traversal: all pages at depth N are fetched (concurrently,
|
||
/// bounded by `config.concurrency`) before moving to depth N+1.
|
||
///
|
||
/// When `config.use_sitemap` is true, sitemap URLs are discovered first and
|
||
/// added to the initial frontier at depth 0 alongside the seed URL.
|
||
///
|
||
/// If `resume_state` is provided, the crawl resumes from the saved state
|
||
/// (pre-populated visited set and frontier) instead of starting fresh.
|
||
pub async fn crawl(&self, start_url: &str, resume_state: Option<CrawlState>) -> CrawlResult {
|
||
let start = Instant::now();
|
||
|
||
let seed = match Url::parse(start_url) {
|
||
Ok(u) => u,
|
||
Err(_) => {
|
||
return CrawlResult {
|
||
pages: vec![PageResult {
|
||
url: start_url.to_string(),
|
||
depth: 0,
|
||
extraction: None,
|
||
error: Some(format!("invalid URL: {start_url}")),
|
||
elapsed: Duration::ZERO,
|
||
}],
|
||
total: 1,
|
||
ok: 0,
|
||
errors: 1,
|
||
elapsed_secs: 0.0,
|
||
visited: HashSet::new(),
|
||
remaining_frontier: Vec::new(),
|
||
};
|
||
}
|
||
};
|
||
|
||
let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
|
||
let mut visited: HashSet<String>;
|
||
let mut pages: Vec<PageResult> = Vec::new();
|
||
let mut frontier: Vec<(String, usize)>;
|
||
|
||
// Resume from saved state or start fresh
|
||
if let Some(state) = resume_state {
|
||
visited = state.visited.into_iter().collect();
|
||
frontier = state.frontier;
|
||
info!(
|
||
visited = visited.len(),
|
||
frontier = frontier.len(),
|
||
"resuming crawl from saved state"
|
||
);
|
||
} else {
|
||
visited = HashSet::new();
|
||
frontier = vec![(normalize(&seed), 0)];
|
||
|
||
// Seed frontier from sitemap if enabled
|
||
if self.config.use_sitemap {
|
||
let base_url = format!("{}://{}", seed.scheme(), seed.host_str().unwrap_or(""));
|
||
match sitemap::discover(&self.client, &base_url).await {
|
||
Ok(entries) => {
|
||
let before = frontier.len();
|
||
for entry in entries {
|
||
if self.qualify_link(&entry.url, &visited).is_some() {
|
||
let parsed = match Url::parse(&entry.url) {
|
||
Ok(u) => u,
|
||
Err(_) => continue,
|
||
};
|
||
let norm = normalize(&parsed);
|
||
frontier.push((norm, 0));
|
||
}
|
||
}
|
||
let added = frontier.len() - before;
|
||
info!(
|
||
sitemap_urls = added,
|
||
"seeded frontier from sitemap discovery"
|
||
);
|
||
}
|
||
Err(e) => {
|
||
warn!(error = %e, "sitemap discovery failed, continuing with seed URL only");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
while !frontier.is_empty() && pages.len() < self.config.max_pages {
|
||
// Check cancel flag before processing each batch
|
||
if self.is_cancelled() {
|
||
info!("crawl cancelled by user");
|
||
break;
|
||
}
|
||
|
||
// Dedup this level's frontier against the visited set and page cap
|
||
let batch: Vec<(String, usize)> = frontier
|
||
.drain(..)
|
||
.filter(|(url, _)| visited.insert(url.clone()))
|
||
.take(self.config.max_pages.saturating_sub(pages.len()))
|
||
.collect();
|
||
|
||
if batch.is_empty() {
|
||
break;
|
||
}
|
||
|
||
// Spawn one task per URL, bounded by semaphore
|
||
let mut handles = Vec::with_capacity(batch.len());
|
||
|
||
for (url, depth) in &batch {
|
||
let permit = Arc::clone(&semaphore);
|
||
let client = Arc::clone(&self.client);
|
||
let url = url.clone();
|
||
let depth = *depth;
|
||
let delay = self.config.delay;
|
||
|
||
handles.push(tokio::spawn(async move {
|
||
// Acquire permit -- blocks if concurrency limit reached.
|
||
// Surface semaphore-closed as a failed PageResult rather
|
||
// than panicking the spawned task and silently dropping
|
||
// it from the batch.
|
||
let page_start = Instant::now();
|
||
let result = match permit.acquire().await {
|
||
Ok(_permit) => {
|
||
tokio::time::sleep(delay).await;
|
||
client.fetch_and_extract(&url).await
|
||
}
|
||
Err(_) => {
|
||
warn!(url = %url, depth, "semaphore closed before acquire");
|
||
return PageResult {
|
||
url,
|
||
depth,
|
||
extraction: None,
|
||
error: Some("semaphore closed before acquire".into()),
|
||
elapsed: page_start.elapsed(),
|
||
};
|
||
}
|
||
};
|
||
let elapsed = page_start.elapsed();
|
||
|
||
match result {
|
||
Ok(extraction) => {
|
||
debug!(
|
||
url = %url, depth,
|
||
elapsed_ms = %elapsed.as_millis(),
|
||
"page extracted"
|
||
);
|
||
PageResult {
|
||
url,
|
||
depth,
|
||
extraction: Some(extraction),
|
||
error: None,
|
||
elapsed,
|
||
}
|
||
}
|
||
Err(e) => {
|
||
warn!(url = %url, depth, error = %e, "page failed");
|
||
PageResult {
|
||
url,
|
||
depth,
|
||
extraction: None,
|
||
error: Some(e.to_string()),
|
||
elapsed,
|
||
}
|
||
}
|
||
}
|
||
}));
|
||
}
|
||
|
||
// Collect results and harvest links for the next depth level
|
||
let mut next_frontier: Vec<(String, usize)> = Vec::new();
|
||
|
||
for handle in handles {
|
||
let page = match handle.await {
|
||
Ok(page) => page,
|
||
Err(e) => {
|
||
warn!(error = %e, "crawl task panicked");
|
||
continue;
|
||
}
|
||
};
|
||
let depth = page.depth;
|
||
|
||
if depth < self.config.max_depth
|
||
&& let Some(ref extraction) = page.extraction
|
||
{
|
||
for link in &extraction.content.links {
|
||
if let Some(candidate) = self.qualify_link(&link.href, &visited) {
|
||
next_frontier.push((candidate, depth + 1));
|
||
}
|
||
}
|
||
}
|
||
|
||
// Stream progress if a channel is configured
|
||
if let Some(tx) = &self.config.progress_tx {
|
||
let _ = tx.send(page.clone());
|
||
}
|
||
|
||
pages.push(page);
|
||
|
||
if pages.len() >= self.config.max_pages {
|
||
break;
|
||
}
|
||
|
||
// Check cancel flag between page results
|
||
if self.is_cancelled() {
|
||
info!("crawl cancelled by user (mid-batch)");
|
||
break;
|
||
}
|
||
}
|
||
|
||
// Cap frontier size independently of max_pages. Pages like
|
||
// search-result listings or tag clouds can emit thousands of
|
||
// links per page; without this a single dense page could push
|
||
// the frontier into the tens of thousands of entries and keep
|
||
// String allocations alive even after max_pages halts crawling.
|
||
// Trim aggressively once we exceed 10× max_pages, keeping the
|
||
// most recently discovered entries which are still on-topic
|
||
// (breadth-first = siblings of the last page we saw).
|
||
let frontier_cap = self.config.max_pages.saturating_mul(10).max(100);
|
||
if next_frontier.len() > frontier_cap {
|
||
let keep = self.config.max_pages.saturating_mul(5).max(50);
|
||
warn!(
|
||
frontier = next_frontier.len(),
|
||
cap = frontier_cap,
|
||
trimmed_to = keep,
|
||
"frontier exceeded cap, truncating"
|
||
);
|
||
next_frontier.truncate(keep);
|
||
}
|
||
|
||
frontier = next_frontier;
|
||
}
|
||
|
||
let total_elapsed = start.elapsed();
|
||
let ok_count = pages.iter().filter(|p| p.extraction.is_some()).count();
|
||
let err_count = pages.len() - ok_count;
|
||
info!(
|
||
total = pages.len(),
|
||
ok = ok_count,
|
||
errors = err_count,
|
||
elapsed_ms = %total_elapsed.as_millis(),
|
||
"crawl complete"
|
||
);
|
||
|
||
CrawlResult {
|
||
total: pages.len(),
|
||
ok: ok_count,
|
||
errors: err_count,
|
||
elapsed_secs: total_elapsed.as_secs_f64(),
|
||
remaining_frontier: frontier,
|
||
visited,
|
||
pages,
|
||
}
|
||
}
|
||
|
||
/// Check if a discovered link should be added to the frontier.
|
||
/// Returns `Some(normalized_url)` if it passes all filters, `None` otherwise.
|
||
fn qualify_link(&self, href: &str, visited: &HashSet<String>) -> Option<String> {
|
||
let parsed = Url::parse(href).ok()?;
|
||
|
||
// Only http(s) schemes
|
||
match parsed.scheme() {
|
||
"http" | "https" => {}
|
||
_ => return None,
|
||
}
|
||
|
||
// Scope check: same-origin, subdomain, or external
|
||
if !self.config.allow_external_links {
|
||
let link_origin = origin_key(&parsed);
|
||
if link_origin != self.seed_origin {
|
||
// Not same-origin. Check if subdomain crawling is allowed.
|
||
if self.config.allow_subdomains {
|
||
let link_root = root_domain(&parsed);
|
||
if link_root != self.seed_root_domain {
|
||
return None;
|
||
}
|
||
} else {
|
||
return None;
|
||
}
|
||
}
|
||
}
|
||
|
||
// Path prefix filter
|
||
if let Some(ref prefix) = self.config.path_prefix
|
||
&& !parsed.path().starts_with(prefix.as_str())
|
||
{
|
||
return None;
|
||
}
|
||
|
||
// Include patterns: if any are set, path must match at least one
|
||
let path = parsed.path();
|
||
if !self.config.include_patterns.is_empty()
|
||
&& !self
|
||
.config
|
||
.include_patterns
|
||
.iter()
|
||
.any(|pat| glob_match(pat, path))
|
||
{
|
||
return None;
|
||
}
|
||
|
||
// Exclude patterns: if path matches any, skip
|
||
if self
|
||
.config
|
||
.exclude_patterns
|
||
.iter()
|
||
.any(|pat| glob_match(pat, path))
|
||
{
|
||
return None;
|
||
}
|
||
|
||
// Skip common non-page file extensions
|
||
const SKIP_EXTENSIONS: &[&str] = &[
|
||
".pdf", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".webp", ".ico", ".css", ".js",
|
||
".zip", ".tar", ".gz", ".xml", ".rss", ".mp3", ".mp4", ".avi", ".mov", ".woff",
|
||
".woff2", ".ttf", ".eot",
|
||
];
|
||
if SKIP_EXTENSIONS.iter().any(|ext| path.ends_with(ext)) {
|
||
return None;
|
||
}
|
||
|
||
let normalized = normalize(&parsed);
|
||
|
||
if visited.contains(&normalized) {
|
||
return None;
|
||
}
|
||
|
||
Some(normalized)
|
||
}
|
||
}
|
||
|
||
/// Canonical origin string for comparing same-origin: "scheme://host[:port]".
|
||
fn origin_key(url: &Url) -> String {
|
||
let port_suffix = match url.port() {
|
||
Some(p) => format!(":{p}"),
|
||
None => String::new(),
|
||
};
|
||
let host = url.host_str().unwrap_or("");
|
||
let host = host.strip_prefix("www.").unwrap_or(host);
|
||
format!("{}://{}{}", url.scheme(), host, port_suffix)
|
||
}
|
||
|
||
/// Extract the root domain from a URL for subdomain comparison.
|
||
/// "blog.docs.example.com" -> "example.com", "example.co.uk" -> "example.co.uk" (best-effort).
|
||
///
|
||
/// Uses a simple heuristic: take the last two labels, or three if the second-to-last
|
||
/// is short (<=3 chars, likely a country SLD like "co.uk", "com.au").
|
||
fn root_domain(url: &Url) -> String {
|
||
let host = url.host_str().unwrap_or("");
|
||
let host = host.strip_prefix("www.").unwrap_or(host);
|
||
let labels: Vec<&str> = host.split('.').collect();
|
||
|
||
if labels.len() <= 2 {
|
||
return host.to_ascii_lowercase();
|
||
}
|
||
|
||
// Heuristic for two-part TLDs (co.uk, com.au, org.br, etc.)
|
||
let sld = labels[labels.len() - 2];
|
||
if labels.len() >= 3 && sld.len() <= 3 {
|
||
labels[labels.len() - 3..].join(".").to_ascii_lowercase()
|
||
} else {
|
||
labels[labels.len() - 2..].join(".").to_ascii_lowercase()
|
||
}
|
||
}
|
||
|
||
/// Normalize a URL for dedup: strip fragment, remove trailing slash (except root "/"),
|
||
/// lowercase scheme + host. Preserves query params and path case.
|
||
fn normalize(url: &Url) -> String {
|
||
let scheme = url.scheme();
|
||
let host = url.host_str().unwrap_or("").to_ascii_lowercase();
|
||
let port_suffix = match url.port() {
|
||
Some(p) => format!(":{p}"),
|
||
None => String::new(),
|
||
};
|
||
|
||
let mut path = url.path().to_string();
|
||
if path.len() > 1 && path.ends_with('/') {
|
||
path.pop();
|
||
}
|
||
|
||
let query = match url.query() {
|
||
Some(q) => format!("?{q}"),
|
||
None => String::new(),
|
||
};
|
||
|
||
// Fragment intentionally omitted
|
||
format!("{scheme}://{host}{port_suffix}{path}{query}")
|
||
}
|
||
|
||
/// Maximum number of `**` wildcards allowed in a single user glob. Each
|
||
/// additional `**` multiplies the backtracking fan-out of `glob_match_inner`
|
||
/// against adversarial paths; 4 is a practical ceiling for legitimate
|
||
/// nested include/exclude patterns and still keeps the matcher linear-ish.
|
||
const MAX_GLOB_DOUBLESTAR: usize = 4;
|
||
|
||
/// Maximum glob pattern length. Keeps a single pattern from taking
|
||
/// megabytes of RAM if someone copy-pastes garbage into --include.
|
||
const MAX_GLOB_LEN: usize = 1024;
|
||
|
||
/// Validate a user-supplied glob pattern before it hits the matcher.
|
||
/// Rejects patterns that would drive `glob_match_inner` into pathological
|
||
/// backtracking (too many `**`, excessive length).
|
||
fn validate_glob(pat: &str) -> Result<(), FetchError> {
|
||
if pat.len() > MAX_GLOB_LEN {
|
||
return Err(FetchError::Build(format!(
|
||
"glob pattern exceeds {MAX_GLOB_LEN} chars ({} given)",
|
||
pat.len()
|
||
)));
|
||
}
|
||
// Count non-overlapping occurrences of `**`.
|
||
let bytes = pat.as_bytes();
|
||
let mut count = 0usize;
|
||
let mut i = 0;
|
||
while i + 1 < bytes.len() {
|
||
if bytes[i] == b'*' && bytes[i + 1] == b'*' {
|
||
count += 1;
|
||
// Skip run of consecutive `*` so `***` counts as one.
|
||
while i < bytes.len() && bytes[i] == b'*' {
|
||
i += 1;
|
||
}
|
||
} else {
|
||
i += 1;
|
||
}
|
||
}
|
||
if count > MAX_GLOB_DOUBLESTAR {
|
||
return Err(FetchError::Build(format!(
|
||
"glob pattern has {count} `**` wildcards (max {MAX_GLOB_DOUBLESTAR})"
|
||
)));
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
/// Simple glob matching for URL paths. Supports:
|
||
/// - `*` matches any characters within a single path segment (no `/`)
|
||
/// - `**` matches any characters including `/` (any number of segments)
|
||
/// - Literal characters match exactly
|
||
///
|
||
/// Examples:
|
||
/// - `/api/*` matches `/api/users` but not `/api/users/123`
|
||
/// - `/api/**` matches `/api/users`, `/api/users/123`, `/api/a/b/c`
|
||
/// - `/docs/*/intro` matches `/docs/v2/intro`
|
||
fn glob_match(pattern: &str, path: &str) -> bool {
|
||
glob_match_inner(pattern.as_bytes(), path.as_bytes())
|
||
}
|
||
|
||
fn glob_match_inner(pat: &[u8], text: &[u8]) -> bool {
|
||
let mut pi = 0;
|
||
let mut ti = 0;
|
||
let mut star_pi = usize::MAX;
|
||
let mut star_ti = 0;
|
||
|
||
while ti < text.len() {
|
||
if pi < pat.len() && pat[pi] == b'*' && pi + 1 < pat.len() && pat[pi + 1] == b'*' {
|
||
// `**` -- match everything including slashes
|
||
// Skip all consecutive `*`
|
||
while pi < pat.len() && pat[pi] == b'*' {
|
||
pi += 1;
|
||
}
|
||
// Skip trailing `/` after `**`
|
||
if pi < pat.len() && pat[pi] == b'/' {
|
||
pi += 1;
|
||
}
|
||
if pi >= pat.len() {
|
||
return true; // `**` at end matches everything
|
||
}
|
||
// Try matching the rest of pattern against every suffix of text
|
||
for start in ti..=text.len() {
|
||
if glob_match_inner(&pat[pi..], &text[start..]) {
|
||
return true;
|
||
}
|
||
}
|
||
return false;
|
||
} else if pi < pat.len() && pat[pi] == b'*' {
|
||
// `*` -- match any chars except `/`
|
||
star_pi = pi;
|
||
star_ti = ti;
|
||
pi += 1;
|
||
} else if pi < pat.len() && (pat[pi] == text[ti] || pat[pi] == b'?') {
|
||
pi += 1;
|
||
ti += 1;
|
||
} else if star_pi != usize::MAX {
|
||
// Backtrack: `*` absorbs one more char (but not `/`)
|
||
if text[star_ti] == b'/' {
|
||
return false;
|
||
}
|
||
star_ti += 1;
|
||
ti = star_ti;
|
||
pi = star_pi + 1;
|
||
} else {
|
||
return false;
|
||
}
|
||
}
|
||
|
||
// Consume trailing `*` or `**` in pattern
|
||
while pi < pat.len() && pat[pi] == b'*' {
|
||
pi += 1;
|
||
}
|
||
|
||
pi >= pat.len()
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn normalize_strips_fragment() {
|
||
let url = Url::parse("https://example.com/page#section").unwrap();
|
||
assert_eq!(normalize(&url), "https://example.com/page");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_strips_trailing_slash() {
|
||
let url = Url::parse("https://example.com/docs/").unwrap();
|
||
assert_eq!(normalize(&url), "https://example.com/docs");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_keeps_root_slash() {
|
||
let url = Url::parse("https://example.com/").unwrap();
|
||
assert_eq!(normalize(&url), "https://example.com/");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_preserves_query() {
|
||
let url = Url::parse("https://example.com/search?q=rust&page=2").unwrap();
|
||
assert_eq!(normalize(&url), "https://example.com/search?q=rust&page=2");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_lowercases_host() {
|
||
let url = Url::parse("https://Example.COM/Path").unwrap();
|
||
assert_eq!(normalize(&url), "https://example.com/Path");
|
||
}
|
||
|
||
#[test]
|
||
fn origin_includes_explicit_port() {
|
||
let url = Url::parse("https://example.com:8443/foo").unwrap();
|
||
assert_eq!(origin_key(&url), "https://example.com:8443");
|
||
}
|
||
|
||
#[test]
|
||
fn origin_omits_default_port() {
|
||
let url = Url::parse("https://example.com/foo").unwrap();
|
||
assert_eq!(origin_key(&url), "https://example.com");
|
||
}
|
||
|
||
#[test]
|
||
fn different_schemes_are_different_origins() {
|
||
let http = Url::parse("http://example.com/").unwrap();
|
||
let https = Url::parse("https://example.com/").unwrap();
|
||
assert_ne!(origin_key(&http), origin_key(&https));
|
||
}
|
||
|
||
// -- root_domain tests --
|
||
|
||
#[test]
|
||
fn root_domain_simple() {
|
||
let url = Url::parse("https://example.com/page").unwrap();
|
||
assert_eq!(root_domain(&url), "example.com");
|
||
}
|
||
|
||
#[test]
|
||
fn root_domain_subdomain() {
|
||
let url = Url::parse("https://blog.example.com/page").unwrap();
|
||
assert_eq!(root_domain(&url), "example.com");
|
||
}
|
||
|
||
#[test]
|
||
fn root_domain_deep_subdomain() {
|
||
let url = Url::parse("https://a.b.c.example.com/").unwrap();
|
||
assert_eq!(root_domain(&url), "example.com");
|
||
}
|
||
|
||
#[test]
|
||
fn root_domain_country_tld() {
|
||
let url = Url::parse("https://blog.example.co.uk/").unwrap();
|
||
assert_eq!(root_domain(&url), "example.co.uk");
|
||
}
|
||
|
||
#[test]
|
||
fn root_domain_strips_www() {
|
||
let url = Url::parse("https://www.example.com/").unwrap();
|
||
assert_eq!(root_domain(&url), "example.com");
|
||
}
|
||
|
||
// -- validate_glob tests --
|
||
|
||
#[test]
|
||
fn validate_glob_accepts_reasonable_patterns() {
|
||
assert!(validate_glob("/api/*").is_ok());
|
||
assert!(validate_glob("/api/**").is_ok());
|
||
assert!(validate_glob("/docs/**/page-*.html").is_ok());
|
||
assert!(validate_glob("/a/**/b/**/c/**/d/**").is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_glob_rejects_too_many_doublestars() {
|
||
// 5 `**` exceeds MAX_GLOB_DOUBLESTAR = 4.
|
||
let pat = "/a/**/b/**/c/**/d/**/e/**";
|
||
let err = validate_glob(pat).unwrap_err();
|
||
assert!(matches!(err, FetchError::Build(ref m) if m.contains("`**` wildcards")));
|
||
}
|
||
|
||
#[test]
|
||
fn validate_glob_treats_triple_star_as_one() {
|
||
// `***` is still one run, should not count as 2.
|
||
assert!(validate_glob("/a/***/b/***/c/***/d/***").is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_glob_rejects_oversized_pattern() {
|
||
let giant = "x".repeat(2048);
|
||
let err = validate_glob(&giant).unwrap_err();
|
||
assert!(matches!(err, FetchError::Build(ref m) if m.contains("exceeds")));
|
||
}
|
||
|
||
// -- glob_match tests --
|
||
|
||
#[test]
|
||
fn glob_star_matches_single_segment() {
|
||
assert!(glob_match("/api/*", "/api/users"));
|
||
assert!(glob_match("/api/*", "/api/products"));
|
||
assert!(!glob_match("/api/*", "/api/users/123"));
|
||
}
|
||
|
||
#[test]
|
||
fn glob_doublestar_matches_multiple_segments() {
|
||
assert!(glob_match("/api/**", "/api/users"));
|
||
assert!(glob_match("/api/**", "/api/users/123"));
|
||
assert!(glob_match("/api/**", "/api/a/b/c/d"));
|
||
assert!(!glob_match("/api/**", "/docs/intro"));
|
||
}
|
||
|
||
#[test]
|
||
fn glob_exact_match() {
|
||
assert!(glob_match("/about", "/about"));
|
||
assert!(!glob_match("/about", "/about/team"));
|
||
}
|
||
|
||
#[test]
|
||
fn glob_middle_wildcard() {
|
||
assert!(glob_match("/docs/*/intro", "/docs/v2/intro"));
|
||
assert!(!glob_match("/docs/*/intro", "/docs/v2/v3/intro"));
|
||
}
|
||
|
||
#[test]
|
||
fn glob_no_pattern_matches_nothing() {
|
||
// Empty pattern only matches empty string
|
||
assert!(glob_match("", ""));
|
||
assert!(!glob_match("", "/foo"));
|
||
}
|
||
|
||
#[test]
|
||
fn glob_trailing_star() {
|
||
assert!(glob_match("/blog*", "/blog"));
|
||
assert!(glob_match("/blog*", "/blog-post"));
|
||
assert!(!glob_match("/blog*", "/blog/post")); // * doesn't cross /
|
||
}
|
||
}
|