fix: close validated init and multi-graph gaps

This commit is contained in:
Ragnor Comerford 2026-05-28 15:41:04 +02:00
parent 37ec7373f5
commit eab99e6f48
No known key found for this signature in database
45 changed files with 1058 additions and 454 deletions

View file

@ -2071,11 +2071,16 @@ fn graphs_subcommand_help_lists_list_only() {
/// message — the CLI only operates against remote multi-graph servers.
#[test]
fn graphs_list_against_local_uri_errors_with_remote_only_message() {
let output = output_failure(cli().arg("graphs").arg("list").arg("--uri").arg("/tmp/local"));
let output = output_failure(
cli()
.arg("graphs")
.arg("list")
.arg("--uri")
.arg("/tmp/local"),
);
let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
assert!(
stderr.contains("remote multi-graph server URL"),
"expected 'remote multi-graph server URL' rejection in stderr; got:\n{stderr}"
);
}

View file

@ -37,6 +37,17 @@ rules:
target_branch_scope: protected
"#;
const GRAPH_LIST_SERVER_POLICY_YAML: &str = r#"
version: 1
groups:
admins: [act-admin]
rules:
- id: admins-can-list-graphs
allow:
actors: { group: admins }
actions: [graph_list]
"#;
fn yaml_string(value: &str) -> String {
format!("'{}'", value.replace('\'', "''"))
}
@ -918,13 +929,24 @@ fn graphs_list_against_multi_graph_server() {
.unwrap();
});
fs::write(
cfg_dir.path().join("server-policy.yaml"),
GRAPH_LIST_SERVER_POLICY_YAML,
)
.unwrap();
// Server config with `graphs:` map and no `server.graph` selector
// — multi mode (rule 4 of the inference matrix).
// — multi mode (rule 4 of the inference matrix). `GET /graphs` is a
// server-scoped action, so the success path needs an explicit server
// policy and bearer token.
let server_config_path = cfg_dir.path().join("omnigraph.yaml");
fs::write(
&server_config_path,
format!(
"\
server:
policy:
file: ./server-policy.yaml
graphs:
alpha:
uri: {}
@ -934,7 +956,13 @@ graphs:
)
.unwrap();
let server = spawn_server_with_config(&server_config_path);
let server = spawn_server_with_config_env(
&server_config_path,
&[(
"OMNIGRAPH_SERVER_BEARER_TOKENS_JSON",
r#"{"act-admin":"admin-token"}"#,
)],
);
// Client config — the CLI's `--target dev` resolves to `server.base_url`.
let client_config_path = cfg_dir.path().join("client.yaml");
@ -945,13 +973,21 @@ graphs:
graphs:
dev:
uri: {}
bearer_token_env: GRAPH_LIST_TOKEN
cli:
graph: dev
auth:
env_file: ./.env.omni
",
yaml_string(&server.base_url)
),
)
.unwrap();
fs::write(
cfg_dir.path().join(".env.omni"),
"GRAPH_LIST_TOKEN=admin-token\n",
)
.unwrap();
// `graphs list` lists `alpha`.
let payload = parse_stdout_json(&output_success(

View file

@ -150,9 +150,7 @@ impl SchemaMigrationStep {
/// non-`UnsupportedChange` variant).
pub fn diagnostic(&self) -> Option<&'static crate::lint::DiagnosticCode> {
match self {
Self::UnsupportedChange {
code: Some(c), ..
} => crate::lint::lookup(c),
Self::UnsupportedChange { code: Some(c), .. } => crate::lint::lookup(c),
_ => None,
}
}
@ -1037,10 +1035,7 @@ node Person {
.unwrap();
let plan = plan_schema_migration(&accepted, &desired).unwrap();
assert!(
plan.supported,
"drop-type plan must be supported: {plan:?}"
);
assert!(plan.supported, "drop-type plan must be supported: {plan:?}");
assert!(
plan.steps.iter().any(|step| matches!(
step,
@ -1182,8 +1177,7 @@ node Person @description("new") {
for step in steps {
let json = serde_json::to_string(&step).expect("serialize");
let round_trip: SchemaMigrationStep =
serde_json::from_str(&json).expect("deserialize");
let round_trip: SchemaMigrationStep = serde_json::from_str(&json).expect("deserialize");
assert_eq!(step, round_trip, "round-trip mismatch on {json}");
}
}

View file

@ -271,9 +271,7 @@ fn lower_clauses(
.traversals
.iter()
.find(|rt| {
rt.src == traversal.src
&& rt.dst == traversal.dst
&& rt.edge_type == edge.name
rt.src == traversal.src && rt.dst == traversal.dst && rt.edge_type == edge.name
})
.map(|rt| rt.direction)
.unwrap_or(Direction::Out);

View file

@ -205,12 +205,8 @@ insert Knows { from: $name, to: $friend }
let ir = lower_mutation_query(&qf.queries[0]).unwrap();
assert_eq!(ir.ops.len(), 2);
assert!(
matches!(&ir.ops[0], MutationOpIR::Insert { type_name, .. } if type_name == "Person")
);
assert!(
matches!(&ir.ops[1], MutationOpIR::Insert { type_name, .. } if type_name == "Knows")
);
assert!(matches!(&ir.ops[0], MutationOpIR::Insert { type_name, .. } if type_name == "Person"));
assert!(matches!(&ir.ops[1], MutationOpIR::Insert { type_name, .. } if type_name == "Knows"));
}
/// Destination binding is deferred: NodeScan + Expand + Filter (no cross-join).

View file

@ -18,9 +18,9 @@ pub use catalog::schema_ir::{
pub use catalog::schema_plan::{
DropMode, SchemaMigrationPlan, SchemaMigrationStep, SchemaTypeKind, plan_schema_migration,
};
pub use lint::{DiagnosticCode, Family, SafetyTier, Severity};
pub use ir::ParamMap;
pub use ir::lower::{lower_mutation_query, lower_query};
pub use lint::{DiagnosticCode, Family, SafetyTier, Severity};
pub use query::ast::Literal;
pub use query::lint::{
QueryLintFinding, QueryLintOutput, QueryLintQueryKind, QueryLintQueryResult,

View file

@ -116,7 +116,13 @@ pub const ALL_CODES: &[DiagnosticCode] = &[
];
/// Codes actually emitted by the planner in v0 (i.e. not reserved).
pub const EMITTED_IN_V0: &[&str] = &["OG-DS-102", "OG-DS-103", "OG-DS-104", "OG-MF-103", "OG-MF-106"];
pub const EMITTED_IN_V0: &[&str] = &[
"OG-DS-102",
"OG-DS-103",
"OG-DS-104",
"OG-MF-103",
"OG-MF-106",
];
/// Look up a code by its string identifier.
pub fn lookup(code: &str) -> Option<&'static DiagnosticCode> {

View file

@ -24,5 +24,5 @@
pub mod codes;
pub mod diagnostic;
pub use codes::{lookup, DiagnosticCode, ALL_CODES};
pub use codes::{ALL_CODES, DiagnosticCode, lookup};
pub use diagnostic::{Family, SafetyTier, Severity};

View file

@ -137,12 +137,11 @@ fn parse_query_decl(pair: pest::iterators::Pair<Rule>) -> Result<QueryDecl> {
Rule::mutation_body => {
for mutation_pair in body.into_inner() {
if let Rule::mutation_stmt = mutation_pair.as_rule() {
let stmt =
mutation_pair.into_inner().next().ok_or_else(|| {
NanoError::Parse(
"mutation statement cannot be empty".to_string(),
)
})?;
let stmt = mutation_pair.into_inner().next().ok_or_else(|| {
NanoError::Parse(
"mutation statement cannot be empty".to_string(),
)
})?;
mutations.push(parse_mutation_stmt(stmt)?);
}
}

View file

@ -271,9 +271,9 @@ age: I32?
match &schema.declarations[0] {
SchemaDecl::Node(n) => {
assert!(
n.constraints.iter().any(
|c| matches!(c, Constraint::Range { property, .. } if property == "age")
)
n.constraints
.iter()
.any(|c| matches!(c, Constraint::Range { property, .. } if property == "age"))
);
}
_ => panic!("expected Node"),

View file

@ -358,8 +358,7 @@ impl PolicyConfig {
);
}
if server_scoped
&& (rule.allow.branch_scope.is_some()
|| rule.allow.target_branch_scope.is_some())
&& (rule.allow.branch_scope.is_some() || rule.allow.target_branch_scope.is_some())
{
bail!(
"policy rule '{}' uses branch_scope/target_branch_scope with a \
@ -985,8 +984,8 @@ impl PolicyChecker for PolicyEngine {
#[cfg(test)]
mod tests {
use super::{
PolicyAction, PolicyCompiler, PolicyConfig, PolicyEngine, PolicyExpectation,
PolicyRequest, PolicyTestCase, PolicyTestConfig,
PolicyAction, PolicyCompiler, PolicyConfig, PolicyEngine, PolicyExpectation, PolicyRequest,
PolicyTestCase, PolicyTestConfig,
};
#[test]

View file

@ -235,7 +235,9 @@ pub struct CommitListOutput {
pub struct ReadRequest {
/// GQ query source. May declare one or more named queries; pick one with
/// `query_name` if there is more than one.
#[schema(example = "query get_person($name: String) {\n match {\n $p: Person { name: $name }\n }\n return { $p.name, $p.age }\n}")]
#[schema(
example = "query get_person($name: String) {\n match {\n $p: Person { name: $name }\n }\n return { $p.name, $p.age }\n}"
)]
pub query_source: String,
/// Name of the query to run when `query_source` declares multiple. Optional
/// when only one query is declared.
@ -252,7 +254,9 @@ pub struct ReadRequest {
pub struct ChangeRequest {
/// GQ mutation source containing `insert`, `update`, or `delete` statements.
/// May declare multiple named mutations; pick one with `query_name`.
#[schema(example = "query insert_person($name: String, $age: I32) {\n insert Person { name: $name, age: $age }\n}")]
#[schema(
example = "query insert_person($name: String, $age: I32) {\n insert Person { name: $name, age: $age }\n}"
)]
pub query_source: String,
/// Name of the mutation to run when `query_source` declares multiple.
pub query_name: Option<String>,
@ -266,7 +270,9 @@ pub struct ChangeRequest {
pub struct SchemaApplyRequest {
/// Project schema in `.pg` source form. The diff against the current
/// schema produces the migration steps that will be applied.
#[schema(example = "node Person {\n name: String @key\n age: I32?\n}\n\nedge Knows: Person -> Person")]
#[schema(
example = "node Person {\n name: String @key\n age: I32?\n}\n\nedge Knows: Person -> Person"
)]
pub schema_source: String,
/// When true, promote every `DropMode::Soft` step in the plan to
/// `DropMode::Hard`, making the prior column data unreachable
@ -303,7 +309,9 @@ pub struct IngestRequest {
pub mode: Option<LoadMode>,
/// NDJSON payload: one record per line, each shaped
/// `{"type": "<TypeName>", "data": {...}}`.
#[schema(example = "{\"type\": \"Person\", \"data\": {\"name\": \"Alice\", \"age\": 30}}\n{\"type\": \"Person\", \"data\": {\"name\": \"Bob\", \"age\": 25}}")]
#[schema(
example = "{\"type\": \"Person\", \"data\": {\"name\": \"Alice\", \"age\": 30}}\n{\"type\": \"Person\", \"data\": {\"name\": \"Bob\", \"age\": 25}}"
)]
pub data: String,
}
@ -492,4 +500,3 @@ pub struct GraphInfo {
pub struct GraphListResponse {
pub graphs: Vec<GraphInfo>,
}

View file

@ -119,7 +119,10 @@ pub(crate) fn parse_json_secret_payload(payload: &str) -> Result<Vec<(String, St
bail!("bearer-token secret contains a blank actor id");
}
if token.is_empty() {
bail!("bearer-token secret has a blank token for actor '{}'", actor);
bail!(
"bearer-token secret has a blank token for actor '{}'",
actor
);
}
pairs.push((actor, token));
}
@ -151,8 +154,7 @@ pub mod aws {
/// Construct a new source. Resolves AWS credentials + region via the
/// default chain — no explicit configuration needed on EC2/ECS/EKS.
pub async fn new(secret_id: impl Into<String>) -> Result<Self> {
let config =
aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
let client = aws_sdk_secretsmanager::Client::new(&config);
Ok(Self {
client,
@ -200,8 +202,8 @@ pub use aws::SecretsManagerTokenSource;
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use serial_test::serial;
use std::env;
fn clear_env() {
unsafe {
@ -232,7 +234,10 @@ mod tests {
unsafe {
env::remove_var("OMNIGRAPH_SERVER_BEARER_TOKEN");
}
assert_eq!(tokens, vec![("default".to_string(), "some-token".to_string())]);
assert_eq!(
tokens,
vec![("default".to_string(), "some-token".to_string())]
);
}
#[tokio::test]

View file

@ -177,10 +177,7 @@ mod tests {
#[test]
fn rejects_path_separators() {
for bad in ["alpha/beta", "../etc", "..", "alpha\\beta"] {
assert!(
GraphId::try_from(bad).is_err(),
"expected reject: {bad}"
);
assert!(GraphId::try_from(bad).is_err(), "expected reject: {bad}");
}
}

View file

@ -95,10 +95,7 @@ fn validate_tenant_id(value: &str) -> Result<()> {
);
}
if !tenant_id_regex().is_match(value) {
bail!(
"tenant_id '{}' must match ^[a-zA-Z0-9-]{{1,64}}$",
value
);
bail!("tenant_id '{}' must match ^[a-zA-Z0-9-]{{1,64}}$", value);
}
Ok(())
}

View file

@ -45,6 +45,7 @@ pub use config::{
use futures::stream;
use omnigraph::db::{Omnigraph, ReadTarget};
use omnigraph::error::{ManifestConflictDetails, ManifestErrorKind, OmniError};
use omnigraph::storage::normalize_root_uri;
use omnigraph_compiler::json_params_to_param_map;
use omnigraph_compiler::query::parser::parse_query;
use omnigraph_compiler::{JsonParamMode, ParamMap};
@ -62,6 +63,8 @@ use tower_http::trace::TraceLayer;
use tracing::{error, info, warn};
use tracing_subscriber::EnvFilter;
use utoipa::OpenApi;
use utoipa::openapi::path::{Parameter, ParameterIn};
use utoipa::openapi::schema::{Object, Type};
use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
type BearerTokenHash = [u8; 32];
@ -361,7 +364,7 @@ impl AppState {
uri: impl Into<String>,
bearer_tokens: Vec<(String, String)>,
) -> Result<Self> {
let uri = uri.into();
let uri = normalize_root_uri(&uri.into()).wrap_err("normalize graph URI")?;
let db = Omnigraph::open(&uri).await?;
Ok(Self::new_with_bearer_tokens(uri, db, bearer_tokens))
}
@ -376,7 +379,7 @@ impl AppState {
// single-mode or multi-mode construction is reached. By the
// time we get here, the (policy, no-tokens) combination has
// already been rejected — no second bail needed.
let uri = uri.into();
let uri = normalize_root_uri(&uri.into()).wrap_err("normalize graph URI")?;
let db = Omnigraph::open(&uri).await?;
let policy_engine = match policy_file {
Some(path) => Some(PolicyEngine::load_graph(path, &uri)?),
@ -420,9 +423,9 @@ impl AppState {
// log label, not a routing key — when the future cluster
// catalog ships, single mode may carry the catalog-assigned
// id here instead.
let uri = normalize_root_uri(&uri).unwrap_or(uri);
let key = GraphKey::cluster(
GraphId::try_from("default")
.expect("'default' is a valid GraphId log label"),
GraphId::try_from("default").expect("'default' is a valid GraphId log label"),
);
let handle = Arc::new(GraphHandle {
key,
@ -488,9 +491,7 @@ impl AppState {
// cached `any_per_graph_policy` flag on the registry snapshot.
match &self.routing {
GraphRouting::Single { handle } => handle.policy.is_some(),
GraphRouting::Multi { registry, .. } => {
registry.snapshot_ref().any_per_graph_policy
}
GraphRouting::Multi { registry, .. } => registry.snapshot_ref().any_per_graph_policy,
}
}
@ -509,9 +510,7 @@ impl AppState {
}
}
fn hash_bearer_tokens(
bearer_tokens: Vec<(String, String)>,
) -> Arc<[(BearerTokenHash, Arc<str>)]> {
fn hash_bearer_tokens(bearer_tokens: Vec<(String, String)>) -> Arc<[(BearerTokenHash, Arc<str>)]> {
let tokens: Vec<(BearerTokenHash, Arc<str>)> = bearer_tokens
.into_iter()
.map(|(actor, token)| (hash_bearer_token(&token), Arc::<str>::from(actor)))
@ -519,7 +518,6 @@ fn hash_bearer_tokens(
Arc::from(tokens)
}
impl ApiError {
pub fn unauthorized(message: impl Into<String>) -> Self {
Self {
@ -789,14 +787,25 @@ pub fn load_server_settings(
let mode = if has_cli_uri || has_cli_target || has_server_graph {
// Rules 1, 2, or 3 → Single mode.
let uri = config.resolve_target_uri(
let raw_uri = config.resolve_target_uri(
cli_uri,
cli_target.as_deref(),
config.server_graph_name(),
)?;
let uri = normalize_root_uri(&raw_uri).wrap_err_with(|| {
format!("normalize single-graph URI '{raw_uri}' from server settings")
})?;
let policy_file = config.resolve_policy_file();
ServerConfigMode::Single { uri, policy_file }
} else if has_explicit_config && has_graphs_map {
if config.resolve_policy_file().is_some() {
bail!(
"top-level `policy.file` is single-graph/CLI-local policy only; \
in multi-graph mode move per-graph rules to \
`graphs.<graph_id>.policy.file` and move `graph_list` rules to \
`server.policy.file`."
);
}
// Rule 4 → Multi mode. Build a startup config per graph.
let mut graphs = Vec::with_capacity(config.graphs.len());
for (name, target) in &config.graphs {
@ -806,9 +815,13 @@ pub fn load_server_settings(
GraphId::try_from(name.clone()).map_err(|err| {
color_eyre::eyre::eyre!("invalid graph id '{name}' in omnigraph.yaml: {err}")
})?;
let raw_uri = config.resolve_uri_value(&target.uri);
let uri = normalize_root_uri(&raw_uri).wrap_err_with(|| {
format!("normalize URI '{raw_uri}' for graph '{name}' in omnigraph.yaml")
})?;
graphs.push(GraphStartupConfig {
graph_id: name.clone(),
uri: config.resolve_uri_value(&target.uri),
uri,
policy_file: config.resolve_target_policy_file(name),
});
}
@ -1033,13 +1046,7 @@ pub async fn serve(config: ServerConfig) -> Result<()> {
config = %config_path.display(),
"serving omnigraph"
);
open_multi_graph_state(
graphs,
tokens,
server_policy_file.as_ref(),
config_path,
)
.await?
open_multi_graph_state(graphs, tokens, server_policy_file.as_ref(), config_path).await?
}
};
@ -1090,14 +1097,8 @@ async fn open_multi_graph_state(
.await?;
let workload = workload::WorkloadController::from_env();
let state = AppState::new_multi(
handles,
tokens,
server_policy,
workload,
Some(config_path),
)
.map_err(|err| color_eyre::eyre::eyre!("multi-graph registry: {err}"))?;
let state = AppState::new_multi(handles, tokens, server_policy, workload, Some(config_path))
.map_err(|err| color_eyre::eyre::eyre!("multi-graph registry: {err}"))?;
Ok(state)
}
@ -1106,10 +1107,12 @@ async fn open_multi_graph_state(
async fn open_single_graph(cfg: GraphStartupConfig) -> Result<Arc<GraphHandle>> {
let graph_id = GraphId::try_from(cfg.graph_id.clone())
.map_err(|err| color_eyre::eyre::eyre!("graph id '{}': {err}", cfg.graph_id))?;
let uri = normalize_root_uri(&cfg.uri)
.wrap_err_with(|| format!("normalize URI for graph '{}'", cfg.graph_id))?;
let db = Omnigraph::open(&cfg.uri)
let db = Omnigraph::open(&uri)
.await
.map_err(|err| color_eyre::eyre::eyre!("open graph '{}' at {}: {err}", graph_id, cfg.uri))?;
.map_err(|err| color_eyre::eyre::eyre!("open graph '{}' at {}: {err}", graph_id, uri))?;
let (policy_arc, db) = match &cfg.policy_file {
Some(path) => {
@ -1123,7 +1126,7 @@ async fn open_single_graph(cfg: GraphStartupConfig) -> Result<Arc<GraphHandle>>
Ok(Arc::new(GraphHandle {
key: GraphKey::cluster(graph_id),
uri: cfg.uri,
uri,
engine: Arc::new(db),
policy: policy_arc,
}))
@ -1260,9 +1263,9 @@ const ALWAYS_FLAT_PATHS: &[&str] = &["/healthz", "/graphs"];
/// In multi-mode `server_openapi`, every protected path-item is
/// reattached under the cluster prefix. Operation IDs gain the
/// `cluster_` prefix so SDK generators don't collide if/when both
/// surfaces are merged. The `{graph_id}` URL placeholder is left
/// implicit in the path; consuming clients see it as a standard
/// OpenAPI path parameter.
/// surfaces are merged. Every rewritten operation also declares the
/// required `{graph_id}` path parameter so the served OpenAPI document
/// remains internally valid.
///
/// Removing the flat protected paths matches the runtime router —
/// in multi mode, requests to `/snapshot` etc. return 404, so the
@ -1276,15 +1279,46 @@ fn nest_paths_under_cluster_prefix(doc: &mut utoipa::openapi::OpenApi) {
continue;
}
rename_operation_ids(&mut item, CLUSTER_OPERATION_ID_PREFIX);
add_cluster_graph_id_parameter(&mut item);
let new_path = format!("{CLUSTER_PATH_PREFIX}{path}");
rewritten.insert(new_path, item);
}
doc.paths.paths = rewritten;
}
fn add_cluster_graph_id_parameter(item: &mut utoipa::openapi::PathItem) {
for op in path_item_operations_mut(item) {
let parameters = op.parameters.get_or_insert_with(Vec::new);
let has_graph_id = parameters
.iter()
.any(|param| param.name == "graph_id" && param.parameter_in == ParameterIn::Path);
if !has_graph_id {
parameters.insert(0, graph_id_path_parameter());
}
}
}
fn graph_id_path_parameter() -> Parameter {
let mut parameter = Parameter::new("graph_id");
parameter.parameter_in = ParameterIn::Path;
parameter.description = Some("Graph id to route the request to.".to_string());
parameter.schema = Some(Object::with_type(Type::String).into());
parameter
}
/// Prefix every operation_id in this PathItem with `prefix`.
fn rename_operation_ids(item: &mut utoipa::openapi::PathItem, prefix: &str) {
for op in [
for op in path_item_operations_mut(item) {
if let Some(id) = op.operation_id.as_deref() {
op.operation_id = Some(format!("{prefix}{id}"));
}
}
}
fn path_item_operations_mut(
item: &mut utoipa::openapi::PathItem,
) -> impl Iterator<Item = &mut utoipa::openapi::path::Operation> {
[
item.get.as_mut(),
item.post.as_mut(),
item.put.as_mut(),
@ -1296,11 +1330,6 @@ fn rename_operation_ids(item: &mut utoipa::openapi::PathItem, prefix: &str) {
]
.into_iter()
.flatten()
{
if let Some(id) = op.operation_id.as_deref() {
op.operation_id = Some(format!("{prefix}{id}"));
}
}
}
fn strip_security(doc: &mut utoipa::openapi::OpenApi) {
@ -1405,9 +1434,7 @@ async fn resolve_graph_handle(
match registry.get(&key) {
RegistryLookup::Ready(handle) => handle,
RegistryLookup::Gone => {
return Err(ApiError::not_found(format!(
"graph '{graph_id}' not found"
)));
return Err(ApiError::not_found(format!("graph '{graph_id}' not found")));
}
}
}
@ -1731,7 +1758,9 @@ async fn server_change(
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.actor_id))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.actor_id.as_ref());
let actor_id = actor
.as_ref()
.map(|Extension(actor)| actor.actor_id.as_ref());
authorize_request(
actor.as_ref().map(|Extension(actor)| actor),
handle.policy.as_deref(),
@ -1850,7 +1879,9 @@ async fn server_schema_apply(
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.actor_id))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.actor_id.as_ref());
let actor_id = actor
.as_ref()
.map(|Extension(actor)| actor.actor_id.as_ref());
authorize_request(
actor.as_ref().map(|Extension(actor)| actor),
handle.policy.as_deref(),
@ -1921,7 +1952,9 @@ async fn server_ingest(
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.actor_id))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.actor_id.as_ref());
let actor_id = actor
.as_ref()
.map(|Extension(actor)| actor.actor_id.as_ref());
let branch_exists = {
let db = &handle.engine;
@ -2120,7 +2153,9 @@ async fn server_branch_delete(
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.actor_id))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.actor_id.as_ref());
let actor_id = actor
.as_ref()
.map(|Extension(actor)| actor.actor_id.as_ref());
authorize_request(
actor.as_ref().map(|Extension(actor)| actor),
handle.policy.as_deref(),
@ -2181,7 +2216,9 @@ async fn server_branch_merge(
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.actor_id))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.actor_id.as_ref());
let actor_id = actor
.as_ref()
.map(|Extension(actor)| actor.actor_id.as_ref());
authorize_request(
actor.as_ref().map(|Extension(actor)| actor),
handle.policy.as_deref(),
@ -2417,8 +2454,7 @@ mod tests {
use super::{
GraphStartupConfig, ServerConfig, ServerConfigMode, ServerRuntimeState,
classify_server_runtime_state, hash_bearer_token, load_server_settings,
normalize_bearer_token, parse_bearer_tokens_json, serve,
server_bearer_tokens_from_env,
normalize_bearer_token, parse_bearer_tokens_json, serve, server_bearer_tokens_from_env,
};
use serial_test::serial;
use std::env;
@ -2770,8 +2806,8 @@ server:
// and multi mode get the same enforcement from one source of
// truth.
for allow_unauthenticated in [false, true] {
let err = classify_server_runtime_state(false, true, allow_unauthenticated)
.unwrap_err();
let err =
classify_server_runtime_state(false, true, allow_unauthenticated).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("policy file is configured but no bearer tokens"),

View file

@ -23,6 +23,7 @@ use std::sync::Arc;
use arc_swap::ArcSwap;
use omnigraph::db::Omnigraph;
use omnigraph::storage::normalize_root_uri;
#[cfg(test)]
use tokio::sync::Mutex;
@ -104,6 +105,9 @@ pub enum InsertError {
/// Maps to HTTP 409.
#[error("URI '{0}' is already registered as another graph")]
DuplicateUri(String),
/// A handle carried an invalid graph URI. Maps to startup failure.
#[error("URI '{uri}' is invalid: {message}")]
InvalidUri { uri: String, message: String },
}
pub struct GraphRegistry {
@ -132,13 +136,14 @@ impl GraphRegistry {
let mut graphs: HashMap<GraphKey, Arc<GraphHandle>> = HashMap::with_capacity(handles.len());
let mut seen_uris: HashMap<String, GraphKey> = HashMap::with_capacity(handles.len());
for handle in handles {
let (canonical_uri, handle) = canonicalize_handle_uri(handle)?;
if graphs.contains_key(&handle.key) {
return Err(InsertError::DuplicateKey(handle.key.clone()));
}
if seen_uris.contains_key(&handle.uri) {
if seen_uris.contains_key(&canonical_uri) {
return Err(InsertError::DuplicateUri(handle.uri.clone()));
}
seen_uris.insert(handle.uri.clone(), handle.key.clone());
seen_uris.insert(canonical_uri, handle.key.clone());
graphs.insert(handle.key.clone(), handle);
}
Ok(Self {
@ -203,11 +208,17 @@ impl GraphRegistry {
pub async fn insert(&self, handle: Arc<GraphHandle>) -> Result<(), InsertError> {
let _guard = self.mutate.lock().await;
let current = self.snapshot.load();
let (canonical_uri, handle) = canonicalize_handle_uri(handle)?;
if current.graphs.contains_key(&handle.key) {
return Err(InsertError::DuplicateKey(handle.key.clone()));
}
for existing in current.graphs.values() {
if existing.uri == handle.uri {
let existing_uri =
normalize_root_uri(&existing.uri).map_err(|err| InsertError::InvalidUri {
uri: existing.uri.clone(),
message: err.to_string(),
})?;
if existing_uri == canonical_uri {
return Err(InsertError::DuplicateUri(handle.uri.clone()));
}
}
@ -219,6 +230,25 @@ impl GraphRegistry {
}
}
fn canonicalize_handle_uri(
handle: Arc<GraphHandle>,
) -> Result<(String, Arc<GraphHandle>), InsertError> {
let canonical_uri = normalize_root_uri(&handle.uri).map_err(|err| InsertError::InvalidUri {
uri: handle.uri.clone(),
message: err.to_string(),
})?;
if canonical_uri == handle.uri {
return Ok((canonical_uri, handle));
}
let canonical_handle = Arc::new(GraphHandle {
key: handle.key.clone(),
uri: canonical_uri.clone(),
engine: Arc::clone(&handle.engine),
policy: handle.policy.clone(),
});
Ok((canonical_uri, canonical_handle))
}
impl Default for GraphRegistry {
fn default() -> Self {
Self::new()

View file

@ -270,12 +270,13 @@ mod tests {
let err = controller
.try_admit(&actor, 100)
.expect_err("third should reject on count");
assert!(matches!(err, RejectReason::InFlightCountExceeded { cap: 2 }));
assert!(matches!(
err,
RejectReason::InFlightCountExceeded { cap: 2 }
));
drop(g1);
// After drop, a new admit succeeds again.
let _g3 = controller
.try_admit(&actor, 100)
.expect("admit after drop");
let _g3 = controller.try_admit(&actor, 100).expect("admit after drop");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@ -356,7 +357,9 @@ mod tests {
let bob: Arc<str> = "bob".into();
let _ga = controller.try_admit(&alice, 100).expect("alice ok");
// Alice over count cap, Bob unaffected.
let err = controller.try_admit(&alice, 100).expect_err("alice rejected");
let err = controller
.try_admit(&alice, 100)
.expect_err("alice rejected");
assert!(matches!(err, RejectReason::InFlightCountExceeded { .. }));
let _gb = controller.try_admit(&bob, 100).expect("bob ok");
}

View file

@ -1128,6 +1128,68 @@ async fn multi_mode_openapi_prefixes_operation_ids_with_cluster() {
);
}
#[tokio::test]
async fn multi_mode_openapi_declares_graph_id_path_parameter() {
let (_dirs, app) = app_for_multi_mode(&["alpha"]).await;
let request = Request::builder()
.method(Method::GET)
.uri("/openapi.json")
.body(Body::empty())
.unwrap();
let (_, json) = json_response(&app, request).await;
let paths = json["paths"].as_object().unwrap();
for expected_path in EXPECTED_CLUSTER_PATHS {
let item = paths
.get(*expected_path)
.unwrap_or_else(|| panic!("missing cluster path {expected_path}"));
for method in ["get", "post", "put", "delete", "patch"] {
let Some(operation) = item.get(method).filter(|value| value.is_object()) else {
continue;
};
let parameters = operation["parameters"]
.as_array()
.unwrap_or_else(|| panic!("{expected_path}.{method} missing parameters"));
let graph_id = parameters
.iter()
.find(|param| param["name"] == "graph_id" && param["in"] == "path")
.unwrap_or_else(|| {
panic!("{expected_path}.{method} missing graph_id path parameter")
});
assert_eq!(
graph_id["required"].as_bool(),
Some(true),
"{expected_path}.{method} graph_id parameter must be required"
);
assert_eq!(
graph_id["schema"]["type"].as_str(),
Some("string"),
"{expected_path}.{method} graph_id parameter must be string typed"
);
}
}
for flat in ["/healthz", "/graphs"] {
let item = paths.get(flat).unwrap();
for method in ["get", "post", "put", "delete", "patch"] {
if let Some(operation) = item.get(method).filter(|value| value.is_object()) {
let has_graph_id = operation["parameters"]
.as_array()
.map(|params| {
params
.iter()
.any(|param| param["name"] == "graph_id" && param["in"] == "path")
})
.unwrap_or(false);
assert!(
!has_graph_id,
"{flat}.{method} must not declare graph_id; it remains flat"
);
}
}
}
}
#[tokio::test]
async fn multi_mode_operation_ids_are_unique() {
// Sanity check: the cluster_ prefix prevents collision with flat ids

View file

@ -4411,8 +4411,10 @@ async fn schema_apply_route_additive_property_preserves_existing_rows() {
mod multi_graph_startup {
use super::*;
use omnigraph::storage::normalize_root_uri;
use omnigraph_server::{
GraphHandle, GraphId, GraphKey, ServerConfig, ServerConfigMode, load_server_settings,
GraphHandle, GraphId, GraphKey, GraphRegistry, InsertError, ServerConfig, ServerConfigMode,
load_server_settings,
};
use std::sync::Arc;
@ -4509,16 +4511,38 @@ mod multi_graph_startup {
(Method::GET, "/graphs/alpha/schema", None),
(Method::GET, "/graphs/alpha/branches", None),
(Method::GET, "/graphs/alpha/commits", None),
(Method::POST, "/graphs/alpha/read", Some(r#"{"query_source":"query q() { return {} }"}"#)),
(Method::POST, "/graphs/alpha/change", Some(r#"{"query_source":"query q() { return {} }"}"#)),
(Method::POST, "/graphs/alpha/export", Some(r#"{"branch":"main"}"#)),
(Method::POST, "/graphs/alpha/schema/apply", Some(r#"{"schema_source":"","allow_data_loss":false}"#)),
(
Method::POST,
"/graphs/alpha/read",
Some(r#"{"query_source":"query q() { return {} }"}"#),
),
(
Method::POST,
"/graphs/alpha/change",
Some(r#"{"query_source":"query q() { return {} }"}"#),
),
(
Method::POST,
"/graphs/alpha/export",
Some(r#"{"branch":"main"}"#),
),
(
Method::POST,
"/graphs/alpha/schema/apply",
Some(r#"{"schema_source":"","allow_data_loss":false}"#),
),
(Method::POST, "/graphs/alpha/ingest", Some(r#"{"data":""}"#)),
(Method::POST, "/graphs/alpha/branches/merge", Some(r#"{"source":"main","target":"main"}"#)),
(
Method::POST,
"/graphs/alpha/branches/merge",
Some(r#"{"source":"main","target":"main"}"#),
),
];
for (method, path, body) in cases {
let req_body = body.map(|s| Body::from(s.to_string())).unwrap_or_else(Body::empty);
let req_body = body
.map(|s| Body::from(s.to_string()))
.unwrap_or_else(Body::empty);
let req = Request::builder()
.method(method.clone())
.uri(*path)
@ -4690,6 +4714,57 @@ graphs:
);
}
#[tokio::test(flavor = "multi_thread")]
async fn registry_rejects_duplicate_normalized_graph_uris() {
let dir = tempfile::tempdir().unwrap();
let graph_uri = dir.path().join("same").to_str().unwrap().to_string();
let schema = fs::read_to_string(fixture("test.pg")).unwrap();
let engine = Arc::new(Omnigraph::init(&graph_uri, &schema).await.unwrap());
let alpha = Arc::new(GraphHandle {
key: GraphKey::cluster(GraphId::try_from("alpha").unwrap()),
uri: graph_uri.clone(),
engine: Arc::clone(&engine),
policy: None,
});
let beta = Arc::new(GraphHandle {
key: GraphKey::cluster(GraphId::try_from("beta").unwrap()),
uri: format!("file://{graph_uri}/"),
engine,
policy: None,
});
match GraphRegistry::from_handles(vec![alpha, beta]) {
Err(InsertError::DuplicateUri(uri)) => {
assert!(
normalize_root_uri(&uri).is_ok(),
"duplicate URI should still be parseable, got {uri}"
);
}
Err(err) => panic!("expected DuplicateUri for normalized aliases, got {err:?}"),
Ok(_) => panic!("expected DuplicateUri for normalized aliases, got Ok"),
}
}
#[tokio::test(flavor = "multi_thread")]
async fn registry_stores_canonical_graph_uri() {
let dir = tempfile::tempdir().unwrap();
let graph_uri = dir.path().join("canonical").to_str().unwrap().to_string();
let schema = fs::read_to_string(fixture("test.pg")).unwrap();
let engine = Omnigraph::init(&graph_uri, &schema).await.unwrap();
let handle = Arc::new(GraphHandle {
key: GraphKey::cluster(GraphId::try_from("alpha").unwrap()),
uri: format!("file://{graph_uri}/"),
engine: Arc::new(engine),
policy: None,
});
let registry = GraphRegistry::from_handles(vec![handle]).unwrap();
let listed = registry.list();
assert_eq!(listed.len(), 1);
assert_eq!(listed[0].uri, graph_uri);
}
// ── Four-rule mode inference matrix ───────────────────────────────
/// Rule 1: CLI positional URI → Single.
@ -4752,8 +4827,7 @@ server:
"#,
)
.unwrap();
let settings =
load_server_settings(Some(&config_path), None, None, None, true).unwrap();
let settings = load_server_settings(Some(&config_path), None, None, None, true).unwrap();
match settings.mode {
ServerConfigMode::Single { uri, .. } => assert_eq!(uri, "/tmp/beta.omni"),
ServerConfigMode::Multi { .. } => panic!("expected Single (rule 3), got Multi"),
@ -4776,8 +4850,7 @@ graphs:
"#,
)
.unwrap();
let settings =
load_server_settings(Some(&config_path), None, None, None, true).unwrap();
let settings = load_server_settings(Some(&config_path), None, None, None, true).unwrap();
match settings.mode {
ServerConfigMode::Multi { graphs, .. } => {
let ids: Vec<&str> = graphs.iter().map(|g| g.graph_id.as_str()).collect();
@ -4788,6 +4861,63 @@ graphs:
}
}
#[test]
fn mode_inference_multi_rejects_top_level_policy_file() {
let temp = tempfile::tempdir().unwrap();
let config_path = temp.path().join("omnigraph.yaml");
fs::write(
&config_path,
r#"
policy:
file: ./policy.yaml
graphs:
alpha:
uri: /tmp/alpha.omni
"#,
)
.unwrap();
let err = load_server_settings(Some(&config_path), None, None, None, true).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("top-level `policy.file` is single-graph/CLI-local policy only"),
"expected single-graph policy guidance, got: {msg}"
);
assert!(
msg.contains("graphs.<graph_id>.policy.file"),
"expected per-graph migration guidance, got: {msg}"
);
assert!(
msg.contains("server.policy.file"),
"expected server policy migration guidance, got: {msg}"
);
}
#[test]
fn mode_inference_normalizes_multi_graph_uris() {
let temp = tempfile::tempdir().unwrap();
let graph = temp.path().join("alpha.omni");
let config_path = temp.path().join("omnigraph.yaml");
fs::write(
&config_path,
format!(
r#"
graphs:
alpha:
uri: file://{}/
"#,
graph.display()
),
)
.unwrap();
let settings = load_server_settings(Some(&config_path), None, None, None, true).unwrap();
match settings.mode {
ServerConfigMode::Multi { graphs, .. } => {
assert_eq!(graphs[0].uri, graph.to_string_lossy());
}
ServerConfigMode::Single { .. } => panic!("expected Multi"),
}
}
/// Rule 5: nothing → error with migration hint.
#[test]
fn mode_inference_no_inputs_errors_with_migration_hint() {
@ -4806,8 +4936,7 @@ graphs:
let temp = tempfile::tempdir().unwrap();
let config_path = temp.path().join("omnigraph.yaml");
fs::write(&config_path, "server:\n bind: 127.0.0.1:8080\n").unwrap();
let err =
load_server_settings(Some(&config_path), None, None, None, true).unwrap_err();
let err = load_server_settings(Some(&config_path), None, None, None, true).unwrap_err();
assert!(err.to_string().contains("no graph to serve"));
}
@ -4865,8 +4994,7 @@ graphs:
"#,
)
.unwrap();
let settings =
load_server_settings(Some(&config_path), None, None, None, true).unwrap();
let settings = load_server_settings(Some(&config_path), None, None, None, true).unwrap();
let graphs = match settings.mode {
ServerConfigMode::Multi { graphs, .. } => graphs,
_ => panic!("expected Multi"),
@ -4900,8 +5028,7 @@ graphs:
"#,
)
.unwrap();
let settings =
load_server_settings(Some(&config_path), None, None, None, true).unwrap();
let settings = load_server_settings(Some(&config_path), None, None, None, true).unwrap();
match settings.mode {
ServerConfigMode::Multi {
server_policy_file, ..
@ -5000,8 +5127,7 @@ graphs:
});
let tokens = vec![("act-andrew".to_string(), "secret-token".to_string())];
let workload = omnigraph_server::workload::WorkloadController::from_env();
let state =
AppState::new_multi(vec![handle], tokens, None, workload, None).unwrap();
let state = AppState::new_multi(vec![handle], tokens, None, workload, None).unwrap();
let app = build_app(state);
// No Authorization header → 401.
@ -5092,8 +5218,8 @@ rules:
("act-bruno".to_string(), "bruno-token".to_string()),
];
let workload = omnigraph_server::workload::WorkloadController::from_env();
let state = AppState::new_multi(handles, tokens, Some(server_policy), workload, None)
.unwrap();
let state =
AppState::new_multi(handles, tokens, Some(server_policy), workload, None).unwrap();
let app = build_app(state);
// Admin → 200, body returns both graphs alphabetically sorted.

View file

@ -239,7 +239,9 @@ async fn main() {
let jsonl = generate_jsonl(n, avg_deg, 42);
let t = Instant::now();
load_jsonl(&mut db, &jsonl, LoadMode::Overwrite).await.unwrap();
load_jsonl(&mut db, &jsonl, LoadMode::Overwrite)
.await
.unwrap();
let load_elapsed = t.elapsed();
println!(

View file

@ -10,11 +10,11 @@ pub(crate) mod write_queue;
pub use commit_graph::GraphCommit;
pub use graph_coordinator::{GraphCoordinator, ReadTarget, ResolvedTarget, SnapshotId};
pub use manifest::{Snapshot, SubTableEntry, SubTableUpdate};
pub(crate) use omnigraph::ensure_public_branch_ref;
pub use omnigraph::{
CleanupPolicyOptions, InitOptions, MergeOutcome, Omnigraph, OpenMode, SchemaApplyOptions,
SchemaApplyResult, TableCleanupStats, TableOptimizeStats,
};
pub(crate) use omnigraph::ensure_public_branch_ref;
pub(crate) use run_registry::is_internal_run_branch;
pub(crate) const SCHEMA_APPLY_LOCK_BRANCH: &str = "__schema_apply_lock__";
@ -59,9 +59,7 @@ impl MutationOpKind {
pub(crate) fn strict_pre_stage_version_check(self) -> bool {
match self {
MutationOpKind::Insert | MutationOpKind::Merge => false,
MutationOpKind::Update
| MutationOpKind::Delete
| MutationOpKind::SchemaRewrite => true,
MutationOpKind::Update | MutationOpKind::Delete | MutationOpKind::SchemaRewrite => true,
}
}
}

View file

@ -231,9 +231,7 @@ impl Omnigraph {
schema_state_uri(&root),
] {
if storage.exists(&candidate).await? {
return Err(OmniError::AlreadyInitialized {
uri: root.clone(),
});
return Err(OmniError::AlreadyInitialized { uri: root.clone() });
}
}
}
@ -242,15 +240,34 @@ impl Omnigraph {
let mut catalog = build_catalog_from_ir(&schema_ir)?;
fixup_blob_schemas(&mut catalog);
// Run the I/O phase. On any error, best-effort-clean the schema
// artifacts that may have been written to disk before returning
// the original error. The cleanup is safe in strict mode because
// the preflight above guarantees the three schema URIs did NOT
// exist before this call, so any file there afterward is ours
// to delete. In `force` mode the operator opted in to overwrite
// semantics, so cleanup-on-failure of force re-inits may delete
// schema files that were present pre-call — that's part of the
// force contract.
// Establish an atomic ownership claim on `_schema.pg` before
// writing the remaining init artifacts. A check-then-write preflight
// is not enough under concurrent `init` calls: two callers can both
// observe an empty root, one can successfully initialize, and the
// loser can then fail in Lance `WriteMode::Create`. Only the caller
// that atomically created `_schema.pg` may clean up schema artifacts
// on later failure.
let schema_pg_claimed = if options.force {
false
} else {
let schema_path = join_uri(&root, SCHEMA_SOURCE_FILENAME);
if !storage
.write_text_if_absent(&schema_path, schema_source)
.await?
{
return Err(OmniError::AlreadyInitialized { uri: root.clone() });
}
if let Err(err) = crate::failpoints::maybe_fail("init.after_schema_pg_written") {
best_effort_cleanup_init_artifacts(&root, storage.as_ref()).await;
return Err(err);
}
true
};
// Run the I/O phase. On any error, best-effort-clean schema
// artifacts only when this invocation owns them: strict mode owns
// them after the atomic `_schema.pg` claim above; force mode owns
// destructive overwrite semantics by explicit operator request.
//
// Coverage gap: Lance per-type datasets and `__manifest/`
// directory created by `GraphCoordinator::init` are NOT cleaned
@ -267,12 +284,15 @@ impl Omnigraph {
&schema_ir,
&catalog,
&storage,
!schema_pg_claimed,
)
.await
{
Ok(coordinator) => coordinator,
Err(err) => {
best_effort_cleanup_init_artifacts(&root, storage.as_ref()).await;
if schema_pg_claimed || options.force {
best_effort_cleanup_init_artifacts(&root, storage.as_ref()).await;
}
return Err(err);
}
};
@ -1567,8 +1587,10 @@ fn read_schema_ir_from_source(schema_source: &str) -> Result<SchemaIR> {
/// can pattern-match on the result and run cleanup on error before
/// returning the original error.
///
/// Failpoints fire at the three phase boundaries:
/// * `init.after_schema_pg_written` — `_schema.pg` is on disk.
/// Failpoints fire at the phase boundaries:
/// * `init.after_schema_pg_written` — `_schema.pg` is on disk. In strict mode
/// this fires in the caller immediately after the atomic ownership claim; in
/// force mode it fires here after the explicit overwrite.
/// * `init.after_schema_contract_written` — `_schema.pg` + `_schema.ir.json`
/// + `__schema_state.json` are on disk.
/// * `init.after_coordinator_init` — all schema files plus Lance per-type
@ -1581,10 +1603,13 @@ async fn init_storage_phase(
schema_ir: &SchemaIR,
catalog: &Catalog,
storage: &Arc<dyn StorageAdapter>,
write_schema_pg: bool,
) -> Result<GraphCoordinator> {
let schema_path = join_uri(root, SCHEMA_SOURCE_FILENAME);
storage.write_text(&schema_path, schema_source).await?;
crate::failpoints::maybe_fail("init.after_schema_pg_written")?;
if write_schema_pg {
let schema_path = join_uri(root, SCHEMA_SOURCE_FILENAME);
storage.write_text(&schema_path, schema_source).await?;
crate::failpoints::maybe_fail("init.after_schema_pg_written")?;
}
write_schema_contract(root, storage.as_ref(), schema_ir).await?;
crate::failpoints::maybe_fail("init.after_schema_contract_written")?;
@ -1832,7 +1857,7 @@ mod tests {
use crate::db::manifest::ManifestCoordinator;
use async_trait::async_trait;
use serde_json::Value;
use std::sync::Mutex;
use std::sync::{Arc, Mutex};
use crate::storage::{LocalStorageAdapter, StorageAdapter, join_uri};
@ -1886,6 +1911,11 @@ edge WorksAt: Person -> Company
self.inner.write_text(uri, contents).await
}
async fn write_text_if_absent(&self, uri: &str, contents: &str) -> Result<bool> {
self.writes.lock().unwrap().push(uri.to_string());
self.inner.write_text_if_absent(uri, contents).await
}
async fn exists(&self, uri: &str) -> Result<bool> {
self.exists_checks.lock().unwrap().push(uri.to_string());
self.inner.exists(uri).await
@ -1909,6 +1939,89 @@ edge WorksAt: Person -> Company
}
}
#[derive(Debug)]
struct InitRaceStorageAdapter {
inner: LocalStorageAdapter,
root: String,
barrier: Arc<tokio::sync::Barrier>,
}
#[async_trait]
impl StorageAdapter for InitRaceStorageAdapter {
async fn read_text(&self, uri: &str) -> Result<String> {
self.inner.read_text(uri).await
}
async fn write_text(&self, uri: &str, contents: &str) -> Result<()> {
self.inner.write_text(uri, contents).await
}
async fn write_text_if_absent(&self, uri: &str, contents: &str) -> Result<bool> {
self.inner.write_text_if_absent(uri, contents).await
}
async fn exists(&self, uri: &str) -> Result<bool> {
let exists = self.inner.exists(uri).await?;
if uri == schema_state_uri(&self.root) {
self.barrier.wait().await;
}
Ok(exists)
}
async fn rename_text(&self, from_uri: &str, to_uri: &str) -> Result<()> {
self.inner.rename_text(from_uri, to_uri).await
}
async fn delete(&self, uri: &str) -> Result<()> {
self.inner.delete(uri).await
}
async fn list_dir(&self, dir_uri: &str) -> Result<Vec<String>> {
self.inner.list_dir(dir_uri).await
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn concurrent_strict_init_does_not_delete_winning_schema_files() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap().to_string();
let root = normalize_root_uri(&uri).unwrap();
let storage: Arc<dyn StorageAdapter> = Arc::new(InitRaceStorageAdapter {
inner: LocalStorageAdapter,
root,
barrier: Arc::new(tokio::sync::Barrier::new(2)),
});
let left = Omnigraph::init_with_storage(
&uri,
TEST_SCHEMA,
Arc::clone(&storage),
InitOptions::default(),
);
let right = Omnigraph::init_with_storage(
&uri,
TEST_SCHEMA,
Arc::clone(&storage),
InitOptions::default(),
);
let (left, right) = tokio::join!(left, right);
let ok_count = usize::from(left.is_ok()) + usize::from(right.is_ok());
assert_eq!(ok_count, 1, "exactly one concurrent init should win");
assert!(
dir.path().join("_schema.pg").exists(),
"winning init must leave _schema.pg in place"
);
assert!(
dir.path().join("_schema.ir.json").exists(),
"winning init must leave _schema.ir.json in place"
);
assert!(
dir.path().join("__schema_state.json").exists(),
"winning init must leave __schema_state.json in place"
);
}
#[tokio::test]
async fn test_init_and_open_route_graph_metadata_through_storage_adapter() {
let dir = tempfile::tempdir().unwrap();

View file

@ -16,7 +16,12 @@ pub(super) async fn entity_at(
id: &str,
version: u64,
) -> Result<Option<serde_json::Value>> {
let snap = db.coordinator.read().await.snapshot_at_version(version).await?;
let snap = db
.coordinator
.read()
.await
.snapshot_at_version(version)
.await?;
entity_from_snapshot(db, &snap, table_key, id).await
}

View file

@ -22,7 +22,12 @@ pub(super) async fn graph_index_for_resolved(
}
pub(super) async fn ensure_indices(db: &Omnigraph) -> Result<()> {
let current_branch = db.coordinator.read().await.current_branch().map(str::to_string);
let current_branch = db
.coordinator
.read()
.await
.current_branch()
.map(str::to_string);
ensure_indices_for_branch(db, current_branch.as_deref()).await
}
@ -68,10 +73,7 @@ pub(super) async fn failpoint_publish_table_head_without_index_rebuild_for_test(
.await
}
pub(super) async fn ensure_indices_for_branch(
db: &Omnigraph,
branch: Option<&str>,
) -> Result<()> {
pub(super) async fn ensure_indices_for_branch(db: &Omnigraph, branch: Option<&str>) -> Result<()> {
db.ensure_schema_state_valid().await?;
db.ensure_schema_apply_idle("ensure_indices").await?;
let resolved = db.resolved_branch_target(branch).await?;
@ -403,7 +405,12 @@ pub(super) async fn open_for_mutation(
table_key: &str,
op_kind: crate::db::MutationOpKind,
) -> Result<(Dataset, String, Option<String>)> {
let current_branch = db.coordinator.read().await.current_branch().map(str::to_string);
let current_branch = db
.coordinator
.read()
.await
.current_branch()
.map(str::to_string);
open_for_mutation_on_branch(db, current_branch.as_deref(), table_key, op_kind).await
}
@ -807,7 +814,12 @@ pub(super) async fn commit_prepared_updates_on_branch(
updates: &[crate::db::SubTableUpdate],
actor_id: Option<&str>,
) -> Result<u64> {
let current_branch = db.coordinator.read().await.current_branch().map(str::to_string);
let current_branch = db
.coordinator
.read()
.await
.current_branch()
.map(str::to_string);
let requested_branch = branch.map(str::to_string);
if requested_branch == current_branch {
return commit_prepared_updates(db, updates, actor_id).await;
@ -835,7 +847,12 @@ pub(super) async fn commit_prepared_updates_on_branch_with_expected(
expected_table_versions: &std::collections::HashMap<String, u64>,
actor_id: Option<&str>,
) -> Result<u64> {
let current_branch = db.coordinator.read().await.current_branch().map(str::to_string);
let current_branch = db
.coordinator
.read()
.await
.current_branch()
.map(str::to_string);
let requested_branch = branch.map(str::to_string);
if requested_branch == current_branch {
return commit_prepared_updates_with_expected(
@ -870,7 +887,12 @@ pub(super) async fn commit_updates(
updates: &[crate::db::SubTableUpdate],
) -> Result<u64> {
db.ensure_schema_apply_not_locked("write commit").await?;
let current_branch = db.coordinator.read().await.current_branch().map(str::to_string);
let current_branch = db
.coordinator
.read()
.await
.current_branch()
.map(str::to_string);
let prepared = prepare_updates_for_commit(db, current_branch.as_deref(), updates).await?;
commit_prepared_updates(db, &prepared, None).await
}
@ -879,7 +901,11 @@ pub(super) async fn commit_manifest_updates(
db: &Omnigraph,
updates: &[crate::db::SubTableUpdate],
) -> Result<u64> {
db.coordinator.write().await.commit_manifest_updates(updates).await
db.coordinator
.write()
.await
.commit_manifest_updates(updates)
.await
}
pub(super) async fn record_merge_commit(
@ -889,7 +915,9 @@ pub(super) async fn record_merge_commit(
merged_parent_commit_id: &str,
actor_id: Option<&str>,
) -> Result<String> {
db.coordinator.write().await
db.coordinator
.write()
.await
.record_merge_commit(
manifest_version,
parent_commit_id,
@ -923,7 +951,11 @@ pub(super) async fn commit_updates_on_branch_with_expected(
}
pub(super) async fn ensure_commit_graph_initialized(db: &Omnigraph) -> Result<()> {
db.coordinator.write().await.ensure_commit_graph_initialized().await
db.coordinator
.write()
.await
.ensure_commit_graph_initialized()
.await
}
pub(super) async fn invalidate_graph_index(db: &Omnigraph) {

View file

@ -91,10 +91,7 @@ impl WriteQueueManager {
/// Empty input returns an empty Vec without touching the map.
/// Duplicates in `keys` are deduped before acquisition (the same
/// key acquired twice would deadlock against itself).
pub(crate) async fn acquire_many(
&self,
keys: &[TableQueueKey],
) -> Vec<OwnedMutexGuard<()>> {
pub(crate) async fn acquire_many(&self, keys: &[TableQueueKey]) -> Vec<OwnedMutexGuard<()>> {
if keys.is_empty() {
return Vec::new();
}
@ -167,7 +164,10 @@ mod tests {
qm2.acquire_many(&[z_clone, a_clone]).await
})
.await;
assert!(result.is_err(), "acquire_many should block on `a`, the lex-first key");
assert!(
result.is_err(),
"acquire_many should block on `a`, the lex-first key"
);
}
#[tokio::test]
@ -180,9 +180,10 @@ mod tests {
// Second acquire on same key should NOT complete within 200ms.
let qm2 = Arc::clone(&qm);
let k2 = k.clone();
let blocked = timeout(Duration::from_millis(200), async move {
qm2.acquire(&k2).await
})
let blocked = timeout(
Duration::from_millis(200),
async move { qm2.acquire(&k2).await },
)
.await;
assert!(blocked.is_err(), "second acquire on same key must block");

View file

@ -794,11 +794,8 @@ impl Omnigraph {
// post_commit_pin) and tidies up. Failing the user
// here would return an error for a write that
// already landed.
if let Err(err) = crate::db::manifest::delete_sidecar(
&handle,
self.storage_adapter(),
)
.await
if let Err(err) =
crate::db::manifest::delete_sidecar(&handle, self.storage_adapter()).await
{
tracing::warn!(
error = %err,
@ -852,15 +849,8 @@ impl Omnigraph {
assignments,
predicate,
} => {
self.execute_update(
type_name,
assignments,
predicate,
params,
branch,
staging,
)
.await?
self.execute_update(type_name, assignments, predicate, params, branch, staging)
.await?
}
MutationOpIR::Delete {
type_name,
@ -981,14 +971,8 @@ impl Omnigraph {
// + iterate pending edges in-memory for the `src` column,
// group-by-src. The pending side already includes the row
// we just appended (above).
validate_edge_cardinality_with_pending(
self,
&ds,
staging,
&table_key,
edge_type,
)
.await?;
validate_edge_cardinality_with_pending(self, &ds, staging, &table_key, edge_type)
.await?;
self.invalidate_graph_index().await;
@ -1379,14 +1363,8 @@ async fn validate_edge_cardinality_with_pending(
if edge_type.cardinality.is_default() {
return Ok(());
}
let counts = super::staging::count_src_per_edge(
db,
committed_ds,
table_key,
staging,
None,
)
.await?;
let counts =
super::staging::count_src_per_edge(db, committed_ds, table_key, staging, None).await?;
super::staging::enforce_cardinality_bounds(edge_type, &counts)
}

View file

@ -345,10 +345,7 @@ fn evaluate_projection(
IRExpr::PropAccess { variable, property } => {
let col_name = format!("{}.{}", variable, property);
let col = wide_batch.column_by_name(&col_name).ok_or_else(|| {
OmniError::manifest(format!(
"column '{}' not found in wide batch",
col_name
))
OmniError::manifest(format!("column '{}' not found in wide batch", col_name))
})?;
Ok((col_name, col.clone()))
}
@ -516,12 +513,10 @@ fn aggregate_return(
}
let num_groups = group_indices.len();
let mut result_columns: Vec<(usize, String, ArrayRef)> =
Vec::with_capacity(projections.len());
let mut result_columns: Vec<(usize, String, ArrayRef)> = Vec::with_capacity(projections.len());
for gk in &group_keys {
let first_row_indices: Vec<u32> =
group_indices.iter().map(|rows| rows[0] as u32).collect();
let first_row_indices: Vec<u32> = group_indices.iter().map(|rows| rows[0] as u32).collect();
let take_idx = UInt32Array::from(first_row_indices);
let col = arrow_select::take::take(gk.column.as_ref(), &take_idx, None)
.map_err(|e| OmniError::Lance(e.to_string()))?;
@ -584,11 +579,19 @@ fn compute_aggregate(
}
}
fn compute_sum(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize) -> Result<ArrayRef> {
fn compute_sum(
arg: &ArrayRef,
group_indices: &[Vec<usize>],
num_groups: usize,
) -> Result<ArrayRef> {
macro_rules! sum_numeric {
($arr_type:ty, $arg:expr, $dt:expr) => {{
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
OmniError::manifest(format!("sum: expected {:?}, got {:?}", $dt, $arg.data_type()))
OmniError::manifest(format!(
"sum: expected {:?}, got {:?}",
$dt,
$arg.data_type()
))
})?;
let mut builder = Float64Builder::with_capacity(num_groups);
for group in group_indices {
@ -613,24 +616,42 @@ fn compute_sum(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize)
dt @ DataType::UInt64 => sum_numeric!(UInt64Array, arg, dt),
dt @ DataType::Float32 => sum_numeric!(Float32Array, arg, dt),
dt @ DataType::Float64 => sum_numeric!(Float64Array, arg, dt),
dt => Err(OmniError::manifest(format!("sum: unsupported type {:?}", dt))),
dt => Err(OmniError::manifest(format!(
"sum: unsupported type {:?}",
dt
))),
}
}
fn compute_avg(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize) -> Result<ArrayRef> {
fn compute_avg(
arg: &ArrayRef,
group_indices: &[Vec<usize>],
num_groups: usize,
) -> Result<ArrayRef> {
macro_rules! avg_typed {
($arr_type:ty, $arg:expr) => {{
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
OmniError::manifest(format!("avg: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type()))
OmniError::manifest(format!(
"avg: expected {:?}, got {:?}",
stringify!($arr_type),
$arg.data_type()
))
})?;
let mut builder = Float64Builder::with_capacity(num_groups);
for group in group_indices {
let mut sum = 0.0f64;
let mut count = 0usize;
for &i in group {
if !arr.is_null(i) { sum += arr.value(i) as f64; count += 1; }
if !arr.is_null(i) {
sum += arr.value(i) as f64;
count += 1;
}
}
if count > 0 {
builder.append_value(sum / count as f64);
} else {
builder.append_null();
}
if count > 0 { builder.append_value(sum / count as f64); } else { builder.append_null(); }
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}};
@ -642,15 +663,27 @@ fn compute_avg(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize)
DataType::UInt64 => avg_typed!(UInt64Array, arg),
DataType::Float32 => avg_typed!(Float32Array, arg),
DataType::Float64 => avg_typed!(Float64Array, arg),
dt => Err(OmniError::manifest(format!("avg: unsupported type {:?}", dt))),
dt => Err(OmniError::manifest(format!(
"avg: unsupported type {:?}",
dt
))),
}
}
fn compute_min_max(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usize, is_min: bool) -> Result<ArrayRef> {
fn compute_min_max(
arg: &ArrayRef,
group_indices: &[Vec<usize>],
num_groups: usize,
is_min: bool,
) -> Result<ArrayRef> {
macro_rules! minmax_typed {
($arr_type:ty, $builder_type:ty, $arg:expr, $is_min:expr) => {{
let arr = $arg.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
OmniError::manifest(format!("min/max: expected {:?}, got {:?}", stringify!($arr_type), $arg.data_type()))
OmniError::manifest(format!(
"min/max: expected {:?}, got {:?}",
stringify!($arr_type),
$arg.data_type()
))
})?;
let mut builder = <$builder_type>::with_capacity(num_groups);
for group in group_indices {
@ -660,11 +693,20 @@ fn compute_min_max(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usi
let v = arr.value(i);
result = Some(match result {
None => v,
Some(cur) => if $is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } },
Some(cur) => {
if $is_min {
if v < cur { v } else { cur }
} else {
if v > cur { v } else { cur }
}
}
});
}
}
match result { Some(v) => builder.append_value(v), None => builder.append_null() }
match result {
Some(v) => builder.append_value(v),
None => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}};
@ -688,15 +730,27 @@ fn compute_min_max(arg: &ArrayRef, group_indices: &[Vec<usize>], num_groups: usi
let v = arr.value(i);
result = Some(match result {
None => v,
Some(cur) => if is_min { if v < cur { v } else { cur } } else { if v > cur { v } else { cur } },
Some(cur) => {
if is_min {
if v < cur { v } else { cur }
} else {
if v > cur { v } else { cur }
}
}
});
}
}
match result { Some(v) => builder.append_value(v), None => builder.append_null() }
match result {
Some(v) => builder.append_value(v),
None => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
dt => Err(OmniError::manifest(format!("min/max: unsupported type {:?}", dt))),
dt => Err(OmniError::manifest(format!(
"min/max: unsupported type {:?}",
dt
))),
}
}
@ -715,7 +769,8 @@ fn build_empty_aggregate_result(projections: &[IRProjection]) -> Result<RecordBa
}
_ => {
fields.push(Field::new(name, DataType::Float64, true));
columns.push(Arc::new(Float64Array::from(vec![None as Option<f64>])) as ArrayRef);
columns
.push(Arc::new(Float64Array::from(vec![None as Option<f64>])) as ArrayRef);
}
},
_ => {

View file

@ -75,14 +75,7 @@ impl Omnigraph {
None
};
execute_query(
&ir,
params,
&snapshot,
graph_index.as_deref(),
&catalog,
)
.await
execute_query(&ir, params, &snapshot, graph_index.as_deref(), &catalog).await
}
}
@ -360,11 +353,23 @@ pub async fn execute_query(
}
let mut wide: Option<RecordBatch> = None;
execute_pipeline(&ir.pipeline, params, snapshot, graph_index, catalog, &mut wide, &search_mode).await?;
execute_pipeline(
&ir.pipeline,
params,
snapshot,
graph_index,
catalog,
&mut wide,
&search_mode,
)
.await?;
let wide_batch = wide.unwrap_or_else(|| RecordBatch::new_empty(Arc::new(Schema::empty())));
// Project return expressions
let has_aggregates = ir.return_exprs.iter().any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
let has_aggregates = ir
.return_exprs
.iter()
.any(|p| matches!(&p.expr, IRExpr::Aggregate { .. }));
let mut result_batch = project_return(&wide_batch, &ir.return_exprs, params)?;
// Apply ordering (skip if search mode already ordered the results)
@ -516,9 +521,9 @@ async fn execute_rrf_query(
}
fn extract_id_column_by_name(batch: &RecordBatch, col_name: &str) -> Result<Vec<String>> {
let col = batch
.column_by_name(col_name)
.ok_or_else(|| OmniError::manifest(format!("batch missing '{}' column for RRF", col_name)))?;
let col = batch.column_by_name(col_name).ok_or_else(|| {
OmniError::manifest(format!("batch missing '{}' column for RRF", col_name))
})?;
let ids = col
.as_any()
.downcast_ref::<StringArray>()
@ -653,8 +658,19 @@ fn execute_pipeline<'a>(
})?;
if let Some(batch) = wide.as_mut() {
execute_expand(
batch, gi, snapshot, catalog, src_var, dst_var, edge_type, *direction,
dst_type, *min_hops, *max_hops, dst_filters, params,
batch,
gi,
snapshot,
catalog,
src_var,
dst_var,
edge_type,
*direction,
dst_type,
*min_hops,
*max_hops,
dst_filters,
params,
)
.await?;
}
@ -691,7 +707,9 @@ async fn execute_expand(
let src_id_col_name = format!("{}.id", src_var);
let src_ids = wide
.column_by_name(&src_id_col_name)
.ok_or_else(|| OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name)))?
.ok_or_else(|| {
OmniError::manifest(format!("wide batch missing '{}' column", src_id_col_name))
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| OmniError::manifest(format!("'{}' column is not Utf8", src_id_col_name)))?
@ -1421,22 +1439,39 @@ fn literal_to_expr(lit: &Literal) -> Option<datafusion::prelude::Expr> {
}
fn prefix_batch(batch: &RecordBatch, variable: &str) -> Result<RecordBatch> {
let fields: Vec<Field> = batch.schema().fields().iter().map(|f| {
Field::new(format!("{}.{}", variable, f.name()), f.data_type().clone(), f.is_nullable())
}).collect();
let fields: Vec<Field> = batch
.schema()
.fields()
.iter()
.map(|f| {
Field::new(
format!("{}.{}", variable, f.name()),
f.data_type().clone(),
f.is_nullable(),
)
})
.collect();
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, batch.columns().to_vec()).map_err(|e| OmniError::Lance(e.to_string()))
RecordBatch::try_new(schema, batch.columns().to_vec())
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
let n = left.num_rows();
let m = right.num_rows();
if n == 0 || m == 0 {
let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
let mut fields: Vec<Field> = left
.schema()
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect();
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
return Ok(RecordBatch::new_empty(Arc::new(Schema::new(fields))));
}
let left_indices: Vec<u32> = (0..n as u32).flat_map(|i| std::iter::repeat(i).take(m)).collect();
let left_indices: Vec<u32> = (0..n as u32)
.flat_map(|i| std::iter::repeat(i).take(m))
.collect();
let right_indices: Vec<u32> = (0..n).flat_map(|_| 0..m as u32).collect();
let left_expanded = take_batch(left, &UInt32Array::from(left_indices))?;
let right_expanded = take_batch(right, &UInt32Array::from(right_indices))?;
@ -1444,23 +1479,39 @@ fn cross_join_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordB
}
fn hconcat_batches(left: &RecordBatch, right: &RecordBatch) -> Result<RecordBatch> {
let mut fields: Vec<Field> = left.schema().fields().iter().map(|f| f.as_ref().clone()).collect();
let mut fields: Vec<Field> = left
.schema()
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect();
if cfg!(debug_assertions) {
let left_schema = left.schema();
let left_names: HashSet<&str> = left_schema.fields().iter().map(|f| f.name().as_str()).collect();
let left_names: HashSet<&str> = left_schema
.fields()
.iter()
.map(|f| f.name().as_str())
.collect();
let right_schema = right.schema();
for f in right_schema.fields() {
debug_assert!(!left_names.contains(f.name().as_str()), "hconcat_batches: duplicate column '{}'", f.name());
debug_assert!(
!left_names.contains(f.name().as_str()),
"hconcat_batches: duplicate column '{}'",
f.name()
);
}
}
fields.extend(right.schema().fields().iter().map(|f| f.as_ref().clone()));
let mut columns: Vec<ArrayRef> = left.columns().to_vec();
columns.extend(right.columns().to_vec());
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).map_err(|e| OmniError::Lance(e.to_string()))
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn take_batch(batch: &RecordBatch, indices: &UInt32Array) -> Result<RecordBatch> {
let columns: Vec<ArrayRef> = batch.columns().iter()
let columns: Vec<ArrayRef> = batch
.columns()
.iter()
.map(|col| arrow_select::take::take(col.as_ref(), indices, None))
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| OmniError::Lance(e.to_string()))?;

View file

@ -212,12 +212,7 @@ impl Omnigraph {
.await
}
pub async fn load_file(
&self,
branch: &str,
path: &str,
mode: LoadMode,
) -> Result<LoadResult> {
pub async fn load_file(&self, branch: &str, path: &str, mode: LoadMode) -> Result<LoadResult> {
self.load_file_as(branch, path, mode, None).await
}
@ -457,13 +452,7 @@ async fn load_jsonl_reader<R: BufRead>(
for (edge_name, rows) in &edge_rows {
let edge_type = &catalog.edge_types[edge_name];
let from_ids = if use_staging {
collect_node_ids_with_pending(
db,
branch,
&edge_type.from_type,
&staging,
)
.await?
collect_node_ids_with_pending(db, branch, &edge_type.from_type, &staging).await?
} else {
collect_node_ids(
db,
@ -476,13 +465,7 @@ async fn load_jsonl_reader<R: BufRead>(
.await?
};
let to_ids = if use_staging {
collect_node_ids_with_pending(
db,
branch,
&edge_type.to_type,
&staging,
)
.await?
collect_node_ids_with_pending(db, branch, &edge_type.to_type, &staging).await?
} else {
collect_node_ids(
db,
@ -581,12 +564,7 @@ async fn load_jsonl_reader<R: BufRead>(
let table_key = format!("edge:{}", edge_name);
if use_staging {
validate_edge_cardinality_with_pending_loader(
db,
branch,
edge_type,
&table_key,
&staging,
mode,
db, branch, edge_type, &table_key, &staging, mode,
)
.await?;
} else if let Some(update) = overwrite_updates.iter().find(|u| u.table_key == table_key) {
@ -1699,8 +1677,7 @@ async fn validate_edge_cardinality_with_pending_loader(
LoadMode::Append | LoadMode::Overwrite => None,
};
let counts =
crate::exec::staging::count_src_per_edge(db, &ds, table_key, staging, dedupe_key)
.await?;
crate::exec::staging::count_src_per_edge(db, &ds, table_key, staging, dedupe_key).await?;
crate::exec::staging::enforce_cardinality_bounds(edge_type, &counts)
}

View file

@ -7,7 +7,8 @@ use async_trait::async_trait;
use futures::TryStreamExt;
use object_store::aws::AmazonS3Builder;
use object_store::path::Path as ObjectPath;
use object_store::{DynObjectStore, ObjectStore, PutPayload};
use object_store::{DynObjectStore, ObjectStore, PutMode, PutPayload};
use tokio::io::AsyncWriteExt;
use url::Url;
use crate::error::{OmniError, Result};
@ -19,6 +20,13 @@ const S3_SCHEME_PREFIX: &str = "s3://";
pub trait StorageAdapter: Debug + Send + Sync {
async fn read_text(&self, uri: &str) -> Result<String>;
async fn write_text(&self, uri: &str, contents: &str) -> Result<()>;
/// Write a text object only if no object exists at `uri`.
///
/// Returns `Ok(true)` when this call created the object, `Ok(false)`
/// when the object already existed, and propagates every other storage
/// error. Callers use this to establish ownership before running
/// best-effort cleanup on partial failure.
async fn write_text_if_absent(&self, uri: &str, contents: &str) -> Result<bool>;
async fn exists(&self, uri: &str) -> Result<bool>;
/// Move a file from `from_uri` to `to_uri`, replacing any existing file at
/// `to_uri`. Atomic on local POSIX; on S3 implemented as copy + delete
@ -77,6 +85,30 @@ impl StorageAdapter for LocalStorageAdapter {
Ok(())
}
async fn write_text_if_absent(&self, uri: &str, contents: &str) -> Result<bool> {
let path = local_path_from_uri(uri)?;
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
tokio::fs::create_dir_all(parent).await?;
}
}
let mut file = match tokio::fs::OpenOptions::new()
.write(true)
.create_new(true)
.open(&path)
.await
{
Ok(file) => file,
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => return Ok(false),
Err(err) => return Err(err.into()),
};
if let Err(err) = file.write_all(contents.as_bytes()).await {
let _ = tokio::fs::remove_file(&path).await;
return Err(err.into());
}
Ok(true)
}
async fn exists(&self, uri: &str) -> Result<bool> {
Ok(local_path_from_uri(uri)?.exists())
}
@ -146,6 +178,24 @@ impl StorageAdapter for S3StorageAdapter {
Ok(())
}
async fn write_text_if_absent(&self, uri: &str, contents: &str) -> Result<bool> {
let location = self.object_path(uri)?;
match self
.store
.put_opts(
&location,
PutPayload::from(contents.as_bytes().to_vec()),
PutMode::Create.into(),
)
.await
{
Ok(_) => Ok(true),
Err(object_store::Error::AlreadyExists { .. })
| Err(object_store::Error::Precondition { .. }) => Ok(false),
Err(err) => Err(storage_backend_error("write_if_absent", uri, err)),
}
}
async fn exists(&self, uri: &str) -> Result<bool> {
let location = self.object_path(uri)?;
match self.store.head(&location).await {
@ -447,4 +497,16 @@ mod tests {
assert_eq!(location.bucket, "bucket");
assert_eq!(location.key, "graph/_schema.pg");
}
#[tokio::test]
async fn local_write_text_if_absent_creates_once_without_overwrite() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().join("claim.txt");
let uri = uri.to_str().unwrap();
let storage = LocalStorageAdapter;
assert!(storage.write_text_if_absent(uri, "first").await.unwrap());
assert!(!storage.write_text_if_absent(uri, "second").await.unwrap());
assert_eq!(storage.read_text(uri).await.unwrap(), "first");
}
}

View file

@ -94,7 +94,9 @@ impl SnapshotHandle {
/// Construct from a Lance dataset. `pub(crate)` — only
/// `TableStore` should produce these.
pub(crate) fn new(ds: Dataset) -> Self {
Self { inner: Arc::new(ds) }
Self {
inner: Arc::new(ds),
}
}
/// Borrow the underlying Lance dataset. `pub(crate)` so only the
@ -242,16 +244,10 @@ pub trait TableStorage: sealed::Sealed + Send + Sync + Debug {
async fn scan_batches(&self, snapshot: &SnapshotHandle) -> Result<Vec<RecordBatch>>;
async fn scan_batches_for_rewrite(
&self,
snapshot: &SnapshotHandle,
) -> Result<Vec<RecordBatch>>;
async fn scan_batches_for_rewrite(&self, snapshot: &SnapshotHandle)
-> Result<Vec<RecordBatch>>;
async fn count_rows(
&self,
snapshot: &SnapshotHandle,
filter: Option<String>,
) -> Result<usize>;
async fn count_rows(&self, snapshot: &SnapshotHandle, filter: Option<String>) -> Result<usize>;
async fn count_rows_with_staged(
&self,
@ -284,11 +280,8 @@ pub trait TableStorage: sealed::Sealed + Send + Sync + Debug {
filter: &str,
) -> Result<Option<u64>>;
async fn table_state(
&self,
dataset_uri: &str,
snapshot: &SnapshotHandle,
) -> Result<TableState>;
async fn table_state(&self, dataset_uri: &str, snapshot: &SnapshotHandle)
-> Result<TableState>;
// ── Staged writes (no HEAD advance) ────────────────────────────────
@ -565,11 +558,7 @@ impl TableStorage for TableStore {
TableStore::scan_batches_for_rewrite(self, snapshot.dataset()).await
}
async fn count_rows(
&self,
snapshot: &SnapshotHandle,
filter: Option<String>,
) -> Result<usize> {
async fn count_rows(&self, snapshot: &SnapshotHandle, filter: Option<String>) -> Result<usize> {
TableStore::count_rows(self, snapshot.dataset(), filter).await
}
@ -591,14 +580,8 @@ impl TableStorage for TableStore {
filter: Option<&str>,
) -> Result<Vec<RecordBatch>> {
let staged_writes = staged_handles_as_writes(staged);
TableStore::scan_with_staged(
self,
snapshot.dataset(),
&staged_writes,
projection,
filter,
)
.await
TableStore::scan_with_staged(self, snapshot.dataset(), &staged_writes, projection, filter)
.await
}
async fn scan_with_pending(
@ -658,18 +641,10 @@ impl TableStorage for TableStore {
when_matched: WhenMatched,
when_not_matched: WhenNotMatched,
) -> Result<StagedHandle> {
let ds = Arc::try_unwrap(snapshot.into_arc())
.unwrap_or_else(|arc| (*arc).clone());
TableStore::stage_merge_insert(
self,
ds,
batch,
key_columns,
when_matched,
when_not_matched,
)
.await
.map(StagedHandle::new)
let ds = Arc::try_unwrap(snapshot.into_arc()).unwrap_or_else(|arc| (*arc).clone());
TableStore::stage_merge_insert(self, ds, batch, key_columns, when_matched, when_not_matched)
.await
.map(StagedHandle::new)
}
async fn commit_staged(
@ -720,8 +695,7 @@ impl TableStorage for TableStore {
snapshot: SnapshotHandle,
batch: RecordBatch,
) -> Result<(SnapshotHandle, TableState)> {
let mut ds = Arc::try_unwrap(snapshot.into_arc())
.unwrap_or_else(|arc| (*arc).clone());
let mut ds = Arc::try_unwrap(snapshot.into_arc()).unwrap_or_else(|arc| (*arc).clone());
let state = TableStore::append_batch(self, dataset_uri, &mut ds, batch).await?;
Ok((SnapshotHandle::new(ds), state))
}
@ -735,8 +709,7 @@ impl TableStorage for TableStore {
when_matched: WhenMatched,
when_not_matched: WhenNotMatched,
) -> Result<TableState> {
let ds = Arc::try_unwrap(snapshot.into_arc())
.unwrap_or_else(|arc| (*arc).clone());
let ds = Arc::try_unwrap(snapshot.into_arc()).unwrap_or_else(|arc| (*arc).clone());
TableStore::merge_insert_batches(
self,
dataset_uri,
@ -755,8 +728,7 @@ impl TableStorage for TableStore {
snapshot: SnapshotHandle,
batch: RecordBatch,
) -> Result<(SnapshotHandle, TableState)> {
let mut ds = Arc::try_unwrap(snapshot.into_arc())
.unwrap_or_else(|arc| (*arc).clone());
let mut ds = Arc::try_unwrap(snapshot.into_arc()).unwrap_or_else(|arc| (*arc).clone());
let state = TableStore::overwrite_batch(self, dataset_uri, &mut ds, batch).await?;
Ok((SnapshotHandle::new(ds), state))
}
@ -767,8 +739,7 @@ impl TableStorage for TableStore {
snapshot: SnapshotHandle,
filter: &str,
) -> Result<(SnapshotHandle, DeleteState)> {
let mut ds = Arc::try_unwrap(snapshot.into_arc())
.unwrap_or_else(|arc| (*arc).clone());
let mut ds = Arc::try_unwrap(snapshot.into_arc()).unwrap_or_else(|arc| (*arc).clone());
let state = TableStore::delete_where(self, dataset_uri, &mut ds, filter).await?;
Ok((SnapshotHandle::new(ds), state))
}
@ -790,8 +761,7 @@ impl TableStorage for TableStore {
snapshot: SnapshotHandle,
columns: &[&str],
) -> Result<SnapshotHandle> {
let mut ds = Arc::try_unwrap(snapshot.into_arc())
.unwrap_or_else(|arc| (*arc).clone());
let mut ds = Arc::try_unwrap(snapshot.into_arc()).unwrap_or_else(|arc| (*arc).clone());
TableStore::create_btree_index(self, &mut ds, columns).await?;
Ok(SnapshotHandle::new(ds))
}
@ -801,8 +771,7 @@ impl TableStorage for TableStore {
snapshot: SnapshotHandle,
column: &str,
) -> Result<SnapshotHandle> {
let mut ds = Arc::try_unwrap(snapshot.into_arc())
.unwrap_or_else(|arc| (*arc).clone());
let mut ds = Arc::try_unwrap(snapshot.into_arc()).unwrap_or_else(|arc| (*arc).clone());
TableStore::create_inverted_index(self, &mut ds, column).await?;
Ok(SnapshotHandle::new(ds))
}
@ -812,8 +781,7 @@ impl TableStorage for TableStore {
snapshot: SnapshotHandle,
column: &str,
) -> Result<SnapshotHandle> {
let mut ds = Arc::try_unwrap(snapshot.into_arc())
.unwrap_or_else(|arc| (*arc).clone());
let mut ds = Arc::try_unwrap(snapshot.into_arc()).unwrap_or_else(|arc| (*arc).clone());
TableStore::create_vector_index(self, &mut ds, column).await?;
Ok(SnapshotHandle::new(ds))
}
@ -837,6 +805,13 @@ impl TableStorage for TableStore {
// Note: existing TableStore::scan_stream is an associated fn that
// takes &Dataset, so we delegate via the dataset reference held by
// the snapshot.
TableStore::scan_stream(snapshot.dataset(), projection, filter, order_by, with_row_id).await
TableStore::scan_stream(
snapshot.dataset(),
projection,
filter,
order_by,
with_row_id,
)
.await
}
}

View file

@ -1793,25 +1793,24 @@ mod tests {
#[test]
fn check_batch_unique_by_keys_errors_on_duplicate_id() {
let batch = batch_with_ids(&["a", "b", "a"]);
let err =
check_batch_unique_by_keys(&batch, &["id".to_string()], "test").unwrap_err();
let err = check_batch_unique_by_keys(&batch, &["id".to_string()], "test").unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("duplicate source row for key 'a'"),
"unexpected error: {msg}"
);
assert!(msg.contains("MR-957"), "error should reference MR-957: {msg}");
assert!(
msg.contains("MR-957"),
"error should reference MR-957: {msg}"
);
}
#[test]
fn check_batch_unique_by_keys_rejects_multi_column_keys() {
let batch = batch_with_ids(&["a"]);
let err = check_batch_unique_by_keys(
&batch,
&["id".to_string(), "other".to_string()],
"test",
)
.unwrap_err();
let err =
check_batch_unique_by_keys(&batch, &["id".to_string(), "other".to_string()], "test")
.unwrap_err();
assert!(err.to_string().contains("single-column keys only"));
}
}

View file

@ -1910,9 +1910,14 @@ query docs_with_tag($tag: String) {
return { $d.slug }
}
"#;
let result = query_main(&mut db, queries, "docs_with_tag", &params(&[("$tag", "red")]))
.await
.unwrap();
let result = query_main(
&mut db,
queries,
"docs_with_tag",
&params(&[("$tag", "red")]),
)
.await
.unwrap();
let batch = result.concat_batches().unwrap();
let slugs = batch

View file

@ -95,11 +95,11 @@ const FORBIDDEN_PATTERNS: &[&str] = &[
/// provide the staged primitives or to maintain the system tables
/// (commit graph, manifest).
const ALLOW_LIST_FILES: &[&str] = &[
"table_store.rs", // The storage layer itself.
"storage_layer.rs", // The trait module.
"commit_graph.rs", // Maintains `_graph_commits.lance` system table.
"graph_coordinator.rs", // Drives the manifest publisher / branch coordinator.
"recovery_audit.rs", // Maintains `_graph_commit_recoveries.lance` (recovery audit trail).
"table_store.rs", // The storage layer itself.
"storage_layer.rs", // The trait module.
"commit_graph.rs", // Maintains `_graph_commits.lance` system table.
"graph_coordinator.rs", // Drives the manifest publisher / branch coordinator.
"recovery_audit.rs", // Maintains `_graph_commit_recoveries.lance` (recovery audit trail).
];
/// Directories exempt from the guard. Files under these paths may use
@ -168,10 +168,7 @@ fn engine_code_does_not_call_forbidden_lance_apis() {
// comments are documentation, not code use. The trait
// surface (sealed + trait-only) is the actual enforcement;
// this test only catches code use.
if trimmed.starts_with("//")
|| trimmed.starts_with("/*")
|| trimmed.starts_with("*")
{
if trimmed.starts_with("//") || trimmed.starts_with("/*") || trimmed.starts_with("*") {
continue;
}
// Allow lines marked with the sentinel on the SAME line or

View file

@ -23,8 +23,8 @@ use std::path::Path;
use std::sync::Arc;
use omnigraph::db::{Omnigraph, ReadTarget, SchemaApplyOptions};
use omnigraph::loader::LoadMode;
use omnigraph::error::OmniError;
use omnigraph::loader::LoadMode;
use omnigraph_policy::{PolicyChecker, PolicyEngine};
use helpers::*;
@ -58,7 +58,10 @@ rules:
"#;
fn additive_schema() -> String {
helpers::TEST_SCHEMA.replace(" age: I32?\n}", " age: I32?\n nickname: String?\n}")
helpers::TEST_SCHEMA.replace(
" age: I32?\n}",
" age: I32?\n nickname: String?\n}",
)
}
fn install_policy(db: Omnigraph, dir_path: &Path) -> (Omnigraph, Arc<PolicyEngine>) {
@ -238,7 +241,12 @@ async fn load_as_denies_when_policy_rejects_actor() {
let (db, _engine) = init_with_policy(&dir).await;
let result = db
.load_as("main", ONE_PERSON_JSONL, LoadMode::Merge, Some("act-denied"))
.load_as(
"main",
ONE_PERSON_JSONL,
LoadMode::Merge,
Some("act-denied"),
)
.await;
assert_denied(result, "load_as");
}

View file

@ -127,10 +127,7 @@ async fn multi_statement_mutation_is_atomic_with_read_your_writes() {
"main",
MUTATION_QUERIES,
"insert_person_and_friend",
&mixed_params(
&[("$name", "Eve"), ("$friend", "Alice")],
&[("$age", 22)],
),
&mixed_params(&[("$name", "Eve"), ("$friend", "Alice")], &[("$age", 22)]),
)
.await
.unwrap();
@ -187,10 +184,7 @@ async fn partial_failure_leaves_target_queryable_and_unblocks_next_mutation() {
"main",
MUTATION_QUERIES,
"insert_person_and_friend",
&mixed_params(
&[("$name", "Eve"), ("$friend", "Missing")],
&[("$age", 22)],
),
&mixed_params(&[("$name", "Eve"), ("$friend", "Missing")], &[("$age", 22)]),
)
.await
.expect_err("op-2 must fail");
@ -543,10 +537,7 @@ async fn mutation_rejects_mixed_insert_and_delete_at_parse_time() {
"main",
STAGED_QUERIES,
"mixed_insert_and_delete",
&mixed_params(
&[("$name", "Eve"), ("$victim", "Alice")],
&[("$age", 22)],
),
&mixed_params(&[("$name", "Eve"), ("$victim", "Alice")], &[("$age", 22)]),
)
.await
.expect_err("D₂ must reject mixed insert+delete");
@ -559,7 +550,9 @@ async fn mutation_rejects_mixed_insert_and_delete_at_parse_time() {
manifest_err.message,
);
assert!(
manifest_err.message.contains("split into separate mutations"),
manifest_err
.message
.contains("split into separate mutations"),
"error message should direct user to split: {}",
manifest_err.message,
);
@ -668,11 +661,7 @@ async fn multiple_appends_to_same_edge_coalesce_to_one_append() {
"main",
STAGED_QUERIES,
"insert_two_friends",
&params(&[
("$from", "Alice"),
("$a", "Bob"),
("$b", "Eve"),
]),
&params(&[("$from", "Alice"), ("$a", "Bob"), ("$b", "Eve")]),
)
.await
.unwrap();
@ -782,8 +771,14 @@ async fn load_with_bad_edge_reference_unblocks_next_load() {
// No write made it to disk: counts unchanged.
let mid_persons = count_rows(&db, "node:Person").await;
let mid_edges = count_rows(&db, "edge:Knows").await;
assert_eq!(mid_persons, pre_persons, "failed load must not advance Person count");
assert_eq!(mid_edges, pre_edges, "failed load must not advance Knows count");
assert_eq!(
mid_persons, pre_persons,
"failed load must not advance Person count"
);
assert_eq!(
mid_edges, pre_edges,
"failed load must not advance Knows count"
);
// Second load against the same tables — succeeds (no HEAD drift).
let good = r#"{"type": "Person", "data": {"name": "Pat", "age": 55}}"#;
@ -824,7 +819,9 @@ edge WorksAt: Person -> Company @card(0..1)
{"type": "Company", "data": {"name": "Acme"}}
{"type": "Company", "data": {"name": "Bigco"}}
"#;
load_jsonl(&mut db, seed, LoadMode::Overwrite).await.unwrap();
load_jsonl(&mut db, seed, LoadMode::Overwrite)
.await
.unwrap();
let pre_works = count_rows(&db, "edge:WorksAt").await;
@ -1014,7 +1011,10 @@ query cascade_then_explicit($name: String, $other: String) {
// — Bob→Diana would survive. The exact-count check makes both ops
// independently observable.
let pre_knows = count_rows(&db, "edge:Knows").await;
assert_eq!(pre_knows, 3, "fixture invariant: TEST_DATA seeds 3 Knows edges");
assert_eq!(
pre_knows, 3,
"fixture invariant: TEST_DATA seeds 3 Knows edges"
);
db.mutate(
"main",
@ -1066,7 +1066,9 @@ query add_friend($from: String, $to: String) {
let seed = r#"{"type": "Person", "data": {"name": "Alice"}}
{"type": "Person", "data": {"name": "Bob"}}
"#;
load_jsonl(&mut db, seed, LoadMode::Overwrite).await.unwrap();
load_jsonl(&mut db, seed, LoadMode::Overwrite)
.await
.unwrap();
// Single insert: count=1 < min=2 → reject with clear message.
let err = db
@ -1082,8 +1084,7 @@ query add_friend($from: String, $to: String) {
panic!("expected Manifest error, got {err:?}");
};
assert!(
manifest_err.message.contains("@card violation")
&& manifest_err.message.contains("min 2"),
manifest_err.message.contains("@card violation") && manifest_err.message.contains("min 2"),
"unexpected error: {}",
manifest_err.message,
);
@ -1121,7 +1122,9 @@ edge WorksAt: Person -> Company @card(0..1)
{"type": "Company", "data": {"name": "Bigco"}}
{"edge": "WorksAt", "from": "Alice", "to": "Acme", "data": {"id": "w1"}}
"#;
load_jsonl(&mut db, seed, LoadMode::Overwrite).await.unwrap();
load_jsonl(&mut db, seed, LoadMode::Overwrite)
.await
.unwrap();
// Merge-update the same edge id w1 to point at Bigco. Counted naively
// as union, Alice has 2 WorksAt (committed Acme + pending Bigco) which
@ -1167,7 +1170,9 @@ edge WorksAt: Person -> Company @card(0..1)
{"type": "Company", "data": {"name": "Acme"}}
{"type": "Company", "data": {"name": "Bigco"}}
"#;
load_jsonl(&mut db, seed, LoadMode::Overwrite).await.unwrap();
load_jsonl(&mut db, seed, LoadMode::Overwrite)
.await
.unwrap();
// Merge load with the SAME edge id twice — the second row supersedes
// the first in the finalize-time dedupe. If pending-counting doesn't
@ -1364,7 +1369,11 @@ query insert_then_update_note(
)
.await
.unwrap();
assert_eq!(qr.num_rows(), 0, "letter must not be visible after early error");
assert_eq!(
qr.num_rows(),
0,
"letter must not be visible after early error"
);
}
/// MR-920 regression: two sequential `update T set {f:v} where x=y`
@ -1446,5 +1455,9 @@ async fn second_sequential_update_on_same_row_succeeds() {
}
}
}
assert_eq!(alice_age, Some(42), "Alice's age must reflect the second update");
assert_eq!(
alice_age,
Some(42),
"Alice's age must reflect the second update"
);
}

View file

@ -132,7 +132,11 @@ async fn stage_merge_insert_dedupes_superseded_committed_fragment() {
.await
.unwrap();
let ids = collect_ids(&batches);
assert_eq!(ids, vec!["alice"], "merge_insert must not surface duplicates");
assert_eq!(
ids,
vec!["alice"],
"merge_insert must not surface duplicates"
);
// Confirm the visible row is the rewritten one.
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
@ -382,12 +386,7 @@ async fn scan_with_staged_with_filter_silently_drops_staged_rows() {
// Actual: dave (staged, age=35) is dropped — only the committed matches
// come back.
let batches = store
.scan_with_staged(
&ds,
std::slice::from_ref(&staged),
None,
Some("age >= 30"),
)
.scan_with_staged(&ds, std::slice::from_ref(&staged), None, Some("age >= 30"))
.await
.unwrap();
assert_eq!(
@ -403,12 +402,7 @@ async fn scan_with_staged_with_filter_silently_drops_staged_rows() {
// Without filter, staged data IS visible — confirms the issue is
// specifically filter pushdown, not fragment scanning per se.
let unfiltered = store
.scan_with_staged(
&ds,
std::slice::from_ref(&staged),
None,
None,
)
.scan_with_staged(&ds, std::slice::from_ref(&staged), None, None)
.await
.unwrap();
assert_eq!(
@ -686,10 +680,7 @@ async fn stage_create_inverted_index_does_not_advance_head_until_commit() {
.unwrap();
let pre_version = ds.version().version;
let staged = store
.stage_create_inverted_index(&ds, "id")
.await
.unwrap();
let staged = store.stage_create_inverted_index(&ds, "id").await.unwrap();
assert_eq!(
ds.version().version,
pre_version,
@ -781,13 +772,9 @@ async fn create_vector_index_advances_head_inline_documents_residual() {
let id_arr = StringArray::from(ids);
let flat: Vec<f32> = (0..(n_rows * dim)).map(|i| i as f32).collect();
let values = arrow_array::Float32Array::from(flat);
let vec_arr =
FixedSizeListArray::new(item_field, dim as i32, Arc::new(values), None);
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id_arr), Arc::new(vec_arr)],
)
.unwrap();
let vec_arr = FixedSizeListArray::new(item_field, dim as i32, Arc::new(values), None);
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(id_arr), Arc::new(vec_arr)]).unwrap();
let mut ds = TableStore::write_dataset(&uri, batch).await.unwrap();
let pre_version = ds.version().version;

View file

@ -504,9 +504,21 @@ query fof_chain($name: String) {
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let col0 = batch.column(0).as_any().downcast_ref::<StringArray>().unwrap();
let col1 = batch.column(1).as_any().downcast_ref::<StringArray>().unwrap();
let col2 = batch.column(2).as_any().downcast_ref::<StringArray>().unwrap();
let col0 = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let col1 = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let col2 = batch
.column(2)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(col0.value(0), "Alice");
assert_eq!(col1.value(0), "Bob");
assert_eq!(col2.value(0), "Diana");
@ -574,8 +586,16 @@ query at_acme_named() {
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let person = batch.column(0).as_any().downcast_ref::<StringArray>().unwrap();
let company = batch.column(1).as_any().downcast_ref::<StringArray>().unwrap();
let person = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let company = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(person.value(0), "Alice");
assert_eq!(company.value(0), "Acme");
}
@ -608,8 +628,16 @@ query at_company($company: String) {
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 1);
let person = batch.column(0).as_any().downcast_ref::<StringArray>().unwrap();
let company = batch.column(1).as_any().downcast_ref::<StringArray>().unwrap();
let person = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let company = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(person.value(0), "Bob");
assert_eq!(company.value(0), "Globex");
}
@ -633,19 +661,22 @@ query fan_out($name: String) {
"#;
// Alice knows Bob and Charlie, works at Acme.
// Each friend paired with her company → 2 rows.
let result = query_main(
&mut db,
queries,
"fan_out",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
let result = query_main(&mut db, queries, "fan_out", &params(&[("$name", "Alice")]))
.await
.unwrap();
let batch = result.concat_batches().unwrap();
assert_eq!(batch.num_rows(), 2);
let friends = batch.column(0).as_any().downcast_ref::<StringArray>().unwrap();
let companies = batch.column(1).as_any().downcast_ref::<StringArray>().unwrap();
let friends = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let companies = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut pairs: Vec<(&str, &str)> = (0..batch.num_rows())
.map(|i| (friends.value(i), companies.value(i)))

View file

@ -76,7 +76,9 @@ async fn init_with(schema: &str, data: &str) -> (tempfile::TempDir, Omnigraph) {
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, schema).await.unwrap();
if !data.is_empty() {
load_jsonl(&mut db, data, LoadMode::Overwrite).await.unwrap();
load_jsonl(&mut db, data, LoadMode::Overwrite)
.await
.unwrap();
}
(dir, db)
}