Initial public Omnigraph repository

This commit is contained in:
andrew 2026-04-10 20:49:41 +03:00
commit 338289656a
110 changed files with 60747 additions and 0 deletions

4
.dockerignore Normal file
View file

@ -0,0 +1,4 @@
**
!Dockerfile
!docker/entrypoint.sh
!target/release/omnigraph-server

123
.github/workflows/ci.yml vendored Normal file
View file

@ -0,0 +1,123 @@
name: CI
on:
pull_request:
push:
branches:
- main
tags:
- "v*"
workflow_dispatch:
concurrency:
group: ci-${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test:
name: Test Workspace
runs-on: ubuntu-latest
timeout-minutes: 45
permissions:
contents: read
env:
CARGO_TERM_COLOR: always
steps:
- name: Checkout source
uses: actions/checkout@v4
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler libprotobuf-dev
- name: Install Rust stable
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
- name: Cache Rust build data
uses: Swatinem/rust-cache@v2
with:
workspaces: |
. -> target
- name: Run workspace tests
run: cargo test --workspace --locked
rustfs_integration:
name: RustFS S3 Integration
needs: test
runs-on: ubuntu-latest
timeout-minutes: 45
permissions:
contents: read
env:
AWS_ACCESS_KEY_ID: rustfsadmin
AWS_SECRET_ACCESS_KEY: rustfsadmin
AWS_REGION: us-east-1
AWS_ENDPOINT_URL: http://127.0.0.1:9000
AWS_ENDPOINT_URL_S3: http://127.0.0.1:9000
AWS_ALLOW_HTTP: "true"
AWS_S3_FORCE_PATH_STYLE: "true"
OMNIGRAPH_S3_TEST_BUCKET: omnigraph-ci
OMNIGRAPH_S3_TEST_PREFIX: github-actions
CARGO_TERM_COLOR: always
steps:
- name: Checkout source
uses: actions/checkout@v4
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler libprotobuf-dev python3-pip
- name: Install Rust stable
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
- name: Cache Rust build data
uses: Swatinem/rust-cache@v2
with:
workspaces: |
. -> target
- name: Start RustFS
run: |
docker rm -f rustfs >/dev/null 2>&1 || true
docker run -d \
--name rustfs \
-p 9000:9000 \
-p 9001:9001 \
-e RUSTFS_ACCESS_KEY="${AWS_ACCESS_KEY_ID}" \
-e RUSTFS_SECRET_KEY="${AWS_SECRET_ACCESS_KEY}" \
rustfs/rustfs:latest \
/data
- name: Install AWS CLI
run: |
python3 -m pip install --user awscli
echo "$HOME/.local/bin" >> "$GITHUB_PATH"
- name: Create RustFS test bucket
run: |
for _ in $(seq 1 30); do
if aws --endpoint-url "${AWS_ENDPOINT_URL_S3}" s3api list-buckets >/dev/null 2>&1; then
break
fi
sleep 2
done
aws --endpoint-url "${AWS_ENDPOINT_URL_S3}" \
s3api create-bucket \
--bucket "${OMNIGRAPH_S3_TEST_BUCKET}" >/dev/null 2>&1 || true
- name: Run RustFS-backed repo tests
run: |
cargo test --locked -p omnigraph --test s3_storage -- --nocapture
cargo test --locked -p omnigraph-server --test server server_opens_s3_repo_directly_and_serves_snapshot_and_read -- --nocapture
cargo test --locked -p omnigraph-cli --test system_local local_cli_s3_end_to_end_init_load_read_flow -- --nocapture
- name: Dump RustFS logs on failure
if: failure()
run: docker logs rustfs

68
.github/workflows/release.yml vendored Normal file
View file

@ -0,0 +1,68 @@
name: Release
on:
push:
tags:
- "v*"
workflow_dispatch:
jobs:
build_release:
name: Build ${{ matrix.asset_name }}
runs-on: ${{ matrix.runner }}
permissions:
contents: write
strategy:
fail-fast: false
matrix:
include:
- runner: ubuntu-latest
asset_name: omnigraph-linux-x86_64
- runner: macos-13
asset_name: omnigraph-macos-x86_64
- runner: macos-14
asset_name: omnigraph-macos-arm64
env:
CARGO_TERM_COLOR: always
steps:
- name: Checkout source
uses: actions/checkout@v4
- name: Install Linux dependencies
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler libprotobuf-dev
- name: Install macOS dependencies
if: runner.os == 'macOS'
run: brew install protobuf
- name: Install Rust stable
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
- name: Cache Rust build data
uses: Swatinem/rust-cache@v2
with:
workspaces: |
. -> target
- name: Build release binaries
run: cargo build --release --locked -p omnigraph-cli -p omnigraph-server
- name: Package release archive
run: |
mkdir -p release
install -m 0755 target/release/omnigraph release/omnigraph
install -m 0755 target/release/omnigraph-server release/omnigraph-server
tar -C release -czf "${{ matrix.asset_name }}.tar.gz" omnigraph omnigraph-server
shasum -a 256 "${{ matrix.asset_name }}.tar.gz" > "${{ matrix.asset_name }}.sha256"
- name: Publish GitHub release assets
uses: softprops/action-gh-release@v2
with:
files: |
${{ matrix.asset_name }}.tar.gz
${{ matrix.asset_name }}.sha256

18
.gitignore vendored Normal file
View file

@ -0,0 +1,18 @@
/target
**/target
*.lance
*.nano
*.nanograph
.DS_Store
.env
.env.*
*.tfstate
*.tfstate.*
.terraform/
.terraform.lock.hcl
*.tfvars
!*.tfvars.example
__pycache__/
*.pyc
demo/*.omni/
.omnigraph-rustfs-demo/

13
CODE_OF_CONDUCT.md Normal file
View file

@ -0,0 +1,13 @@
# Code Of Conduct
This project follows a simple rule: be direct, respectful, and constructive.
Expected behavior:
- focus on technical substance
- assume good intent
- give actionable feedback
- avoid harassment, personal attacks, and pile-ons
Maintainers may remove comments, issues, or pull requests that make the project
harder to collaborate in productively.

23
CONTRIBUTING.md Normal file
View file

@ -0,0 +1,23 @@
# Contributing
Small bug fixes and documentation improvements are welcome directly through pull
requests.
For larger changes, please open an issue or design discussion first so the
proposed direction is clear before implementation starts.
## Development
```bash
cargo build --workspace
cargo test --workspace
```
If you touch S3-backed flows, the CI model uses a local RustFS instance for
integration tests.
## Pull Requests
- keep changes focused
- include tests for behavior changes when practical
- update public docs when the user-facing surface changes

7646
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

79
Cargo.toml Normal file
View file

@ -0,0 +1,79 @@
[workspace]
resolver = "2"
members = [
"crates/omnigraph-compiler",
"crates/omnigraph",
"crates/omnigraph-cli",
"crates/omnigraph-server",
]
default-members = [
"crates/omnigraph",
"crates/omnigraph-cli",
"crates/omnigraph-server",
]
[workspace.dependencies]
arrow-array = "57"
arrow-ipc = "57"
arrow-schema = "57"
arrow-select = "57"
arrow-cast = { version = "57", features = ["prettyprint"] }
arrow-ord = "57"
datafusion-physical-plan = "52"
datafusion-physical-expr = "52"
datafusion-execution = "52"
datafusion-common = "52"
datafusion-expr = "52"
datafusion-functions-aggregate = "52"
lance = { version = "4.0.0", default-features = false, features = ["aws"] }
lance-datafusion = "4.0.0"
lance-file = "4.0.0"
lance-index = "4.0.0"
lance-linalg = "4.0.0"
lance-namespace = "4.0.0"
lance-namespace-impls = "4.0.0"
lance-table = "4.0.0"
ulid = "1"
futures = "0.3"
async-trait = "0.1"
pest = "2"
pest_derive = "2"
thiserror = "2"
tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "net", "signal", "sync"] }
clap = { version = "4", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde_yaml = "0.9"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
tower = "0.5"
tower-http = { version = "0.6", features = ["trace"] }
color-eyre = "0.6"
tempfile = "3"
ahash = "0.8"
base64 = "0.22"
ariadne = "0.4"
regex = "1"
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
object_store = { version = "0.12.5", default-features = false, features = ["aws"] }
fail = "0.5"
time = { version = "0.3", features = ["formatting"] }
axum = { version = "0.8", features = ["json", "macros"] }
url = "2"
cedar-policy = "4.9"
sha2 = "0.10"
[profile.dev]
debug = 0
[profile.dev.package."*"]
opt-level = 2
[profile.release]
opt-level = 2
lto = "thin"
codegen-units = 16
strip = true

25
Dockerfile Normal file
View file

@ -0,0 +1,25 @@
FROM debian:bookworm-slim
RUN apt-get update \
&& apt-get install -y --no-install-recommends ca-certificates curl \
&& rm -rf /var/lib/apt/lists/*
RUN groupadd --system omnigraph \
&& useradd --system --gid omnigraph --create-home --home-dir /var/lib/omnigraph omnigraph
COPY target/release/omnigraph-server /usr/local/bin/omnigraph-server
COPY docker/entrypoint.sh /usr/local/bin/omnigraph-entrypoint
RUN chmod 0755 /usr/local/bin/omnigraph-server /usr/local/bin/omnigraph-entrypoint
ENV OMNIGRAPH_BIND=0.0.0.0:8080
WORKDIR /var/lib/omnigraph
USER omnigraph:omnigraph
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
CMD curl -fsS http://127.0.0.1:8080/healthz >/dev/null || exit 1
ENTRYPOINT ["/usr/local/bin/omnigraph-entrypoint"]

21
LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 NanoGraph Contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

154
README.md Normal file
View file

@ -0,0 +1,154 @@
# Omnigraph
Omnigraph is a typed property graph database built on Lance. It combines
schema-first graph modeling, typed queries and mutations, Git-style graph
workflows, and storage that runs equally well on a local directory or an
`s3://` URI.
## Quick Install
```bash
curl -fsSL https://raw.githubusercontent.com/ModernRelay/omnigraph-public/main/scripts/install.sh | bash
```
This installs `omnigraph` and `omnigraph-server` into `~/.local/bin`. If no
tagged release exists for your platform yet, the installer falls back to a
source build.
## One-Command Local RustFS Bootstrap
```bash
curl -fsSL https://raw.githubusercontent.com/ModernRelay/omnigraph-public/main/scripts/local-rustfs-bootstrap.sh | bash
```
That bootstrap:
- starts RustFS on `127.0.0.1:9000`
- creates a bucket and S3-backed repo
- loads the checked-in context fixture
- launches `omnigraph-server` on `127.0.0.1:8080`
Docker must be installed and running first.
## Good Fit For
- Team knowledge graphs and internal context graphs
- Research, decisions, and evidence tracking
- Collaborative knowledge systems with reviewable changes
- Private self-hosted graph backends for local or on-prem AI tooling
## Why Omnigraph
- Typed schema, typed queries, and typed mutations
- Git-style graph workflows: branches, commits, merges, and transactional runs
- Local-first and S3-native storage with snapshot-pinned reads
- Graph traversal plus text, fuzzy, BM25, vector, and RRF search in one runtime
- Policy as code for server-side access control
## Quick Start
From a checkout of this repo:
```bash
cargo build --workspace
cargo run -p omnigraph-cli -- init \
--schema crates/omnigraph/tests/fixtures/test.pg \
./repo.omni
cargo run -p omnigraph-cli -- load \
--data crates/omnigraph/tests/fixtures/test.jsonl \
./repo.omni
cargo run -p omnigraph-cli -- read \
./repo.omni \
--query crates/omnigraph/tests/fixtures/test.gq \
--name friends_of \
--params '{"name":"Alice"}'
```
`init` also scaffolds an `omnigraph.yaml` next to the repo if one does not
already exist.
## Run A Server
Serve the same repo over HTTP:
```bash
cargo run -p omnigraph-server -- ./repo.omni --bind 127.0.0.1:8080
```
Then query it remotely:
```bash
cargo run -p omnigraph-cli -- read \
--target http://127.0.0.1:8080 \
--query crates/omnigraph/tests/fixtures/test.gq \
--name get_person \
--params '{"name":"Alice"}'
```
Server routes include `/healthz`, `/snapshot`, `/export`, `/read`, `/change`,
`/ingest`, `/branches`, `/runs`, and `/commits`.
To require auth, set `OMNIGRAPH_SERVER_BEARER_TOKEN` on the server and set the
matching bearer token env var in your CLI target config.
## Common Commands
Core repo flow:
```bash
omnigraph init --schema ./schema.pg ./repo.omni
omnigraph load --data ./data.jsonl --mode overwrite ./repo.omni
omnigraph snapshot ./repo.omni --branch main --json
omnigraph read ./repo.omni --query ./queries.gq --name get_person --params '{"name":"Alice"}'
omnigraph change ./repo.omni --query ./queries.gq --name insert_person --params '{"name":"Mina","age":28}'
omnigraph branch create --uri ./repo.omni --from main feature-x
omnigraph branch merge --uri ./repo.omni feature-x --into main
```
More CLI examples, config patterns, and admin commands live in
[docs/cli.md](docs/cli.md).
## Production Features
- Branches, commits, merge-base-aware graph merges, and transactional runs
- Snapshot-pinned reads across local and S3-backed repos
- Traversal plus text, fuzzy, BM25, vector, and RRF search
- Axum server for reads, changes, export, branches, commits, and runs
- Cedar-based server-side authorization
## Docs
- [Install guide](docs/install.md)
- [CLI guide](docs/cli.md)
- [Deployment guide](docs/deployment.md)
## Build And Test
```bash
cargo build --workspace
cargo check --workspace
cargo test --workspace
```
Notes:
- Rust stable toolchain, edition 2024
- CI runs `cargo test --workspace --locked`
- Full CI and some local test flows require `protobuf-compiler`
- S3 integration tests expect an S3-compatible endpoint such as RustFS
## Workspace Crates
- `crates/omnigraph-compiler`: shared schema/query parser, typechecker, catalog, and IR lowering
- `crates/omnigraph`: storage/runtime, branching, merge, change detection, and query execution
- `crates/omnigraph-cli`: CLI for init/load/ingest/read/change/branch/snapshot/export/policy operations
- `crates/omnigraph-server`: Axum HTTP server for remote reads, changes, ingest, export, branches, commits, and runs
## Contributing
Please open an issue, spec, or design discussion before sending large code
changes. Design feedback and concrete problem statements are the fastest way to
collaborate on the roadmap.

14
SECURITY.md Normal file
View file

@ -0,0 +1,14 @@
# Security Policy
Please do not report security issues through public GitHub issues.
If GitHub private vulnerability reporting is enabled for this repository, use
that channel. Otherwise, contact the maintainers directly through a private
channel before publishing details.
When reporting an issue, include:
- affected version or commit
- impact
- reproduction steps
- any proposed mitigation

View file

@ -0,0 +1,28 @@
[package]
name = "omnigraph-cli"
version = "0.4.0"
edition = "2024"
description = "CLI for the Omnigraph graph database."
license = "MIT"
[[bin]]
name = "omnigraph"
path = "src/main.rs"
[dependencies]
omnigraph = { path = "../omnigraph", version = "0.4.0" }
omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.4.0" }
omnigraph-server = { path = "../omnigraph-server", version = "0.4.0" }
clap = { workspace = true }
color-eyre = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serde_yaml = { workspace = true }
tokio = { workspace = true }
reqwest = { workspace = true, features = ["blocking"] }
[dev-dependencies]
assert_cmd = "2"
predicates = "3"
serde_json = { workspace = true }
tempfile = { workspace = true }

View file

@ -0,0 +1,586 @@
use std::collections::{BTreeMap, HashSet};
use std::fs::{self, File};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use clap::Args;
use color_eyre::eyre::{Result, bail, eyre};
use omnigraph::embedding::EmbeddingClient;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value, json};
const DEFAULT_EMBED_MODEL: &str = "gemini-embedding-2-preview";
#[derive(Debug, Args, Clone)]
pub(crate) struct EmbedArgs {
/// Seed manifest path
#[arg(long, conflicts_with_all = ["input", "output", "spec"])]
pub seed: Option<PathBuf>,
/// Raw seed JSONL input path
#[arg(long, requires_all = ["output", "spec"], conflicts_with = "seed")]
pub input: Option<PathBuf>,
/// Embedded JSONL output path
#[arg(long)]
pub output: Option<PathBuf>,
/// Embedding spec JSON path
#[arg(long, requires_all = ["input", "output"], conflicts_with = "seed")]
pub spec: Option<PathBuf>,
/// Remove embedding fields instead of generating embeddings
#[arg(long, conflicts_with = "reembed_all")]
pub clean: bool,
/// Regenerate embeddings for all matching rows
#[arg(long, conflicts_with = "clean")]
pub reembed_all: bool,
/// Restrict processing to these type names
#[arg(long = "type")]
pub types: Vec<String>,
/// Reembed or clean matching rows only. Syntax: Type:field=value or field=value
#[arg(long = "select")]
pub selectors: Vec<String>,
/// Print JSON summary
#[arg(long)]
pub json: bool,
}
#[derive(Debug, Clone, Serialize)]
pub(crate) struct EmbedOutput {
pub input: String,
pub output: String,
pub rows: usize,
pub selected_rows: usize,
pub embedded_rows: usize,
pub cleaned_rows: usize,
pub mode: &'static str,
pub dimension: usize,
pub model: String,
}
#[derive(Debug, Clone)]
pub(crate) struct EmbedJob {
input: PathBuf,
output: PathBuf,
spec: EmbedSpec,
mode: EmbedMode,
type_filter: HashSet<String>,
selectors: Vec<RowSelector>,
}
#[derive(Debug, Clone, Copy)]
enum EmbedMode {
FillMissing,
ReembedAll,
Clean,
}
impl EmbedMode {
fn as_str(self, selectors_present: bool) -> &'static str {
match self {
Self::FillMissing if selectors_present => "reembed_selected",
Self::FillMissing => "fill_missing",
Self::ReembedAll => "reembed_all",
Self::Clean => "clean",
}
}
}
#[derive(Debug, Clone, Deserialize)]
struct EmbedSpec {
#[serde(default = "default_embed_model")]
model: String,
dimension: usize,
types: BTreeMap<String, EmbedTypeSpec>,
}
#[derive(Debug, Clone, Deserialize)]
struct EmbedTypeSpec {
target: String,
fields: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct SeedManifest {
#[serde(default)]
sources: Option<SeedSources>,
#[serde(default)]
artifacts: Option<SeedArtifacts>,
#[serde(default)]
embeddings: Option<EmbedSpec>,
#[serde(default)]
seed: Option<LegacySeed>,
}
#[derive(Debug, Clone, Deserialize)]
struct SeedSources {
raw_seed: PathBuf,
}
#[derive(Debug, Clone, Deserialize)]
struct SeedArtifacts {
embedded_seed: PathBuf,
}
#[derive(Debug, Clone, Deserialize)]
struct LegacySeed {
data: PathBuf,
}
#[derive(Debug, Clone)]
struct RowSelector {
type_name: Option<String>,
field: String,
expected: String,
}
#[derive(Debug)]
enum EmbedRow {
Entity {
type_name: String,
data: Map<String, Value>,
root: Map<String, Value>,
},
Passthrough(Map<String, Value>),
}
pub(crate) fn resolve_embed_job(args: &EmbedArgs) -> Result<EmbedJob> {
let mode = if args.clean {
EmbedMode::Clean
} else if args.reembed_all {
EmbedMode::ReembedAll
} else {
EmbedMode::FillMissing
};
let selectors = args
.selectors
.iter()
.map(|selector| RowSelector::parse(selector))
.collect::<Result<Vec<_>>>()?;
let type_filter = args.types.iter().cloned().collect::<HashSet<_>>();
let (input, output, spec) = if let Some(seed_path) = &args.seed {
let manifest = load_seed_manifest(seed_path)?;
(
manifest.raw_seed,
args.output.clone().unwrap_or(manifest.embedded_seed),
manifest.spec,
)
} else {
let input = args
.input
.clone()
.ok_or_else(|| eyre!("--input is required when --seed is not provided"))?;
let output = args
.output
.clone()
.ok_or_else(|| eyre!("--output is required when --seed is not provided"))?;
let spec_path = args
.spec
.clone()
.ok_or_else(|| eyre!("--spec is required when --seed is not provided"))?;
let spec = load_embed_spec(&spec_path)?;
(input, output, spec)
};
if spec.model != DEFAULT_EMBED_MODEL {
bail!(
"only {} is supported for explicit seed embeddings right now",
DEFAULT_EMBED_MODEL
);
}
Ok(EmbedJob {
input,
output,
spec,
mode,
type_filter,
selectors,
})
}
pub(crate) async fn execute_embed(args: &EmbedArgs) -> Result<EmbedOutput> {
let job = resolve_embed_job(args)?;
run_embed_job(&job).await
}
pub(crate) async fn run_embed_job(job: &EmbedJob) -> Result<EmbedOutput> {
if !job.input.exists() {
bail!("seed input does not exist: {}", job.input.display());
}
if let Some(parent) = job.output.parent() {
fs::create_dir_all(parent)?;
}
let temp_output = temp_output_path(&job.output);
let mut reader = BufReader::new(File::open(&job.input)?);
let mut writer = BufWriter::new(File::create(&temp_output)?);
let client = match job.mode {
EmbedMode::Clean => None,
_ => Some(EmbeddingClient::from_env()?),
};
let mut line = String::new();
let mut rows = 0usize;
let mut selected_rows = 0usize;
let mut embedded_rows = 0usize;
let mut cleaned_rows = 0usize;
loop {
line.clear();
let bytes = reader.read_line(&mut line)?;
if bytes == 0 {
break;
}
let raw = line.trim();
if raw.is_empty() {
continue;
}
rows += 1;
let mut row = parse_row(raw, rows)?;
let selected = row_matches_selection(&row, &job.type_filter, &job.selectors);
if selected {
selected_rows += 1;
}
if let Some(type_spec) = row
.type_name()
.and_then(|type_name| job.spec.types.get(type_name))
{
match job.mode {
EmbedMode::Clean => {
if selected
&& row
.data_mut()
.is_some_and(|data| data.remove(&type_spec.target).is_some())
{
cleaned_rows += 1;
}
}
EmbedMode::ReembedAll => {
if selected {
embed_row(
&mut row,
type_spec,
job.spec.dimension,
client.as_ref().unwrap(),
)
.await?;
embedded_rows += 1;
}
}
EmbedMode::FillMissing => {
let reembed_selected = !job.selectors.is_empty();
if selected
&& (reembed_selected
|| embedding_missing(
row.data().and_then(|data| data.get(&type_spec.target)),
))
{
embed_row(
&mut row,
type_spec,
job.spec.dimension,
client.as_ref().unwrap(),
)
.await?;
embedded_rows += 1;
}
}
}
}
writer.write_all(serde_json::to_string(&row.into_value())?.as_bytes())?;
writer.write_all(b"\n")?;
}
writer.flush()?;
fs::rename(&temp_output, &job.output)?;
Ok(EmbedOutput {
input: job.input.display().to_string(),
output: job.output.display().to_string(),
rows,
selected_rows,
embedded_rows,
cleaned_rows,
mode: job.mode.as_str(!job.selectors.is_empty()),
dimension: job.spec.dimension,
model: job.spec.model.clone(),
})
}
fn temp_output_path(output: &Path) -> PathBuf {
let mut temp = output.as_os_str().to_os_string();
temp.push(".tmp");
PathBuf::from(temp)
}
fn default_embed_model() -> String {
DEFAULT_EMBED_MODEL.to_string()
}
fn load_embed_spec(path: &Path) -> Result<EmbedSpec> {
Ok(serde_json::from_str(&fs::read_to_string(path)?)?)
}
struct ResolvedSeedManifest {
raw_seed: PathBuf,
embedded_seed: PathBuf,
spec: EmbedSpec,
}
fn load_seed_manifest(path: &Path) -> Result<ResolvedSeedManifest> {
let base_dir = path
.parent()
.map(Path::to_path_buf)
.unwrap_or(std::env::current_dir()?);
let manifest: SeedManifest = serde_yaml::from_str(&fs::read_to_string(path)?)?;
let raw_seed = manifest
.sources
.as_ref()
.map(|sources| sources.raw_seed.clone())
.or_else(|| manifest.seed.as_ref().map(|seed| seed.data.clone()))
.ok_or_else(|| eyre!("seed manifest is missing sources.raw_seed"))?;
let embedded_seed = manifest
.artifacts
.as_ref()
.map(|artifacts| artifacts.embedded_seed.clone())
.unwrap_or_else(|| PathBuf::from("./build/seed.embedded.jsonl"));
let spec = manifest
.embeddings
.ok_or_else(|| eyre!("seed manifest is missing embeddings"))?;
Ok(ResolvedSeedManifest {
raw_seed: base_dir.join(raw_seed),
embedded_seed: base_dir.join(embedded_seed),
spec,
})
}
impl RowSelector {
fn parse(value: &str) -> Result<Self> {
let (lhs, expected) = value
.split_once('=')
.ok_or_else(|| eyre!("selector must be field=value or Type:field=value"))?;
let (type_name, field) = if let Some((type_name, field)) = lhs.split_once(':') {
(
Some(type_name.trim().to_string()).filter(|value| !value.is_empty()),
field.trim().to_string(),
)
} else {
(None, lhs.trim().to_string())
};
if field.is_empty() {
bail!("selector field cannot be empty");
}
Ok(Self {
type_name,
field,
expected: expected.trim().to_string(),
})
}
fn matches(&self, type_name: &str, data: &Map<String, Value>) -> bool {
if self
.type_name
.as_deref()
.is_some_and(|expected| expected != type_name)
{
return false;
}
data.get(&self.field)
.map(render_value)
.is_some_and(|value| value == self.expected)
}
}
fn parse_row(raw: &str, line_number: usize) -> Result<EmbedRow> {
let mut root = serde_json::from_str::<Map<String, Value>>(raw)
.map_err(|err| eyre!("line {} is not valid JSON: {}", line_number, err))?;
let Some(type_name) = root.get("type").and_then(Value::as_str).map(str::to_string) else {
return Ok(EmbedRow::Passthrough(root));
};
let data = root
.remove("data")
.and_then(|value| value.as_object().cloned())
.ok_or_else(|| eyre!("line {} is missing object field 'data'", line_number))?;
Ok(EmbedRow::Entity {
type_name,
data,
root,
})
}
impl EmbedRow {
fn into_value(self) -> Value {
match self {
Self::Entity {
type_name,
data,
mut root,
} => {
root.insert("type".to_string(), Value::String(type_name));
root.insert("data".to_string(), Value::Object(data));
Value::Object(root)
}
Self::Passthrough(root) => Value::Object(root),
}
}
fn type_name(&self) -> Option<&str> {
match self {
Self::Entity { type_name, .. } => Some(type_name.as_str()),
Self::Passthrough(_) => None,
}
}
fn data(&self) -> Option<&Map<String, Value>> {
match self {
Self::Entity { data, .. } => Some(data),
Self::Passthrough(_) => None,
}
}
fn data_mut(&mut self) -> Option<&mut Map<String, Value>> {
match self {
Self::Entity { data, .. } => Some(data),
Self::Passthrough(_) => None,
}
}
}
fn row_matches_selection(
row: &EmbedRow,
type_filter: &HashSet<String>,
selectors: &[RowSelector],
) -> bool {
let Some(type_name) = row.type_name() else {
return false;
};
let Some(data) = row.data() else {
return false;
};
let matches_type = type_filter.is_empty() || type_filter.contains(type_name);
if !matches_type {
return false;
}
if selectors.is_empty() {
return true;
}
selectors
.iter()
.any(|selector| selector.matches(type_name, data))
}
fn embedding_missing(value: Option<&Value>) -> bool {
match value {
None | Some(Value::Null) => true,
Some(Value::Array(values)) => values.is_empty(),
_ => false,
}
}
fn render_value(value: &Value) -> String {
match value {
Value::Null => String::new(),
Value::String(value) => value.trim().to_string(),
Value::Bool(value) => {
if *value {
"true".to_string()
} else {
"false".to_string()
}
}
Value::Number(value) => value.to_string(),
Value::Array(values) => values
.iter()
.map(render_value)
.filter(|value| !value.is_empty())
.collect::<Vec<_>>()
.join(", "),
other => other.to_string(),
}
}
fn build_embedding_text(type_name: &str, data: &Map<String, Value>, fields: &[String]) -> String {
let mut parts = vec![format!("type: {}", type_name)];
for field in fields {
if let Some(value) = data.get(field) {
let rendered = render_value(value);
if !rendered.is_empty() {
parts.push(format!("{}: {}", field, rendered));
}
}
}
parts.join("\n")
}
async fn embed_row(
row: &mut EmbedRow,
spec: &EmbedTypeSpec,
dimension: usize,
client: &EmbeddingClient,
) -> Result<()> {
let type_name = row
.type_name()
.ok_or_else(|| eyre!("cannot embed non-entity seed rows"))?
.to_string();
let data = row
.data_mut()
.ok_or_else(|| eyre!("cannot embed non-entity seed rows"))?;
let text = build_embedding_text(&type_name, data, &spec.fields);
if text.trim().is_empty() {
return Ok(());
}
let embedding = client.embed_document_text(&text, dimension).await?;
data.insert(spec.target.clone(), json!(embedding));
Ok(())
}
#[cfg(test)]
mod tests {
use super::{RowSelector, build_embedding_text, render_value};
use serde_json::json;
#[test]
fn selector_parses_type_and_field_forms() {
let typed = RowSelector::parse("Decision:slug=dec-1").unwrap();
assert_eq!(typed.type_name.as_deref(), Some("Decision"));
assert_eq!(typed.field, "slug");
assert_eq!(typed.expected, "dec-1");
let plain = RowSelector::parse("slug=dec-2").unwrap();
assert_eq!(plain.type_name, None);
assert_eq!(plain.field, "slug");
assert_eq!(plain.expected, "dec-2");
}
#[test]
fn render_value_handles_lists_and_scalars() {
assert_eq!(render_value(&json!(["a", "b"])), "a, b");
assert_eq!(render_value(&json!(true)), "true");
assert_eq!(render_value(&json!(3)), "3");
}
#[test]
fn build_embedding_text_prefixes_type_and_fields() {
let data = json!({
"slug": "dec-1",
"intent": "Ship it"
});
let object = data.as_object().unwrap();
let text = build_embedding_text(
"Decision",
object,
&["slug".to_string(), "intent".to_string()],
);
assert!(text.contains("type: Decision"));
assert!(text.contains("slug: dec-1"));
assert!(text.contains("intent: Ship it"));
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,356 @@
use color_eyre::eyre::Result;
use omnigraph_server::ReadOutputFormat;
use omnigraph_server::api::ReadOutput;
use omnigraph_server::config::TableCellLayout;
use serde_json::{Map, Value};
pub struct ReadRenderOptions {
pub max_column_width: usize,
pub cell_layout: TableCellLayout,
}
pub fn render_read(
output: &ReadOutput,
format: ReadOutputFormat,
options: &ReadRenderOptions,
) -> Result<String> {
match format {
ReadOutputFormat::Json => Ok(serde_json::to_string_pretty(output)?),
ReadOutputFormat::Jsonl => render_jsonl(output),
ReadOutputFormat::Csv => render_csv(output),
ReadOutputFormat::Kv => Ok(render_kv(output)),
ReadOutputFormat::Table => Ok(render_table(output, options)),
}
}
fn render_jsonl(output: &ReadOutput) -> Result<String> {
let mut lines = Vec::new();
lines.push(serde_json::to_string(&serde_json::json!({
"kind": "metadata",
"query_name": output.query_name,
"target": output.target,
"row_count": output.row_count,
}))?);
for row in rows(output) {
lines.push(serde_json::to_string(&row)?);
}
Ok(lines.join("\n"))
}
fn render_csv(output: &ReadOutput) -> Result<String> {
let rows = rows(output);
let columns = columns(output, &rows);
let mut lines = Vec::new();
lines.push(
columns
.iter()
.map(|column| csv_escape(column))
.collect::<Vec<_>>()
.join(","),
);
for row in rows {
lines.push(
columns
.iter()
.map(|column| csv_escape(&stringify_value(row.get(column).unwrap_or(&Value::Null))))
.collect::<Vec<_>>()
.join(","),
);
}
Ok(lines.join("\n"))
}
fn render_kv(output: &ReadOutput) -> String {
let mut lines = vec![header_line(output)];
let rows = rows(output);
if rows.is_empty() {
lines.push("(no rows)".to_string());
return lines.join("\n");
}
for (idx, row) in rows.iter().enumerate() {
if idx > 0 {
lines.push(String::new());
}
lines.push(format!("row {}", idx + 1));
for column in columns(output, &rows) {
lines.push(format!(
"{}: {}",
column,
stringify_value(row.get(&column).unwrap_or(&Value::Null))
));
}
}
lines.join("\n")
}
fn render_table(output: &ReadOutput, options: &ReadRenderOptions) -> String {
let mut lines = vec![header_line(output)];
let rows = rows(output);
let columns = columns(output, &rows);
if columns.is_empty() {
lines.push("(no rows)".to_string());
return lines.join("\n");
}
let widths = columns
.iter()
.map(|column| {
let mut width = column.chars().count();
for row in &rows {
let rendered =
normalize_cell(&stringify_value(row.get(column).unwrap_or(&Value::Null)));
let longest = rendered
.lines()
.map(|line| line.chars().count())
.max()
.unwrap_or(0);
width = width.max(longest.min(options.max_column_width));
}
width.min(options.max_column_width.max(8))
})
.collect::<Vec<_>>();
lines.push(render_table_line(&columns, &widths));
lines.push(
widths
.iter()
.map(|width| "-".repeat(*width))
.collect::<Vec<_>>()
.join("-+-"),
);
for row in rows {
let cell_lines = columns
.iter()
.zip(widths.iter())
.map(|(column, width)| {
split_cell(
&normalize_cell(&stringify_value(row.get(column).unwrap_or(&Value::Null))),
*width,
options.cell_layout,
)
})
.collect::<Vec<_>>();
let line_count = cell_lines.iter().map(Vec::len).max().unwrap_or(1);
for line_idx in 0..line_count {
let rendered = cell_lines
.iter()
.zip(widths.iter())
.map(|(segments, width)| {
let segment = segments.get(line_idx).cloned().unwrap_or_default();
pad_to_width(&segment, *width)
})
.collect::<Vec<_>>();
lines.push(rendered.join(" | "));
}
}
lines.join("\n")
}
fn render_table_line(columns: &[String], widths: &[usize]) -> String {
columns
.iter()
.zip(widths.iter())
.map(|(column, width)| pad_to_width(column, *width))
.collect::<Vec<_>>()
.join(" | ")
}
fn header_line(output: &ReadOutput) -> String {
format!(
"{} rows from {} via {}",
output.row_count,
output
.target
.snapshot
.as_deref()
.map(|id| format!("snapshot {}", id))
.or_else(|| {
output
.target
.branch
.as_deref()
.map(|branch| format!("branch {}", branch))
})
.unwrap_or_else(|| "target".to_string()),
output.query_name
)
}
fn rows(output: &ReadOutput) -> Vec<Map<String, Value>> {
output
.rows
.as_array()
.into_iter()
.flatten()
.map(|row| match row {
Value::Object(map) => map.clone(),
other => {
let mut map = Map::new();
map.insert("value".to_string(), other.clone());
map
}
})
.collect()
}
fn columns(output: &ReadOutput, rows: &[Map<String, Value>]) -> Vec<String> {
if !output.columns.is_empty() {
return output.columns.clone();
}
let mut columns = rows
.iter()
.flat_map(|row| row.keys().cloned())
.collect::<Vec<_>>();
columns.sort();
columns.dedup();
columns
}
fn stringify_value(value: &Value) -> String {
match value {
Value::Null => "null".to_string(),
Value::String(text) => text.clone(),
Value::Bool(boolean) => boolean.to_string(),
Value::Number(number) => number.to_string(),
other => serde_json::to_string(other).unwrap_or_else(|_| "<invalid json>".to_string()),
}
}
fn normalize_cell(value: &str) -> String {
value.replace('\n', "\\n")
}
fn split_cell(value: &str, width: usize, layout: TableCellLayout) -> Vec<String> {
if value.is_empty() {
return vec![String::new()];
}
if value.chars().count() <= width {
return vec![value.to_string()];
}
match layout {
TableCellLayout::Truncate => vec![truncate(value, width)],
TableCellLayout::Wrap => wrap(value, width),
}
}
fn truncate(value: &str, width: usize) -> String {
if width <= 1 {
return value.chars().take(width).collect();
}
let keep = width.saturating_sub(1);
let mut out = value.chars().take(keep).collect::<String>();
out.push('…');
out
}
fn wrap(value: &str, width: usize) -> Vec<String> {
let chars = value.chars().collect::<Vec<_>>();
chars
.chunks(width.max(1))
.map(|chunk| chunk.iter().collect::<String>())
.collect()
}
fn pad_to_width(value: &str, width: usize) -> String {
let value_width = value.chars().count();
if value_width >= width {
value.to_string()
} else {
format!("{}{}", value, " ".repeat(width - value_width))
}
}
fn csv_escape(value: &str) -> String {
if value.contains(',') || value.contains('"') || value.contains('\n') || value.contains('\r') {
format!("\"{}\"", value.replace('"', "\"\""))
} else {
value.to_string()
}
}
#[cfg(test)]
mod tests {
use omnigraph_server::api::{ReadOutput, ReadTargetOutput};
use super::*;
fn sample_output() -> ReadOutput {
ReadOutput {
query_name: "get_person".to_string(),
target: ReadTargetOutput {
branch: Some("main".to_string()),
snapshot: None,
},
row_count: 1,
columns: vec!["name".to_string(), "age".to_string()],
rows: serde_json::json!([{ "name": "Alice", "age": 30 }]),
}
}
#[test]
fn csv_format_outputs_header_and_rows() {
let rendered = render_read(
&sample_output(),
ReadOutputFormat::Csv,
&ReadRenderOptions {
max_column_width: 80,
cell_layout: TableCellLayout::Truncate,
},
)
.unwrap();
assert!(rendered.lines().next().unwrap().contains("name,age"));
assert!(rendered.contains("Alice,30"));
}
#[test]
fn jsonl_format_emits_metadata_first() {
let rendered = render_read(
&sample_output(),
ReadOutputFormat::Jsonl,
&ReadRenderOptions {
max_column_width: 80,
cell_layout: TableCellLayout::Truncate,
},
)
.unwrap();
let first = rendered.lines().next().unwrap();
assert!(first.contains("\"kind\":\"metadata\""));
assert!(
rendered
.lines()
.nth(1)
.unwrap()
.contains("\"name\":\"Alice\"")
);
}
#[test]
fn render_falls_back_to_discovered_columns_for_legacy_payloads() {
let mut output = sample_output();
output.columns.clear();
let rendered = render_read(
&output,
ReadOutputFormat::Csv,
&ReadRenderOptions {
max_column_width: 80,
cell_layout: TableCellLayout::Truncate,
},
)
.unwrap();
assert!(rendered.lines().next().unwrap().contains("age,name"));
}
#[test]
fn csv_quotes_carriage_returns() {
assert_eq!(csv_escape("hello\rworld"), "\"hello\rworld\"");
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,292 @@
#![allow(dead_code)]
use std::fs;
use std::net::TcpListener;
use std::path::{Path, PathBuf};
use std::process::{Child, Command as StdCommand, Output, Stdio};
use std::thread::sleep;
use std::time::Duration;
use assert_cmd::Command;
use omnigraph::db::Omnigraph;
use omnigraph::loader::LoadMode;
use reqwest::blocking::Client;
use serde_json::Value;
use tempfile::{TempDir, tempdir};
pub fn cli() -> Command {
Command::cargo_bin("omnigraph").unwrap()
}
pub fn cli_process() -> StdCommand {
StdCommand::new(assert_cmd::cargo::cargo_bin("omnigraph"))
}
fn server_process() -> StdCommand {
if let Some(path) = std::env::var_os("CARGO_BIN_EXE_omnigraph-server") {
StdCommand::new(path)
} else if let Some(path) = built_server_binary() {
StdCommand::new(path)
} else {
let cargo = std::env::var_os("CARGO").unwrap_or_else(|| "cargo".into());
let mut cmd = StdCommand::new(cargo);
cmd.arg("run")
.arg("--quiet")
.arg("-p")
.arg("omnigraph-server")
.arg("--");
cmd
}
}
fn built_server_binary() -> Option<PathBuf> {
let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../..");
let candidate = workspace_root
.join("target")
.join("debug")
.join(format!("omnigraph-server{}", std::env::consts::EXE_SUFFIX));
candidate.exists().then_some(candidate)
}
pub fn fixture(name: &str) -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../omnigraph/tests/fixtures")
.join(name)
}
pub fn repo_path(root: &Path) -> PathBuf {
root.join("demo.omni")
}
pub fn output_success(cmd: &mut Command) -> Output {
let output = cmd.output().unwrap();
assert!(
output.status.success(),
"command failed\nstdout:\n{}\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
output
}
pub fn output_failure(cmd: &mut Command) -> Output {
let output = cmd.output().unwrap();
assert!(
!output.status.success(),
"command unexpectedly succeeded\nstdout:\n{}\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
output
}
pub fn stdout_string(output: &Output) -> String {
String::from_utf8(output.stdout.clone()).unwrap()
}
pub fn parse_stdout_json(output: &Output) -> Value {
serde_json::from_slice(&output.stdout).unwrap()
}
pub fn init_repo(repo: &Path) {
let schema = fixture("test.pg");
output_success(cli().arg("init").arg("--schema").arg(&schema).arg(repo));
}
pub fn load_fixture(repo: &Path) {
let data = fixture("test.jsonl");
output_success(cli().arg("load").arg("--data").arg(&data).arg(repo));
}
pub fn write_jsonl(path: &Path, rows: &str) {
fs::write(path, rows).unwrap();
}
pub fn write_query_file(path: &Path, source: &str) {
fs::write(path, source).unwrap();
}
pub fn write_config(path: &Path, source: &str) {
fs::write(path, source).unwrap();
}
fn yaml_string(value: &str) -> String {
format!("'{}'", value.replace('\'', "''"))
}
pub fn local_yaml_config(repo: &Path) -> String {
format!(
"\
targets:
local:
uri: {}
cli:
target: local
branch: main
query:
roots:
- .
policy: {{}}
",
yaml_string(&repo.to_string_lossy())
)
}
pub fn remote_yaml_config(url: &str) -> String {
format!(
"\
targets:
dev:
uri: {}
cli:
target: dev
branch: main
query:
roots:
- .
policy: {{}}
",
yaml_string(url)
)
}
pub struct TestServer {
child: Child,
pub base_url: String,
}
impl Drop for TestServer {
fn drop(&mut self) {
let _ = self.child.kill();
let _ = self.child.wait();
}
}
fn free_port() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
fn spawn_server_process(mut command: StdCommand) -> TestServer {
let port = free_port();
let bind = format!("127.0.0.1:{}", port);
let mut child = command
.arg("--bind")
.arg(&bind)
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.unwrap();
let base_url = format!("http://{}", bind);
let client = Client::new();
for _ in 0..300 {
if client
.get(format!("{}/healthz", base_url))
.send()
.map(|response| response.status().is_success())
.unwrap_or(false)
{
return TestServer { child, base_url };
}
if let Some(status) = child.try_wait().unwrap() {
panic!("server exited before becoming healthy: {status}");
}
sleep(Duration::from_millis(100));
}
panic!("server did not become healthy");
}
pub fn spawn_server(repo: &Path) -> TestServer {
let mut command = server_process();
command.arg(repo);
spawn_server_process(command)
}
pub fn spawn_server_with_config(config: &Path) -> TestServer {
let mut command = server_process();
command.arg("--config").arg(config);
spawn_server_process(command)
}
pub fn spawn_server_with_config_env(config: &Path, envs: &[(&str, &str)]) -> TestServer {
let mut command = server_process();
command.arg("--config").arg(config);
for (name, value) in envs {
command.env(name, value);
}
spawn_server_process(command)
}
pub async fn begin_manual_run(repo: &Path, target_branch: &str) -> String {
let mut db = Omnigraph::open(repo.to_str().unwrap()).await.unwrap();
let run = db
.begin_run(target_branch, Some("cli-test-run"))
.await
.unwrap();
db.load(
&run.run_branch,
r#"{"type":"Person","data":{"name":"Eve","age":29}}"#,
LoadMode::Append,
)
.await
.unwrap();
run.run_id.as_str().to_string()
}
pub struct SystemRepo {
_temp: TempDir,
repo: PathBuf,
}
impl SystemRepo {
pub fn initialized() -> Self {
let temp = tempdir().unwrap();
let repo = repo_path(temp.path());
init_repo(&repo);
Self { _temp: temp, repo }
}
pub fn loaded() -> Self {
let temp = tempdir().unwrap();
let repo = repo_path(temp.path());
init_repo(&repo);
load_fixture(&repo);
Self { _temp: temp, repo }
}
pub fn path(&self) -> &Path {
&self.repo
}
pub fn write_query(&self, name: &str, source: &str) -> PathBuf {
let path = self.repo.parent().unwrap().join(name);
write_query_file(&path, source);
path
}
pub fn write_jsonl(&self, name: &str, rows: &str) -> PathBuf {
let path = self.repo.parent().unwrap().join(name);
write_jsonl(&path, rows);
path
}
pub fn write_config(&self, name: &str, source: &str) -> PathBuf {
let path = self.repo.parent().unwrap().join(name);
write_config(&path, source);
path
}
pub fn spawn_server(&self) -> TestServer {
spawn_server(&self.repo)
}
pub fn spawn_server_with_config(&self, config: &Path) -> TestServer {
spawn_server_with_config(config)
}
pub fn spawn_server_with_config_env(&self, config: &Path, envs: &[(&str, &str)]) -> TestServer {
spawn_server_with_config_env(config, envs)
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,810 @@
mod support;
use std::fs;
use reqwest::blocking::Client;
use serde_json::json;
use support::*;
const REMOTE_POLICY_E2E_YAML: &str = r#"
version: 1
groups:
team: [act-bruno]
admins: [act-ragnor]
protected_branches: [main]
rules:
- id: team-read
allow:
actors: { group: team }
actions: [read]
branch_scope: any
- id: team-branch-create
allow:
actors: { group: team }
actions: [branch_create]
target_branch_scope: unprotected
- id: team-write-unprotected
allow:
actors: { group: team }
actions: [change]
branch_scope: unprotected
- id: admins-promote
allow:
actors: { group: admins }
actions: [branch_merge, run_publish]
target_branch_scope: protected
"#;
fn yaml_string(value: &str) -> String {
format!("'{}'", value.replace('\'', "''"))
}
fn remote_policy_server_config(repo: &SystemRepo) -> String {
format!(
"\
project:
name: remote-policy-e2e
targets:
local:
uri: {}
server:
target: local
policy:
file: ./policy.yaml
",
yaml_string(&repo.path().to_string_lossy())
)
}
fn remote_policy_client_config(url: &str) -> String {
format!(
"\
targets:
dev:
uri: {}
bearer_token_env: POLICY_TEST_TOKEN
cli:
target: dev
branch: main
query:
roots:
- .
auth:
env_file: ./.env.omni
",
yaml_string(url)
)
}
#[test]
#[ignore = "requires loopback socket permissions in sandboxed runners"]
fn remote_server_and_cli_end_to_end_flow() {
let repo = SystemRepo::loaded();
let server = repo.spawn_server();
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
let mutation_file = repo.write_query(
"system-remote-change.gq",
r#"
query insert_person($name: String, $age: I32) {
insert Person { name: $name, age: $age }
}
"#,
);
let client = Client::new();
let health = client
.get(format!("{}/healthz", server.base_url))
.send()
.unwrap()
.error_for_status()
.unwrap()
.json::<serde_json::Value>()
.unwrap();
assert_eq!(health["status"], "ok");
let local_snapshot = parse_stdout_json(&output_success(
cli().arg("snapshot").arg(repo.path()).arg("--json"),
));
let snapshot = parse_stdout_json(&output_success(
cli()
.arg("snapshot")
.arg("--config")
.arg(&config)
.arg("--json"),
));
assert_eq!(snapshot["branch"], "main");
assert_eq!(snapshot["tables"], local_snapshot["tables"]);
let local_read = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg(repo.path())
.arg("--query")
.arg(fixture("test.gq"))
.arg("--name")
.arg("get_person")
.arg("--params")
.arg(r#"{"name":"Alice"}"#)
.arg("--json"),
));
let read_payload = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(fixture("test.gq"))
.arg("--name")
.arg("get_person")
.arg("--params")
.arg(r#"{"name":"Alice"}"#)
.arg("--json"),
));
assert_eq!(read_payload, local_read);
assert_eq!(read_payload["row_count"], 1);
assert_eq!(read_payload["rows"][0]["p.name"], "Alice");
let change_payload = parse_stdout_json(&output_success(
cli()
.arg("change")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(&mutation_file)
.arg("--params")
.arg(r#"{"name":"Mina","age":28}"#)
.arg("--json"),
));
assert_eq!(change_payload["affected_nodes"], 1);
let query_source = fs::read_to_string(fixture("test.gq")).unwrap();
let http_read = client
.post(format!("{}/read", server.base_url))
.json(&json!({
"branch": "main",
"query_source": query_source,
"query_name": "get_person",
"params": { "name": "Mina" }
}))
.send()
.unwrap()
.error_for_status()
.unwrap()
.json::<serde_json::Value>()
.unwrap();
assert_eq!(http_read["row_count"], 1);
assert_eq!(http_read["rows"][0]["p.name"], "Mina");
let local_verify = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg(repo.path())
.arg("--query")
.arg(fixture("test.gq"))
.arg("--name")
.arg("get_person")
.arg("--params")
.arg(r#"{"name":"Mina"}"#)
.arg("--json"),
));
assert_eq!(local_verify["row_count"], 1);
assert_eq!(local_verify["rows"][0]["p.name"], "Mina");
let manual_run = tokio::runtime::Runtime::new()
.unwrap()
.block_on(begin_manual_run(repo.path(), "main"));
let publish_payload = parse_stdout_json(&output_success(
cli()
.arg("run")
.arg("publish")
.arg("--config")
.arg(&config)
.arg(&manual_run)
.arg("--json"),
));
assert_eq!(publish_payload["run_id"], manual_run);
assert_eq!(publish_payload["status"], "published");
let runs_payload = parse_stdout_json(&output_success(
cli()
.arg("run")
.arg("list")
.arg("--config")
.arg(&config)
.arg("--json"),
));
assert!(runs_payload["runs"].as_array().unwrap().len() >= 2);
}
#[test]
#[ignore = "requires loopback socket permissions in sandboxed runners"]
fn remote_read_preserves_projection_order_in_json_and_csv() {
let repo = SystemRepo::loaded();
let server = repo.spawn_server();
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
let ordered_query = repo.write_query(
"ordered-remote.gq",
r#"
query ordered_person($name: String) {
match {
$p: Person { name: $name }
}
return { $p.age, $p.name }
}
"#,
);
let json_payload = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(&ordered_query)
.arg("--name")
.arg("ordered_person")
.arg("--params")
.arg(r#"{"name":"Alice"}"#)
.arg("--json"),
));
let columns = json_payload["columns"]
.as_array()
.unwrap()
.iter()
.map(|value| value.as_str().unwrap())
.collect::<Vec<_>>();
assert_eq!(columns, vec!["p.age", "p.name"]);
let csv = stdout_string(&output_success(
cli()
.arg("read")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(&ordered_query)
.arg("--name")
.arg("ordered_person")
.arg("--params")
.arg(r#"{"name":"Alice"}"#)
.arg("--format")
.arg("csv"),
));
let mut lines = csv.lines();
assert_eq!(lines.next().unwrap(), "p.age,p.name");
assert_eq!(lines.next().unwrap(), "30,Alice");
}
#[test]
#[ignore = "requires loopback socket permissions in sandboxed runners"]
fn remote_branch_create_list_merge_flow() {
let repo = SystemRepo::loaded();
let server = repo.spawn_server();
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
let mutation_file = repo.write_query(
"system-remote-branch-change.gq",
r#"
query insert_person($name: String, $age: I32) {
insert Person { name: $name, age: $age }
}
"#,
);
let initial = parse_stdout_json(&output_success(
cli()
.arg("branch")
.arg("list")
.arg("--config")
.arg(&config)
.arg("--json"),
));
assert_eq!(initial["branches"], json!(["main"]));
let created = parse_stdout_json(&output_success(
cli()
.arg("branch")
.arg("create")
.arg("--config")
.arg(&config)
.arg("--from")
.arg("main")
.arg("feature")
.arg("--json"),
));
assert_eq!(created["from"], "main");
assert_eq!(created["name"], "feature");
let listed = parse_stdout_json(&output_success(
cli()
.arg("branch")
.arg("list")
.arg("--config")
.arg(&config)
.arg("--json"),
));
assert_eq!(listed["branches"], json!(["feature", "main"]));
let changed = parse_stdout_json(&output_success(
cli()
.arg("change")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(&mutation_file)
.arg("--branch")
.arg("feature")
.arg("--params")
.arg(r#"{"name":"Zoe","age":33}"#)
.arg("--json"),
));
assert_eq!(changed["branch"], "feature");
assert_eq!(changed["affected_nodes"], 1);
let merged = parse_stdout_json(&output_success(
cli()
.arg("branch")
.arg("merge")
.arg("--config")
.arg(&config)
.arg("feature")
.arg("--into")
.arg("main")
.arg("--json"),
));
assert_eq!(merged["source"], "feature");
assert_eq!(merged["target"], "main");
assert_eq!(merged["outcome"], "fast_forward");
let verify = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(fixture("test.gq"))
.arg("--name")
.arg("get_person")
.arg("--params")
.arg(r#"{"name":"Zoe"}"#)
.arg("--json"),
));
assert_eq!(verify["row_count"], 1);
assert_eq!(verify["rows"][0]["p.name"], "Zoe");
}
#[test]
#[ignore = "requires loopback socket permissions in sandboxed runners"]
fn remote_branch_delete_removes_branch() {
let repo = SystemRepo::loaded();
let server = repo.spawn_server();
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
parse_stdout_json(&output_success(
cli()
.arg("branch")
.arg("create")
.arg("--config")
.arg(&config)
.arg("--from")
.arg("main")
.arg("feature")
.arg("--json"),
));
let deleted = parse_stdout_json(&output_success(
cli()
.arg("branch")
.arg("delete")
.arg("--config")
.arg(&config)
.arg("feature")
.arg("--json"),
));
assert_eq!(deleted["name"], "feature");
let listed = parse_stdout_json(&output_success(
cli()
.arg("branch")
.arg("list")
.arg("--config")
.arg(&config)
.arg("--json"),
));
assert_eq!(listed["branches"], json!(["main"]));
}
#[test]
#[ignore = "requires loopback socket permissions in sandboxed runners"]
fn remote_export_round_trips_full_branch_graph() {
let repo = SystemRepo::loaded();
let server = repo.spawn_server();
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
let mutation_file = repo.write_query(
"system-remote-export-change.gq",
r#"
query insert_person($name: String, $age: I32) {
insert Person { name: $name, age: $age }
}
query add_friend($from: String, $to: String) {
insert Knows { from: $from, to: $to }
}
"#,
);
output_success(
cli()
.arg("branch")
.arg("create")
.arg("--config")
.arg(&config)
.arg("--from")
.arg("main")
.arg("feature"),
);
output_success(
cli()
.arg("change")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(&mutation_file)
.arg("--name")
.arg("insert_person")
.arg("--branch")
.arg("feature")
.arg("--params")
.arg(r#"{"name":"Eve","age":29}"#)
.arg("--json"),
);
output_success(
cli()
.arg("change")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(&mutation_file)
.arg("--name")
.arg("add_friend")
.arg("--branch")
.arg("feature")
.arg("--params")
.arg(r#"{"from":"Alice","to":"Eve"}"#)
.arg("--json"),
);
let exported = stdout_string(&output_success(
cli()
.arg("export")
.arg("--config")
.arg(&config)
.arg("--branch")
.arg("feature")
.arg("--jsonl"),
));
let export_path = repo.write_jsonl("system-remote-exported.jsonl", &exported);
let imported_repo = repo
.path()
.parent()
.unwrap()
.join("imported-remote-export.omni");
output_success(
cli()
.arg("init")
.arg("--schema")
.arg(fixture("test.pg"))
.arg(&imported_repo),
);
output_success(
cli()
.arg("load")
.arg("--data")
.arg(&export_path)
.arg(&imported_repo),
);
let snapshot = parse_stdout_json(&output_success(
cli().arg("snapshot").arg(&imported_repo).arg("--json"),
));
assert_eq!(
snapshot["tables"]
.as_array()
.unwrap()
.iter()
.find(|table| table["table_key"] == "node:Person")
.unwrap()["row_count"],
5
);
assert_eq!(
snapshot["tables"]
.as_array()
.unwrap()
.iter()
.find(|table| table["table_key"] == "edge:Knows")
.unwrap()["row_count"],
4
);
let eve = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg(&imported_repo)
.arg("--query")
.arg(fixture("test.gq"))
.arg("--name")
.arg("get_person")
.arg("--params")
.arg(r#"{"name":"Eve"}"#)
.arg("--json"),
));
assert_eq!(eve["row_count"], 1);
assert_eq!(eve["rows"][0]["p.name"], "Eve");
}
#[test]
#[ignore = "requires loopback socket permissions in sandboxed runners"]
fn remote_ingest_creates_review_branch_and_keeps_it_readable() {
let repo = SystemRepo::loaded();
let server = repo.spawn_server();
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
let ingest_data = repo.write_jsonl(
"system-remote-ingest.jsonl",
r#"{"type":"Person","data":{"name":"Zoe","age":33}}
{"type":"Person","data":{"name":"Bob","age":26}}"#,
);
let ingest_payload = parse_stdout_json(&output_success(
cli()
.arg("ingest")
.arg("--config")
.arg(&config)
.arg("--data")
.arg(&ingest_data)
.arg("--branch")
.arg("feature-ingest")
.arg("--json"),
));
assert_eq!(ingest_payload["branch"], "feature-ingest");
assert_eq!(ingest_payload["base_branch"], "main");
assert_eq!(ingest_payload["branch_created"], true);
assert_eq!(ingest_payload["mode"], "merge");
assert_eq!(ingest_payload["tables"][0]["table_key"], "node:Person");
assert_eq!(ingest_payload["tables"][0]["rows_loaded"], 2);
let feature_snapshot = parse_stdout_json(&output_success(
cli()
.arg("snapshot")
.arg("--config")
.arg(&config)
.arg("--branch")
.arg("feature-ingest")
.arg("--json"),
));
assert_eq!(feature_snapshot["branch"], "feature-ingest");
let zoe = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(fixture("test.gq"))
.arg("--name")
.arg("get_person")
.arg("--branch")
.arg("feature-ingest")
.arg("--params")
.arg(r#"{"name":"Zoe"}"#)
.arg("--json"),
));
assert_eq!(zoe["row_count"], 1);
assert_eq!(zoe["rows"][0]["p.name"], "Zoe");
}
#[test]
#[ignore = "requires loopback socket permissions in sandboxed runners"]
fn remote_ingest_reuses_existing_branch_and_merges_updates() {
let repo = SystemRepo::loaded();
let server = repo.spawn_server();
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
output_success(
cli()
.arg("branch")
.arg("create")
.arg("--config")
.arg(&config)
.arg("--from")
.arg("main")
.arg("feature-ingest"),
);
let ingest_data = repo.write_jsonl(
"system-remote-ingest-merge.jsonl",
r#"{"type":"Person","data":{"name":"Bob","age":26}}
{"type":"Person","data":{"name":"Zoe","age":33}}"#,
);
let ingest_payload = parse_stdout_json(&output_success(
cli()
.arg("ingest")
.arg("--config")
.arg(&config)
.arg("--data")
.arg(&ingest_data)
.arg("--branch")
.arg("feature-ingest")
.arg("--from")
.arg("missing-base")
.arg("--json"),
));
assert_eq!(ingest_payload["branch"], "feature-ingest");
assert_eq!(ingest_payload["base_branch"], "missing-base");
assert_eq!(ingest_payload["branch_created"], false);
assert_eq!(ingest_payload["mode"], "merge");
assert_eq!(ingest_payload["tables"][0]["table_key"], "node:Person");
assert_eq!(ingest_payload["tables"][0]["rows_loaded"], 2);
let bob = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(fixture("test.gq"))
.arg("--name")
.arg("get_person")
.arg("--branch")
.arg("feature-ingest")
.arg("--params")
.arg(r#"{"name":"Bob"}"#)
.arg("--json"),
));
assert_eq!(bob["row_count"], 1);
assert_eq!(bob["rows"][0]["p.age"], 26);
let zoe = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg("--config")
.arg(&config)
.arg("--query")
.arg(fixture("test.gq"))
.arg("--name")
.arg("get_person")
.arg("--branch")
.arg("feature-ingest")
.arg("--params")
.arg(r#"{"name":"Zoe"}"#)
.arg("--json"),
));
assert_eq!(zoe["row_count"], 1);
assert_eq!(zoe["rows"][0]["p.name"], "Zoe");
}
#[test]
#[ignore = "requires loopback socket permissions in sandboxed runners"]
fn remote_policy_enforces_branch_first_cli_workflow() {
let repo = SystemRepo::loaded();
let server_config =
repo.write_config("server-policy.yaml", &remote_policy_server_config(&repo));
repo.write_config("policy.yaml", REMOTE_POLICY_E2E_YAML);
let server = repo.spawn_server_with_config_env(
&server_config,
&[(
"OMNIGRAPH_SERVER_BEARER_TOKENS_JSON",
r#"{"act-bruno":"team-token","act-ragnor":"admin-token"}"#,
)],
);
let client_config = repo.write_config(
"omnigraph-policy.yaml",
&remote_policy_client_config(&server.base_url),
);
repo.write_config(".env.omni", "POLICY_TEST_TOKEN=team-token\n");
let mutation_file = repo.write_query(
"system-remote-policy-change.gq",
r#"
query insert_person($name: String, $age: I32) {
insert Person { name: $name, age: $age }
}
"#,
);
let snapshot = parse_stdout_json(&output_success(
cli()
.arg("snapshot")
.arg("--config")
.arg(&client_config)
.arg("--json"),
));
assert_eq!(snapshot["branch"], "main");
let denied_main_change = output_failure(
cli()
.arg("change")
.arg("--config")
.arg(&client_config)
.arg("--query")
.arg(&mutation_file)
.arg("--params")
.arg(r#"{"name":"PolicyRemote","age":41}"#)
.arg("--json"),
);
let denied_main_stderr = String::from_utf8(denied_main_change.stderr).unwrap();
assert!(denied_main_stderr.contains("policy denied action 'change' on branch 'main'"));
let created = parse_stdout_json(&output_success(
cli()
.arg("branch")
.arg("create")
.arg("--config")
.arg(&client_config)
.arg("--from")
.arg("main")
.arg("feature")
.arg("--json"),
));
assert_eq!(created["name"], "feature");
let changed = parse_stdout_json(&output_success(
cli()
.arg("change")
.arg("--config")
.arg(&client_config)
.arg("--query")
.arg(&mutation_file)
.arg("--branch")
.arg("feature")
.arg("--params")
.arg(r#"{"name":"PolicyRemote","age":41}"#)
.arg("--json"),
));
assert_eq!(changed["branch"], "feature");
assert_eq!(changed["affected_nodes"], 1);
let denied_merge = output_failure(
cli()
.arg("branch")
.arg("merge")
.arg("--config")
.arg(&client_config)
.arg("feature")
.arg("--into")
.arg("main")
.arg("--json"),
);
let denied_merge_stderr = String::from_utf8(denied_merge.stderr).unwrap();
assert!(denied_merge_stderr.contains("policy denied action 'branch_merge'"));
let merged = parse_stdout_json(&output_success(
cli()
.env("POLICY_TEST_TOKEN", "admin-token")
.arg("branch")
.arg("merge")
.arg("--config")
.arg(&client_config)
.arg("feature")
.arg("--into")
.arg("main")
.arg("--json"),
));
assert_eq!(merged["target"], "main");
let verify = parse_stdout_json(&output_success(
cli()
.arg("read")
.arg("--config")
.arg(&client_config)
.arg("--query")
.arg(fixture("test.gq"))
.arg("--name")
.arg("get_person")
.arg("--params")
.arg(r#"{"name":"PolicyRemote"}"#)
.arg("--json"),
));
assert_eq!(verify["row_count"], 1);
assert_eq!(verify["rows"][0]["p.name"], "PolicyRemote");
}

View file

@ -0,0 +1,26 @@
[package]
name = "omnigraph-compiler"
version = "0.4.0"
edition = "2024"
description = "Schema/query compiler for Omnigraph. Zero Lance dependency."
license = "MIT"
[dependencies]
arrow-array = { workspace = true }
arrow-ipc = { workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
arrow-cast = { workspace = true }
arrow-ord = { workspace = true }
pest = { workspace = true }
pest_derive = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
reqwest = { workspace = true }
ahash = { workspace = true }
tokio = { workspace = true }
sha2 = { workspace = true }
[dev-dependencies]
tokio = { workspace = true }

View file

@ -0,0 +1,594 @@
pub mod schema_ir;
pub mod schema_plan;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use crate::error::{NanoError, Result};
use crate::schema::ast::{Cardinality, Constraint, ConstraintBound, SchemaDecl, SchemaFile};
use crate::types::{PropType, ScalarType};
#[derive(Debug, Clone)]
pub struct Catalog {
pub node_types: HashMap<String, NodeType>,
pub edge_types: HashMap<String, EdgeType>,
/// Maps normalized lowercase edge name -> EdgeType key (e.g. "knows" -> "Knows")
pub edge_name_index: HashMap<String, String>,
/// Interface declarations (for Phase 2 polymorphic queries)
pub interfaces: HashMap<String, InterfaceType>,
}
#[derive(Debug, Clone)]
pub struct InterfaceType {
pub name: String,
pub properties: HashMap<String, PropType>,
}
#[derive(Debug, Clone)]
pub struct NodeType {
pub name: String,
/// Interface names this type implements
pub implements: Vec<String>,
pub properties: HashMap<String, PropType>,
/// Key property names (from `@key` or `@key(name)`). Usually 0 or 1 element.
pub key: Option<Vec<String>>,
/// Uniqueness constraints (each entry is a list of column names)
pub unique_constraints: Vec<Vec<String>>,
/// Index declarations (each entry is a list of column names)
pub indices: Vec<Vec<String>>,
/// Value range constraints
pub range_constraints: Vec<RangeConstraint>,
/// Regex check constraints
pub check_constraints: Vec<CheckConstraint>,
/// Maps @embed target property -> source text property
pub embed_sources: HashMap<String, String>,
pub blob_properties: HashSet<String>,
pub arrow_schema: SchemaRef,
}
impl NodeType {
/// Backward-compatible accessor: returns the first (and typically only) key property name.
pub fn key_property(&self) -> Option<&str> {
self.key
.as_ref()
.and_then(|v| v.first())
.map(|s| s.as_str())
}
}
#[derive(Debug, Clone)]
pub struct RangeConstraint {
pub property: String,
pub min: Option<LiteralValue>,
pub max: Option<LiteralValue>,
}
#[derive(Debug, Clone)]
pub enum LiteralValue {
Integer(i64),
Float(f64),
}
#[derive(Debug, Clone)]
pub struct CheckConstraint {
pub property: String,
pub pattern: String,
}
#[derive(Debug, Clone)]
pub struct EdgeType {
pub name: String,
pub from_type: String,
pub to_type: String,
pub cardinality: Cardinality,
pub properties: HashMap<String, PropType>,
/// Uniqueness constraints on edge columns (e.g. `@unique(src, dst)`)
pub unique_constraints: Vec<Vec<String>>,
/// Index declarations on edge properties
pub indices: Vec<Vec<String>>,
pub blob_properties: HashSet<String>,
pub arrow_schema: SchemaRef,
}
impl Catalog {
pub fn lookup_edge_by_name(&self, name: &str) -> Option<&EdgeType> {
if let Some(et) = self.edge_types.get(name) {
return Some(et);
}
if let Some(key) = self.edge_name_index.get(&normalize_edge_name(name)) {
return self.edge_types.get(key);
}
None
}
}
fn normalize_edge_name(name: &str) -> String {
name.to_lowercase()
}
fn bound_to_literal(b: &ConstraintBound) -> LiteralValue {
match b {
ConstraintBound::Integer(n) => LiteralValue::Integer(*n),
ConstraintBound::Float(f) => LiteralValue::Float(*f),
}
}
pub fn build_catalog(schema: &SchemaFile) -> Result<Catalog> {
let mut node_types = HashMap::new();
let mut edge_types = HashMap::new();
let mut edge_name_index = HashMap::new();
let mut interfaces = HashMap::new();
// Pass 0: collect interfaces
for decl in &schema.declarations {
if let SchemaDecl::Interface(iface) = decl {
let mut properties = HashMap::new();
for prop in &iface.properties {
properties.insert(prop.name.clone(), prop.prop_type.clone());
}
interfaces.insert(
iface.name.clone(),
InterfaceType {
name: iface.name.clone(),
properties,
},
);
}
}
// Pass 1: collect node types
for decl in &schema.declarations {
if let SchemaDecl::Node(node) = decl {
if node_types.contains_key(&node.name) {
return Err(NanoError::Catalog(format!(
"duplicate node type: {}",
node.name
)));
}
let mut properties = HashMap::new();
let mut embed_sources = HashMap::new();
let mut blob_properties = HashSet::new();
for prop in &node.properties {
properties.insert(prop.name.clone(), prop.prop_type.clone());
if matches!(prop.prop_type.scalar, ScalarType::Blob) {
blob_properties.insert(prop.name.clone());
}
// Extract @embed from property annotations (stays as annotation)
if let Some(source_prop) = prop
.annotations
.iter()
.find(|ann| ann.name == "embed")
.and_then(|ann| ann.value.clone())
{
embed_sources.insert(prop.name.clone(), source_prop);
}
}
// Extract constraints from the typed Constraint enum
let mut key: Option<Vec<String>> = None;
let mut unique_constraints = Vec::new();
let mut indices = Vec::new();
let mut range_constraints = Vec::new();
let mut check_constraints = Vec::new();
for constraint in &node.constraints {
match constraint {
Constraint::Key(cols) => {
key = Some(cols.clone());
// @key implies index on key columns
indices.push(cols.clone());
}
Constraint::Unique(cols) => {
unique_constraints.push(cols.clone());
}
Constraint::Index(cols) => {
indices.push(cols.clone());
}
Constraint::Range { property, min, max } => {
range_constraints.push(RangeConstraint {
property: property.clone(),
min: min.as_ref().map(bound_to_literal),
max: max.as_ref().map(bound_to_literal),
});
}
Constraint::Check { property, pattern } => {
check_constraints.push(CheckConstraint {
property: property.clone(),
pattern: pattern.clone(),
});
}
}
}
// Build Arrow schema: id: Utf8 + all properties
let mut fields = vec![Field::new("id", DataType::Utf8, false)];
for prop in &node.properties {
fields.push(Field::new(
&prop.name,
prop.prop_type.to_arrow(),
prop.prop_type.nullable,
));
}
let arrow_schema = Arc::new(Schema::new(fields));
node_types.insert(
node.name.clone(),
NodeType {
name: node.name.clone(),
implements: node.implements.clone(),
properties,
key,
unique_constraints,
indices,
range_constraints,
check_constraints,
embed_sources,
blob_properties,
arrow_schema,
},
);
}
}
// Pass 2: collect edge types, validate endpoints
for decl in &schema.declarations {
if let SchemaDecl::Edge(edge) = decl {
if edge_types.contains_key(&edge.name) {
return Err(NanoError::Catalog(format!(
"duplicate edge type: {}",
edge.name
)));
}
if !node_types.contains_key(&edge.from_type) {
return Err(NanoError::Catalog(format!(
"edge {} references unknown source type: {}",
edge.name, edge.from_type
)));
}
if !node_types.contains_key(&edge.to_type) {
return Err(NanoError::Catalog(format!(
"edge {} references unknown target type: {}",
edge.name, edge.to_type
)));
}
let mut properties = HashMap::new();
let mut blob_properties = HashSet::new();
let mut fields = vec![
Field::new("id", DataType::Utf8, false),
Field::new("src", DataType::Utf8, false),
Field::new("dst", DataType::Utf8, false),
];
for prop in &edge.properties {
properties.insert(prop.name.clone(), prop.prop_type.clone());
if matches!(prop.prop_type.scalar, ScalarType::Blob) {
blob_properties.insert(prop.name.clone());
}
fields.push(Field::new(
&prop.name,
prop.prop_type.to_arrow(),
prop.prop_type.nullable,
));
}
// Extract edge constraints
let mut unique_constraints = Vec::new();
let mut edge_indices = Vec::new();
for constraint in &edge.constraints {
match constraint {
Constraint::Unique(cols) => unique_constraints.push(cols.clone()),
Constraint::Index(cols) => edge_indices.push(cols.clone()),
_ => {} // Key/Range/Check validated at parse time to not appear on edges
}
}
let normalized_name = normalize_edge_name(&edge.name);
if let Some(existing) = edge_name_index.get(&normalized_name)
&& existing != &edge.name
{
return Err(NanoError::Catalog(format!(
"edge name collision after case folding: '{}' conflicts with '{}'",
edge.name, existing
)));
}
edge_name_index.insert(normalized_name, edge.name.clone());
edge_types.insert(
edge.name.clone(),
EdgeType {
name: edge.name.clone(),
from_type: edge.from_type.clone(),
to_type: edge.to_type.clone(),
cardinality: edge.cardinality.clone(),
properties,
unique_constraints,
indices: edge_indices,
blob_properties,
arrow_schema: Arc::new(Schema::new(fields)),
},
);
}
}
Ok(Catalog {
node_types,
edge_types,
edge_name_index,
interfaces,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::ast::{EdgeDecl, NodeDecl};
use crate::schema::parser::parse_schema;
use crate::types::PropType;
fn test_schema() -> &'static str {
r#"
node Person {
name: String
age: I32?
}
node Company {
name: String
}
edge Knows: Person -> Person {
since: Date?
}
edge WorksAt: Person -> Company {
title: String?
}
"#
}
#[test]
fn test_build_catalog() {
let schema = parse_schema(test_schema()).unwrap();
let catalog = build_catalog(&schema).unwrap();
assert_eq!(catalog.node_types.len(), 2);
assert_eq!(catalog.edge_types.len(), 2);
assert!(catalog.node_types.contains_key("Person"));
assert!(catalog.node_types.contains_key("Company"));
}
#[test]
fn test_edge_lookup() {
let schema = parse_schema(test_schema()).unwrap();
let catalog = build_catalog(&schema).unwrap();
let edge = catalog.lookup_edge_by_name("knows").unwrap();
assert_eq!(edge.from_type, "Person");
assert_eq!(edge.to_type, "Person");
let upper = catalog.lookup_edge_by_name("KNOWS").unwrap();
assert_eq!(upper.name, "Knows");
}
#[test]
fn test_node_arrow_schema() {
let schema = parse_schema(test_schema()).unwrap();
let catalog = build_catalog(&schema).unwrap();
let person = &catalog.node_types["Person"];
assert_eq!(person.arrow_schema.fields().len(), 3); // id, name, age
}
#[test]
fn test_duplicate_node_error() {
let input = r#"
node Person { name: String }
node Person { age: I32 }
"#;
let schema = parse_schema(input).unwrap();
assert!(build_catalog(&schema).is_err());
}
#[test]
fn test_bad_edge_endpoint() {
let input = r#"
node Person { name: String }
edge Knows: Person -> Alien
"#;
let schema = parse_schema(input).unwrap();
assert!(build_catalog(&schema).is_err());
}
#[test]
fn test_id_fields_are_utf8() {
let schema = parse_schema(test_schema()).unwrap();
let catalog = build_catalog(&schema).unwrap();
let person = &catalog.node_types["Person"];
assert_eq!(
person
.arrow_schema
.field_with_name("id")
.unwrap()
.data_type(),
&DataType::Utf8
);
let knows = &catalog.edge_types["Knows"];
assert_eq!(
knows
.arrow_schema
.field_with_name("id")
.unwrap()
.data_type(),
&DataType::Utf8
);
assert_eq!(
knows
.arrow_schema
.field_with_name("src")
.unwrap()
.data_type(),
&DataType::Utf8
);
assert_eq!(
knows
.arrow_schema
.field_with_name("dst")
.unwrap()
.data_type(),
&DataType::Utf8
);
}
#[test]
fn test_key_property_tracking() {
let input = r#"
node Signal {
slug: String @key
title: String
}
node Person {
name: String
}
edge Emits: Person -> Signal
"#;
let schema = parse_schema(input).unwrap();
let catalog = build_catalog(&schema).unwrap();
assert_eq!(catalog.node_types["Signal"].key_property(), Some("slug"));
assert_eq!(catalog.node_types["Person"].key_property(), None);
}
#[test]
fn test_edge_lookup_handles_non_ascii_leading_character() {
let schema = SchemaFile {
declarations: vec![
SchemaDecl::Node(NodeDecl {
name: "Person".to_string(),
annotations: vec![],
implements: vec![],
properties: vec![crate::schema::ast::PropDecl {
name: "name".to_string(),
prop_type: PropType::scalar(ScalarType::String, false),
annotations: vec![],
}],
constraints: vec![],
}),
SchemaDecl::Edge(EdgeDecl {
name: "Édges".to_string(),
from_type: "Person".to_string(),
to_type: "Person".to_string(),
cardinality: Default::default(),
annotations: vec![],
properties: vec![],
constraints: vec![],
}),
],
};
let catalog = build_catalog(&schema).unwrap();
assert!(catalog.lookup_edge_by_name("édges").is_some());
}
#[test]
fn test_edge_lookup_rejects_case_fold_collisions() {
let input = r#"
node Person { name: String }
edge Knows: Person -> Person
edge KNOWS: Person -> Person
"#;
let schema = parse_schema(input).unwrap();
let err = build_catalog(&schema).unwrap_err();
assert!(err.to_string().contains("case folding"));
}
#[test]
fn test_catalog_composite_unique() {
let input = r#"
node Person {
first: String
last: String
@unique(first, last)
}
"#;
let schema = parse_schema(input).unwrap();
let catalog = build_catalog(&schema).unwrap();
let person = &catalog.node_types["Person"];
assert!(
person
.unique_constraints
.contains(&vec!["first".to_string(), "last".to_string()])
);
}
#[test]
fn test_catalog_composite_index() {
let input = r#"
node Event {
category: String
date: Date
@index(category, date)
}
"#;
let schema = parse_schema(input).unwrap();
let catalog = build_catalog(&schema).unwrap();
let event = &catalog.node_types["Event"];
assert!(
event
.indices
.contains(&vec!["category".to_string(), "date".to_string()])
);
}
#[test]
fn test_catalog_edge_cardinality() {
let input = r#"
node Person { name: String }
node Company { name: String }
edge WorksAt: Person -> Company @card(0..1)
"#;
let schema = parse_schema(input).unwrap();
let catalog = build_catalog(&schema).unwrap();
let edge = &catalog.edge_types["WorksAt"];
assert_eq!(edge.cardinality.min, 0);
assert_eq!(edge.cardinality.max, Some(1));
}
#[test]
fn test_catalog_interfaces_stored() {
let input = r#"
interface Named {
name: String
}
node Person implements Named {
age: I32?
}
"#;
let schema = parse_schema(input).unwrap();
let catalog = build_catalog(&schema).unwrap();
assert!(catalog.interfaces.contains_key("Named"));
assert!(catalog.interfaces["Named"].properties.contains_key("name"));
}
#[test]
fn test_catalog_node_implements() {
let input = r#"
interface Named {
name: String
}
node Person implements Named {
age: I32?
}
"#;
let schema = parse_schema(input).unwrap();
let catalog = build_catalog(&schema).unwrap();
assert_eq!(catalog.node_types["Person"].implements, vec!["Named"]);
}
#[test]
fn test_key_implies_index() {
let input = r#"
node Signal {
slug: String @key
title: String
}
"#;
let schema = parse_schema(input).unwrap();
let catalog = build_catalog(&schema).unwrap();
let signal = &catalog.node_types["Signal"];
assert!(signal.indices.contains(&vec!["slug".to_string()]));
}
}

View file

@ -0,0 +1,393 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::catalog::{Catalog, build_catalog};
use crate::error::{NanoError, Result};
use crate::schema::ast::{Annotation, Cardinality, Constraint, PropDecl, SchemaDecl, SchemaFile};
use crate::types::PropType;
const SCHEMA_IR_VERSION: u32 = 1;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SchemaIR {
pub ir_version: u32,
pub interfaces: Vec<InterfaceIR>,
pub nodes: Vec<NodeIR>,
pub edges: Vec<EdgeIR>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct InterfaceIR {
pub name: String,
pub type_id: u32,
pub properties: Vec<PropertyIR>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct NodeIR {
pub name: String,
pub type_id: u32,
pub annotations: Vec<Annotation>,
pub implements: Vec<String>,
pub properties: Vec<PropertyIR>,
pub constraints: Vec<Constraint>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EdgeIR {
pub name: String,
pub type_id: u32,
pub from_type: String,
pub to_type: String,
pub cardinality: Cardinality,
pub annotations: Vec<Annotation>,
pub properties: Vec<PropertyIR>,
pub constraints: Vec<Constraint>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PropertyIR {
pub name: String,
pub prop_id: u32,
pub prop_type: PropType,
pub annotations: Vec<Annotation>,
}
pub fn build_schema_ir(schema: &SchemaFile) -> Result<SchemaIR> {
let mut seen_type_ids = HashMap::<u32, String>::new();
let mut interfaces = Vec::new();
let mut nodes = Vec::new();
let mut edges = Vec::new();
for decl in &schema.declarations {
match decl {
SchemaDecl::Interface(interface) => {
let type_id = stable_type_id("interface", &interface.name);
check_type_id_collision(&mut seen_type_ids, type_id, &interface.name)?;
interfaces.push(InterfaceIR {
name: interface.name.clone(),
type_id,
properties: canonical_properties(
"interface",
&interface.name,
&interface.properties,
)?,
});
}
SchemaDecl::Node(node) => {
let type_id = stable_type_id("node", &node.name);
check_type_id_collision(&mut seen_type_ids, type_id, &node.name)?;
nodes.push(NodeIR {
name: node.name.clone(),
type_id,
annotations: canonical_annotations(&node.annotations),
implements: canonical_strings(&node.implements),
properties: canonical_properties("node", &node.name, &node.properties)?,
constraints: canonical_constraints(&node.constraints),
});
}
SchemaDecl::Edge(edge) => {
let type_id = stable_type_id("edge", &edge.name);
check_type_id_collision(&mut seen_type_ids, type_id, &edge.name)?;
edges.push(EdgeIR {
name: edge.name.clone(),
type_id,
from_type: edge.from_type.clone(),
to_type: edge.to_type.clone(),
cardinality: edge.cardinality.clone(),
annotations: canonical_annotations(&edge.annotations),
properties: canonical_properties("edge", &edge.name, &edge.properties)?,
constraints: canonical_constraints(&edge.constraints),
});
}
}
}
interfaces.sort_by(|a, b| a.name.cmp(&b.name));
nodes.sort_by(|a, b| a.name.cmp(&b.name));
edges.sort_by(|a, b| a.name.cmp(&b.name));
Ok(SchemaIR {
ir_version: SCHEMA_IR_VERSION,
interfaces,
nodes,
edges,
})
}
pub fn build_catalog_from_ir(ir: &SchemaIR) -> Result<Catalog> {
if ir.ir_version != SCHEMA_IR_VERSION {
return Err(NanoError::Catalog(format!(
"unsupported schema ir_version {} (expected {})",
ir.ir_version, SCHEMA_IR_VERSION
)));
}
let schema = SchemaFile {
declarations: ir
.interfaces
.iter()
.map(|interface| {
SchemaDecl::Interface(crate::schema::ast::InterfaceDecl {
name: interface.name.clone(),
properties: interface
.properties
.iter()
.map(property_decl_from_ir)
.collect(),
})
})
.chain(ir.nodes.iter().map(|node| {
SchemaDecl::Node(crate::schema::ast::NodeDecl {
name: node.name.clone(),
annotations: node.annotations.clone(),
implements: node.implements.clone(),
properties: node.properties.iter().map(property_decl_from_ir).collect(),
constraints: node.constraints.clone(),
})
}))
.chain(ir.edges.iter().map(|edge| {
SchemaDecl::Edge(crate::schema::ast::EdgeDecl {
name: edge.name.clone(),
from_type: edge.from_type.clone(),
to_type: edge.to_type.clone(),
cardinality: edge.cardinality.clone(),
annotations: edge.annotations.clone(),
properties: edge.properties.iter().map(property_decl_from_ir).collect(),
constraints: edge.constraints.clone(),
})
}))
.collect(),
};
build_catalog(&schema)
}
pub fn schema_ir_json(ir: &SchemaIR) -> Result<String> {
serde_json::to_string(ir)
.map_err(|err| NanoError::Catalog(format!("serialize schema ir error: {}", err)))
}
pub fn schema_ir_pretty_json(ir: &SchemaIR) -> Result<String> {
serde_json::to_string_pretty(ir)
.map_err(|err| NanoError::Catalog(format!("serialize schema ir error: {}", err)))
}
pub fn schema_ir_hash(ir: &SchemaIR) -> Result<String> {
let json = schema_ir_json(ir)?;
let mut hasher = Sha256::new();
hasher.update(json.as_bytes());
Ok(format!("sha256:{:x}", hasher.finalize()))
}
fn property_decl_from_ir(property: &PropertyIR) -> PropDecl {
PropDecl {
name: property.name.clone(),
prop_type: property.prop_type.clone(),
annotations: property.annotations.clone(),
}
}
fn canonical_strings(values: &[String]) -> Vec<String> {
let mut values = values.to_vec();
values.sort();
values.dedup();
values
}
fn canonical_annotations(annotations: &[Annotation]) -> Vec<Annotation> {
let mut annotations = annotations.to_vec();
annotations.sort_by(|left, right| {
left.name
.cmp(&right.name)
.then_with(|| left.value.cmp(&right.value))
});
annotations
}
fn canonical_prop_type(prop_type: &PropType) -> PropType {
let mut normalized = prop_type.clone();
if let Some(values) = &mut normalized.enum_values {
values.sort();
values.dedup();
}
normalized
}
fn canonical_properties(
kind: &str,
owner_name: &str,
properties: &[PropDecl],
) -> Result<Vec<PropertyIR>> {
let mut seen_prop_ids = HashMap::<u32, String>::new();
let owner_key = format!("{}:{}", kind, owner_name);
let mut canonical = properties
.iter()
.map(|property| {
let prop_id = stable_prop_id(&owner_key, &property.name);
if let Some(previous) = seen_prop_ids.insert(prop_id, property.name.clone()) {
return Err(NanoError::Catalog(format!(
"property id collision on {}: '{}' and '{}' both hash to {}",
owner_name, previous, property.name, prop_id
)));
}
Ok(PropertyIR {
name: property.name.clone(),
prop_id,
prop_type: canonical_prop_type(&property.prop_type),
annotations: canonical_annotations(&property.annotations),
})
})
.collect::<Result<Vec<_>>>()?;
canonical.sort_by(|a, b| a.name.cmp(&b.name));
Ok(canonical)
}
fn canonical_constraints(constraints: &[Constraint]) -> Vec<Constraint> {
let mut constraints = constraints
.iter()
.cloned()
.map(normalize_constraint)
.collect::<Vec<_>>();
constraints.sort_by_key(constraint_sort_key);
constraints
}
fn normalize_constraint(constraint: Constraint) -> Constraint {
match constraint {
Constraint::Key(mut columns) => {
columns.sort();
Constraint::Key(columns)
}
Constraint::Unique(mut columns) => {
columns.sort();
Constraint::Unique(columns)
}
Constraint::Index(mut columns) => {
columns.sort();
Constraint::Index(columns)
}
other => other,
}
}
fn constraint_sort_key(constraint: &Constraint) -> String {
match constraint {
Constraint::Key(columns) => format!("key:{}", columns.join(",")),
Constraint::Unique(columns) => format!("unique:{}", columns.join(",")),
Constraint::Index(columns) => format!("index:{}", columns.join(",")),
Constraint::Range { property, min, max } => {
format!("range:{}:{:?}:{:?}", property, min, max)
}
Constraint::Check { property, pattern } => format!("check:{}:{}", property, pattern),
}
}
fn stable_type_id(kind: &str, name: &str) -> u32 {
fnv1a_u32(&format!("{}:{}", kind, name))
}
fn stable_prop_id(owner: &str, name: &str) -> u32 {
fnv1a_u32(&format!("{}:{}", owner, name))
}
fn fnv1a_u32(value: &str) -> u32 {
let mut hash: u32 = 2_166_136_261;
for byte in value.bytes() {
hash ^= u32::from(byte);
hash = hash.wrapping_mul(16_777_619);
}
if hash == 0 { 1 } else { hash }
}
fn check_type_id_collision(
seen_type_ids: &mut HashMap<u32, String>,
type_id: u32,
name: &str,
) -> Result<()> {
if let Some(previous) = seen_type_ids.insert(type_id, name.to_string()) {
return Err(NanoError::Catalog(format!(
"type id collision: '{}' and '{}' both hash to {}",
previous, name, type_id
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::catalog::build_catalog;
use crate::schema::parser::parse_schema;
#[test]
fn schema_ir_hash_is_stable_across_source_ordering_noise() {
let schema_a = parse_schema(
r#"
node Person {
age: I32?
name: String @key
}
edge Knows: Person -> Person {
since: Date?
}
"#,
)
.unwrap();
let schema_b = parse_schema(
r#"
edge Knows: Person -> Person {
since: Date?
}
node Person {
name: String @key
age: I32?
}
"#,
)
.unwrap();
let ir_a = build_schema_ir(&schema_a).unwrap();
let ir_b = build_schema_ir(&schema_b).unwrap();
assert_eq!(ir_a, ir_b);
assert_eq!(
schema_ir_hash(&ir_a).unwrap(),
schema_ir_hash(&ir_b).unwrap()
);
}
#[test]
fn build_catalog_from_ir_round_trips_core_catalog_fields() {
let schema = parse_schema(
r#"
node Person @description("person") {
name: String @key
age: I32? @description("age")
}
edge Knows: Person -> Person @instruction("friendship") {
since: Date?
}
"#,
)
.unwrap();
let direct = build_catalog(&schema).unwrap();
let ir = build_schema_ir(&schema).unwrap();
let rebuilt = build_catalog_from_ir(&ir).unwrap();
assert_eq!(direct.node_types.len(), rebuilt.node_types.len());
assert_eq!(direct.edge_types.len(), rebuilt.edge_types.len());
assert_eq!(
direct.node_types["Person"].key_property(),
rebuilt.node_types["Person"].key_property()
);
assert_eq!(
direct.edge_types["Knows"].cardinality,
rebuilt.edge_types["Knows"].cardinality
);
}
}

View file

@ -0,0 +1,895 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::schema::ast::{Annotation, Constraint};
use crate::types::PropType;
use super::schema_ir::{EdgeIR, InterfaceIR, NodeIR, PropertyIR, SchemaIR};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SchemaTypeKind {
Interface,
Node,
Edge,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SchemaMigrationPlan {
pub supported: bool,
pub steps: Vec<SchemaMigrationStep>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SchemaMigrationStep {
AddType {
type_kind: SchemaTypeKind,
name: String,
},
RenameType {
type_kind: SchemaTypeKind,
from: String,
to: String,
},
AddProperty {
type_kind: SchemaTypeKind,
type_name: String,
property_name: String,
property_type: PropType,
},
RenameProperty {
type_kind: SchemaTypeKind,
type_name: String,
from: String,
to: String,
},
AddConstraint {
type_kind: SchemaTypeKind,
type_name: String,
constraint: Constraint,
},
UpdateTypeMetadata {
type_kind: SchemaTypeKind,
name: String,
annotations: Vec<Annotation>,
},
UpdatePropertyMetadata {
type_kind: SchemaTypeKind,
type_name: String,
property_name: String,
annotations: Vec<Annotation>,
},
UnsupportedChange {
entity: String,
reason: String,
},
}
pub fn plan_schema_migration(
accepted: &SchemaIR,
desired: &SchemaIR,
) -> Result<SchemaMigrationPlan> {
let mut steps = Vec::new();
let interface_renames = plan_interfaces(&accepted.interfaces, &desired.interfaces, &mut steps);
let node_renames = plan_nodes(
&accepted.nodes,
&desired.nodes,
&interface_renames,
&mut steps,
);
plan_edges(&accepted.edges, &desired.edges, &node_renames, &mut steps);
Ok(SchemaMigrationPlan {
supported: !steps
.iter()
.any(|step| matches!(step, SchemaMigrationStep::UnsupportedChange { .. })),
steps,
})
}
fn plan_interfaces(
accepted: &[InterfaceIR],
desired: &[InterfaceIR],
steps: &mut Vec<SchemaMigrationStep>,
) -> HashMap<String, String> {
let accepted_by_name = accepted
.iter()
.map(|interface| (interface.name.as_str(), interface))
.collect::<HashMap<_, _>>();
let mut consumed = HashSet::new();
for interface in desired {
if let Some(existing) = accepted_by_name.get(interface.name.as_str()) {
consumed.insert(existing.name.clone());
let _property_renames = plan_properties(
SchemaTypeKind::Interface,
&interface.name,
&existing.properties,
&interface.properties,
steps,
);
continue;
}
steps.push(SchemaMigrationStep::AddType {
type_kind: SchemaTypeKind::Interface,
name: interface.name.clone(),
});
}
for leftover in accepted
.iter()
.filter(|interface| !consumed.contains(&interface.name))
{
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("interface:{}", leftover.name),
reason: format!(
"removing interface '{}' is not supported in schema migration v1",
leftover.name
),
});
}
HashMap::new()
}
fn plan_nodes(
accepted: &[NodeIR],
desired: &[NodeIR],
interface_renames: &HashMap<String, String>,
steps: &mut Vec<SchemaMigrationStep>,
) -> HashMap<String, String> {
let accepted_by_name = accepted
.iter()
.map(|node| (node.name.as_str(), node))
.collect::<HashMap<_, _>>();
let mut consumed = HashSet::new();
let mut renames = HashMap::new();
for node in desired {
let rename_from = rename_from_value(&node.annotations);
let matched = accepted_by_name
.get(node.name.as_str())
.copied()
.or_else(|| {
rename_from.and_then(|from| {
accepted_by_name
.get(from)
.copied()
.filter(|candidate| candidate.name != node.name)
})
});
let Some(existing) = matched else {
if let Some(from) = rename_from {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("node:{}", node.name),
reason: format!(
"node '{}' declares @rename_from(\"{}\") but no accepted node with that name exists",
node.name, from
),
});
} else {
steps.push(SchemaMigrationStep::AddType {
type_kind: SchemaTypeKind::Node,
name: node.name.clone(),
});
}
continue;
};
consumed.insert(existing.name.clone());
if existing.name != node.name {
renames.insert(existing.name.clone(), node.name.clone());
steps.push(SchemaMigrationStep::RenameType {
type_kind: SchemaTypeKind::Node,
from: existing.name.clone(),
to: node.name.clone(),
});
}
if normalize_strings(&existing.implements, interface_renames)
!= normalize_strings(&node.implements, &HashMap::new())
{
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("node:{}", node.name),
reason: format!(
"changing implemented interfaces on node '{}' is not supported in schema migration v1",
node.name
),
});
}
plan_type_metadata(
SchemaTypeKind::Node,
&node.name,
&existing.annotations,
&node.annotations,
steps,
);
let property_renames = plan_properties(
SchemaTypeKind::Node,
&node.name,
&existing.properties,
&node.properties,
steps,
);
plan_constraints(
SchemaTypeKind::Node,
&node.name,
&existing.constraints,
&node.constraints,
&property_renames,
steps,
);
}
for leftover in accepted
.iter()
.filter(|node| !consumed.contains(&node.name))
{
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("node:{}", leftover.name),
reason: format!(
"removing node type '{}' is not supported in schema migration v1",
leftover.name
),
});
}
renames
}
fn plan_edges(
accepted: &[EdgeIR],
desired: &[EdgeIR],
node_renames: &HashMap<String, String>,
steps: &mut Vec<SchemaMigrationStep>,
) {
let accepted_by_name = accepted
.iter()
.map(|edge| (edge.name.as_str(), edge))
.collect::<HashMap<_, _>>();
let mut consumed = HashSet::new();
for edge in desired {
let rename_from = rename_from_value(&edge.annotations);
let matched = accepted_by_name
.get(edge.name.as_str())
.copied()
.or_else(|| {
rename_from.and_then(|from| {
accepted_by_name
.get(from)
.copied()
.filter(|candidate| candidate.name != edge.name)
})
});
let Some(existing) = matched else {
if let Some(from) = rename_from {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("edge:{}", edge.name),
reason: format!(
"edge '{}' declares @rename_from(\"{}\") but no accepted edge with that name exists",
edge.name, from
),
});
} else {
steps.push(SchemaMigrationStep::AddType {
type_kind: SchemaTypeKind::Edge,
name: edge.name.clone(),
});
}
continue;
};
consumed.insert(existing.name.clone());
if existing.name != edge.name {
steps.push(SchemaMigrationStep::RenameType {
type_kind: SchemaTypeKind::Edge,
from: existing.name.clone(),
to: edge.name.clone(),
});
}
let normalized_from = normalize_type_ref(&existing.from_type, node_renames);
let normalized_to = normalize_type_ref(&existing.to_type, node_renames);
if normalized_from != edge.from_type || normalized_to != edge.to_type {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("edge:{}", edge.name),
reason: format!(
"changing edge endpoints on '{}' is not supported in schema migration v1",
edge.name
),
});
}
if existing.cardinality != edge.cardinality {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("edge:{}", edge.name),
reason: format!(
"changing cardinality on edge '{}' is not supported in schema migration v1",
edge.name
),
});
}
plan_type_metadata(
SchemaTypeKind::Edge,
&edge.name,
&existing.annotations,
&edge.annotations,
steps,
);
let property_renames = plan_properties(
SchemaTypeKind::Edge,
&edge.name,
&existing.properties,
&edge.properties,
steps,
);
plan_constraints(
SchemaTypeKind::Edge,
&edge.name,
&existing.constraints,
&edge.constraints,
&property_renames,
steps,
);
}
for leftover in accepted
.iter()
.filter(|edge| !consumed.contains(&edge.name))
{
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("edge:{}", leftover.name),
reason: format!(
"removing edge type '{}' is not supported in schema migration v1",
leftover.name
),
});
}
}
fn plan_properties(
type_kind: SchemaTypeKind,
type_name: &str,
accepted: &[PropertyIR],
desired: &[PropertyIR],
steps: &mut Vec<SchemaMigrationStep>,
) -> HashMap<String, String> {
let accepted_by_name = accepted
.iter()
.map(|property| (property.name.as_str(), property))
.collect::<HashMap<_, _>>();
let mut consumed = HashSet::new();
let mut renames = HashMap::new();
for property in desired {
let rename_from = rename_from_value(&property.annotations);
let matched = accepted_by_name
.get(property.name.as_str())
.copied()
.or_else(|| {
rename_from.and_then(|from| {
accepted_by_name
.get(from)
.copied()
.filter(|candidate| candidate.name != property.name)
})
});
let Some(existing) = matched else {
if let Some(from) = rename_from {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!(
"{}:{}.{}",
schema_type_kind_key(type_kind),
type_name,
property.name
),
reason: format!(
"property '{}.{}' declares @rename_from(\"{}\") but no accepted property with that name exists",
type_name, property.name, from
),
});
} else if property.prop_type.nullable {
steps.push(SchemaMigrationStep::AddProperty {
type_kind,
type_name: type_name.to_string(),
property_name: property.name.clone(),
property_type: property.prop_type.clone(),
});
} else {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!(
"{}:{}.{}",
schema_type_kind_key(type_kind),
type_name,
property.name
),
reason: format!(
"adding required property '{}.{}' requires a backfill and is not supported in schema migration v1",
type_name, property.name
),
});
}
continue;
};
consumed.insert(existing.name.clone());
if existing.name != property.name {
renames.insert(existing.name.clone(), property.name.clone());
steps.push(SchemaMigrationStep::RenameProperty {
type_kind,
type_name: type_name.to_string(),
from: existing.name.clone(),
to: property.name.clone(),
});
}
if existing.prop_type != property.prop_type {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!(
"{}:{}.{}",
schema_type_kind_key(type_kind),
type_name,
property.name
),
reason: format!(
"changing property type for '{}.{}' is not supported in schema migration v1",
type_name, property.name
),
});
}
plan_property_metadata(
type_kind,
type_name,
&property.name,
&existing.annotations,
&property.annotations,
steps,
);
}
for leftover in accepted
.iter()
.filter(|property| !consumed.contains(&property.name))
{
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!(
"{}:{}.{}",
schema_type_kind_key(type_kind),
type_name,
leftover.name
),
reason: format!(
"removing property '{}.{}' is not supported in schema migration v1",
type_name, leftover.name
),
});
}
renames
}
fn plan_constraints(
type_kind: SchemaTypeKind,
type_name: &str,
accepted: &[Constraint],
desired: &[Constraint],
property_renames: &HashMap<String, String>,
steps: &mut Vec<SchemaMigrationStep>,
) {
let accepted = accepted
.iter()
.cloned()
.map(|constraint| rename_constraint_properties(constraint, property_renames))
.collect::<Vec<_>>();
let desired_map = desired
.iter()
.cloned()
.map(|constraint| (constraint_key(&constraint), constraint))
.collect::<BTreeMap<_, _>>();
let accepted_map = accepted
.into_iter()
.map(|constraint| (constraint_key(&constraint), constraint))
.collect::<BTreeMap<_, _>>();
let removed = accepted_map
.keys()
.filter(|key| !desired_map.contains_key(*key))
.cloned()
.collect::<Vec<_>>();
if !removed.is_empty() {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("{}:{}", schema_type_kind_key(type_kind), type_name),
reason: format!(
"removing constraints from '{}' is not supported in schema migration v1",
type_name
),
});
}
for (key, constraint) in desired_map {
if accepted_map.contains_key(&key) {
continue;
}
match constraint {
Constraint::Index(_) => steps.push(SchemaMigrationStep::AddConstraint {
type_kind,
type_name: type_name.to_string(),
constraint,
}),
_ => steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("{}:{}", schema_type_kind_key(type_kind), type_name),
reason: format!(
"adding constraint '{}' to '{}' is not supported in schema migration v1",
key, type_name
),
}),
}
}
}
fn plan_type_metadata(
type_kind: SchemaTypeKind,
name: &str,
accepted: &[Annotation],
desired: &[Annotation],
steps: &mut Vec<SchemaMigrationStep>,
) {
match annotation_change_kind(accepted, desired) {
AnnotationChangeKind::None => {}
AnnotationChangeKind::MetadataOnly(metadata) => {
steps.push(SchemaMigrationStep::UpdateTypeMetadata {
type_kind,
name: name.to_string(),
annotations: metadata,
});
}
AnnotationChangeKind::Unsupported(reason) => {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!("{}:{}", schema_type_kind_key(type_kind), name),
reason,
});
}
}
}
fn plan_property_metadata(
type_kind: SchemaTypeKind,
type_name: &str,
property_name: &str,
accepted: &[Annotation],
desired: &[Annotation],
steps: &mut Vec<SchemaMigrationStep>,
) {
match annotation_change_kind(accepted, desired) {
AnnotationChangeKind::None => {}
AnnotationChangeKind::MetadataOnly(metadata) => {
steps.push(SchemaMigrationStep::UpdatePropertyMetadata {
type_kind,
type_name: type_name.to_string(),
property_name: property_name.to_string(),
annotations: metadata,
});
}
AnnotationChangeKind::Unsupported(reason) => {
steps.push(SchemaMigrationStep::UnsupportedChange {
entity: format!(
"{}:{}.{}",
schema_type_kind_key(type_kind),
type_name,
property_name
),
reason,
});
}
}
}
enum AnnotationChangeKind {
None,
MetadataOnly(Vec<Annotation>),
Unsupported(String),
}
fn annotation_change_kind(accepted: &[Annotation], desired: &[Annotation]) -> AnnotationChangeKind {
let accepted_non_metadata = strip_metadata_annotations(accepted);
let desired_non_metadata = strip_metadata_annotations(desired);
if accepted_non_metadata != desired_non_metadata {
return AnnotationChangeKind::Unsupported(
"changing annotations beyond @description/@instruction is not supported in schema migration v1"
.to_string(),
);
}
let accepted_metadata = metadata_annotations(accepted);
let desired_metadata = metadata_annotations(desired);
if accepted_metadata == desired_metadata {
AnnotationChangeKind::None
} else {
AnnotationChangeKind::MetadataOnly(desired_metadata)
}
}
fn strip_metadata_annotations(annotations: &[Annotation]) -> Vec<Annotation> {
annotations
.iter()
.filter(|annotation| {
!matches!(
annotation.name.as_str(),
"description" | "instruction" | "rename_from" | "key" | "unique" | "index"
)
})
.cloned()
.collect()
}
fn metadata_annotations(annotations: &[Annotation]) -> Vec<Annotation> {
annotations
.iter()
.filter(|annotation| matches!(annotation.name.as_str(), "description" | "instruction"))
.cloned()
.collect()
}
fn normalize_strings(values: &[String], renames: &HashMap<String, String>) -> BTreeSet<String> {
values
.iter()
.map(|value| normalize_type_ref(value, renames))
.collect()
}
fn normalize_type_ref(value: &str, renames: &HashMap<String, String>) -> String {
renames
.get(value)
.cloned()
.unwrap_or_else(|| value.to_string())
}
fn rename_constraint_properties(
constraint: Constraint,
property_renames: &HashMap<String, String>,
) -> Constraint {
match constraint {
Constraint::Key(columns) => {
Constraint::Key(rename_constraint_columns(columns, property_renames))
}
Constraint::Unique(columns) => {
Constraint::Unique(rename_constraint_columns(columns, property_renames))
}
Constraint::Index(columns) => {
Constraint::Index(rename_constraint_columns(columns, property_renames))
}
Constraint::Range { property, min, max } => Constraint::Range {
property: normalize_property_ref(&property, property_renames),
min,
max,
},
Constraint::Check { property, pattern } => Constraint::Check {
property: normalize_property_ref(&property, property_renames),
pattern,
},
}
}
fn rename_constraint_columns(
columns: Vec<String>,
property_renames: &HashMap<String, String>,
) -> Vec<String> {
let mut columns = columns
.into_iter()
.map(|column| normalize_property_ref(&column, property_renames))
.collect::<Vec<_>>();
columns.sort();
columns
}
fn normalize_property_ref(value: &str, renames: &HashMap<String, String>) -> String {
renames
.get(value)
.cloned()
.unwrap_or_else(|| value.to_string())
}
fn constraint_key(constraint: &Constraint) -> String {
match constraint {
Constraint::Key(columns) => format!("key:{}", columns.join(",")),
Constraint::Unique(columns) => format!("unique:{}", columns.join(",")),
Constraint::Index(columns) => format!("index:{}", columns.join(",")),
Constraint::Range { property, min, max } => {
format!("range:{}:{:?}:{:?}", property, min, max)
}
Constraint::Check { property, pattern } => format!("check:{}:{}", property, pattern),
}
}
fn rename_from_value(annotations: &[Annotation]) -> Option<&str> {
annotations
.iter()
.find(|annotation| annotation.name == "rename_from")
.and_then(|annotation| annotation.value.as_deref())
}
fn schema_type_kind_key(kind: SchemaTypeKind) -> &'static str {
match kind {
SchemaTypeKind::Interface => "interface",
SchemaTypeKind::Node => "node",
SchemaTypeKind::Edge => "edge",
}
}
#[cfg(test)]
mod tests {
use crate::catalog::schema_ir::build_schema_ir;
use crate::schema::parser::parse_schema;
use super::SchemaMigrationStep::{
AddConstraint, AddProperty, RenameProperty, RenameType, UnsupportedChange,
UpdateTypeMetadata,
};
use super::*;
#[test]
fn plan_supports_additive_nullable_property_and_index() {
let accepted = build_schema_ir(
&parse_schema(
r#"
node Person {
name: String @key
age: I32?
}
"#,
)
.unwrap(),
)
.unwrap();
let desired = build_schema_ir(
&parse_schema(
r#"
node Person {
name: String @key
age: I32? @index
nickname: String?
}
"#,
)
.unwrap(),
)
.unwrap();
let plan = plan_schema_migration(&accepted, &desired).unwrap();
assert!(plan.supported);
assert!(plan.steps.contains(&AddProperty {
type_kind: SchemaTypeKind::Node,
type_name: "Person".to_string(),
property_name: "nickname".to_string(),
property_type: PropType::scalar(crate::types::ScalarType::String, true),
}));
assert!(plan.steps.contains(&AddConstraint {
type_kind: SchemaTypeKind::Node,
type_name: "Person".to_string(),
constraint: Constraint::Index(vec!["age".to_string()]),
}));
}
#[test]
fn plan_supports_explicit_type_and_property_rename() {
let accepted = build_schema_ir(
&parse_schema(
r#"
node User {
name: String @key
}
"#,
)
.unwrap(),
)
.unwrap();
let desired = build_schema_ir(
&parse_schema(
r#"
node Account @rename_from("User") {
full_name: String @key @rename_from("name")
}
"#,
)
.unwrap(),
)
.unwrap();
let plan = plan_schema_migration(&accepted, &desired).unwrap();
assert!(plan.supported);
assert!(plan.steps.contains(&RenameType {
type_kind: SchemaTypeKind::Node,
from: "User".to_string(),
to: "Account".to_string(),
}));
assert!(plan.steps.contains(&RenameProperty {
type_kind: SchemaTypeKind::Node,
type_name: "Account".to_string(),
from: "name".to_string(),
to: "full_name".to_string(),
}));
}
#[test]
fn plan_rejects_required_property_addition() {
let accepted = build_schema_ir(
&parse_schema(
r#"
node Person {
name: String @key
}
"#,
)
.unwrap(),
)
.unwrap();
let desired = build_schema_ir(
&parse_schema(
r#"
node Person {
name: String @key
age: I32
}
"#,
)
.unwrap(),
)
.unwrap();
let plan = plan_schema_migration(&accepted, &desired).unwrap();
assert!(!plan.supported);
assert!(plan.steps.iter().any(|step| matches!(
step,
UnsupportedChange { entity, reason }
if entity.contains("Person.age")
&& reason.contains("adding required property")
)));
}
#[test]
fn plan_supports_metadata_only_annotation_changes() {
let accepted = build_schema_ir(
&parse_schema(
r#"
node Person @description("old") {
name: String @key
}
"#,
)
.unwrap(),
)
.unwrap();
let desired = build_schema_ir(
&parse_schema(
r#"
node Person @description("new") {
name: String @key
}
"#,
)
.unwrap(),
)
.unwrap();
let plan = plan_schema_migration(&accepted, &desired).unwrap();
assert!(plan.supported);
assert!(plan.steps.contains(&UpdateTypeMetadata {
type_kind: SchemaTypeKind::Node,
name: "Person".to_string(),
annotations: vec![Annotation {
name: "description".to_string(),
value: Some("new".to_string()),
}],
}));
}
}

View file

@ -0,0 +1,379 @@
#![allow(dead_code)]
use std::time::Duration;
use reqwest::Client;
use serde::Deserialize;
use tokio::time::sleep;
use crate::error::{NanoError, Result};
const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";
const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
const DEFAULT_TIMEOUT_MS: u64 = 30_000;
const DEFAULT_RETRY_ATTEMPTS: usize = 4;
const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
#[derive(Clone)]
enum EmbeddingTransport {
Mock,
OpenAi {
api_key: String,
base_url: String,
http: Client,
},
}
#[derive(Clone)]
pub(crate) struct EmbeddingClient {
model: String,
retry_attempts: usize,
retry_backoff_ms: u64,
transport: EmbeddingTransport,
}
struct EmbedCallError {
message: String,
retryable: bool,
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingResponse {
data: Vec<OpenAiEmbeddingDatum>,
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingDatum {
index: usize,
embedding: Vec<f32>,
}
#[derive(Debug, Deserialize)]
struct OpenAiErrorEnvelope {
error: OpenAiErrorBody,
}
#[derive(Debug, Deserialize)]
struct OpenAiErrorBody {
message: String,
}
impl EmbeddingClient {
pub(crate) fn from_env() -> Result<Self> {
let model = std::env::var("NANOGRAPH_EMBED_MODEL")
.ok()
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.unwrap_or_else(|| DEFAULT_EMBED_MODEL.to_string());
let retry_attempts =
parse_env_usize("NANOGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
let retry_backoff_ms =
parse_env_u64("NANOGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
if env_flag("NANOGRAPH_EMBEDDINGS_MOCK") {
return Ok(Self {
model,
retry_attempts,
retry_backoff_ms,
transport: EmbeddingTransport::Mock,
});
}
let api_key = std::env::var("OPENAI_API_KEY")
.ok()
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.ok_or_else(|| {
NanoError::Execution(
"OPENAI_API_KEY is required when an embedding call is needed".to_string(),
)
})?;
let base_url = std::env::var("OPENAI_BASE_URL")
.ok()
.map(|v| v.trim_end_matches('/').to_string())
.filter(|v| !v.is_empty())
.unwrap_or_else(|| DEFAULT_OPENAI_BASE_URL.to_string());
let timeout_ms = parse_env_u64("NANOGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
let http = Client::builder()
.timeout(Duration::from_millis(timeout_ms))
.build()
.map_err(|e| {
NanoError::Execution(format!("failed to initialize HTTP client: {}", e))
})?;
Ok(Self {
model,
retry_attempts,
retry_backoff_ms,
transport: EmbeddingTransport::OpenAi {
api_key,
base_url,
http,
},
})
}
#[cfg(test)]
pub(crate) fn mock_for_tests() -> Self {
Self {
model: DEFAULT_EMBED_MODEL.to_string(),
retry_attempts: DEFAULT_RETRY_ATTEMPTS,
retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
transport: EmbeddingTransport::Mock,
}
}
pub(crate) fn model(&self) -> &str {
&self.model
}
pub(crate) async fn embed_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
let mut vectors = self.embed_texts(&[input.to_string()], expected_dim).await?;
vectors.pop().ok_or_else(|| {
NanoError::Execution("embedding provider returned no vector".to_string())
})
}
pub(crate) async fn embed_texts(
&self,
inputs: &[String],
expected_dim: usize,
) -> Result<Vec<Vec<f32>>> {
if expected_dim == 0 {
return Err(NanoError::Execution(
"embedding dimension must be greater than zero".to_string(),
));
}
if inputs.is_empty() {
return Ok(Vec::new());
}
match &self.transport {
EmbeddingTransport::Mock => Ok(inputs
.iter()
.map(|input| mock_embedding(input, expected_dim))
.collect()),
EmbeddingTransport::OpenAi { .. } => {
self.embed_texts_openai_with_retry(inputs, expected_dim)
.await
}
}
}
async fn embed_texts_openai_with_retry(
&self,
inputs: &[String],
expected_dim: usize,
) -> Result<Vec<Vec<f32>>> {
let max_attempt = self.retry_attempts.max(1);
let mut attempt = 0usize;
loop {
attempt += 1;
match self.embed_texts_openai_once(inputs, expected_dim).await {
Ok(vectors) => return Ok(vectors),
Err(err) => {
if !err.retryable || attempt >= max_attempt {
return Err(NanoError::Execution(err.message));
}
let shift = (attempt - 1).min(10) as u32;
let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
sleep(Duration::from_millis(delay)).await;
}
}
}
}
async fn embed_texts_openai_once(
&self,
inputs: &[String],
expected_dim: usize,
) -> std::result::Result<Vec<Vec<f32>>, EmbedCallError> {
let (api_key, base_url, http) = match &self.transport {
EmbeddingTransport::OpenAi {
api_key,
base_url,
http,
} => (api_key, base_url, http),
EmbeddingTransport::Mock => unreachable!("mock transport should not call OpenAI"),
};
let request = serde_json::json!({
"model": self.model,
"input": inputs,
"dimensions": expected_dim,
});
let url = format!("{}/embeddings", base_url);
let response = http
.post(&url)
.bearer_auth(api_key)
.json(&request)
.send()
.await;
let response = match response {
Ok(resp) => resp,
Err(err) => {
let retryable = err.is_timeout() || err.is_connect() || err.is_request();
return Err(EmbedCallError {
message: format!("embedding request failed: {}", err),
retryable,
});
}
};
let status = response.status();
let body = match response.text().await {
Ok(body) => body,
Err(err) => {
return Err(EmbedCallError {
message: format!(
"embedding response read failed (status {}): {}",
status, err
),
retryable: status.is_server_error() || status.as_u16() == 429,
});
}
};
if !status.is_success() {
let message = parse_openai_error_message(&body).unwrap_or_else(|| body.clone());
return Err(EmbedCallError {
message: format!(
"embedding request failed with status {}: {}",
status, message
),
retryable: status.is_server_error() || status.as_u16() == 429,
});
}
let mut parsed: OpenAiEmbeddingResponse =
serde_json::from_str(&body).map_err(|err| EmbedCallError {
message: format!("embedding response decode failed: {}", err),
retryable: false,
})?;
if parsed.data.len() != inputs.len() {
return Err(EmbedCallError {
message: format!(
"embedding response size mismatch: expected {}, got {}",
inputs.len(),
parsed.data.len()
),
retryable: false,
});
}
parsed.data.sort_by_key(|item| item.index);
let mut vectors = Vec::with_capacity(parsed.data.len());
for (idx, item) in parsed.data.into_iter().enumerate() {
if item.index != idx {
return Err(EmbedCallError {
message: format!(
"embedding response index mismatch at position {}: got {}",
idx, item.index
),
retryable: false,
});
}
if item.embedding.len() != expected_dim {
return Err(EmbedCallError {
message: format!(
"embedding dimension mismatch: expected {}, got {}",
expected_dim,
item.embedding.len()
),
retryable: false,
});
}
vectors.push(item.embedding);
}
Ok(vectors)
}
}
fn parse_openai_error_message(body: &str) -> Option<String> {
serde_json::from_str::<OpenAiErrorEnvelope>(body)
.ok()
.map(|e| e.error.message)
.filter(|msg| !msg.trim().is_empty())
}
fn parse_env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v > 0)
.unwrap_or(default)
}
fn parse_env_u64(name: &str, default: u64) -> u64 {
std::env::var(name)
.ok()
.and_then(|v| v.parse::<u64>().ok())
.filter(|v| *v > 0)
.unwrap_or(default)
}
fn env_flag(name: &str) -> bool {
std::env::var(name)
.ok()
.map(|v| {
let s = v.trim().to_ascii_lowercase();
s == "1" || s == "true" || s == "yes" || s == "on"
})
.unwrap_or(false)
}
fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
let mut seed = fnv1a64(input.as_bytes());
let mut out = Vec::with_capacity(dim);
for _ in 0..dim {
seed = xorshift64(seed);
let ratio = (seed as f64 / u64::MAX as f64) as f32;
out.push((ratio * 2.0) - 1.0);
}
let norm = out
.iter()
.map(|v| (*v as f64) * (*v as f64))
.sum::<f64>()
.sqrt() as f32;
if norm > f32::EPSILON {
for value in &mut out {
*value /= norm;
}
}
out
}
fn fnv1a64(bytes: &[u8]) -> u64 {
let mut hash = 14695981039346656037u64;
for byte in bytes {
hash ^= *byte as u64;
hash = hash.wrapping_mul(1099511628211u64);
}
hash
}
fn xorshift64(mut x: u64) -> u64 {
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
x
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn mock_embeddings_are_deterministic() {
let client = EmbeddingClient::mock_for_tests();
let a = client.embed_text("alpha", 8).await.unwrap();
let b = client.embed_text("alpha", 8).await.unwrap();
let c = client.embed_text("beta", 8).await.unwrap();
assert_eq!(a, b);
assert_ne!(a, c);
assert_eq!(a.len(), 8);
}
}

View file

@ -0,0 +1,146 @@
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SourceSpan {
pub start: usize,
pub end: usize,
}
impl SourceSpan {
pub fn new(start: usize, end: usize) -> Self {
Self { start, end }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParseDiagnostic {
pub message: String,
pub span: Option<SourceSpan>,
}
impl ParseDiagnostic {
pub fn new(message: String, span: Option<SourceSpan>) -> Self {
Self { message, span }
}
}
impl std::fmt::Display for ParseDiagnostic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for ParseDiagnostic {}
pub fn render_span(span: SourceSpan) -> SourceSpan {
SourceSpan {
start: span.start,
end: span.end.max(span.start.saturating_add(1)),
}
}
pub fn decode_string_literal(raw: &str) -> Result<String> {
let inner = raw
.strip_prefix('"')
.and_then(|inner| inner.strip_suffix('"'))
.unwrap_or(raw);
let mut decoded = String::with_capacity(inner.len());
let mut chars = inner.chars();
while let Some(ch) = chars.next() {
if ch != '\\' {
decoded.push(ch);
continue;
}
let escaped = chars
.next()
.ok_or_else(|| NanoError::Parse("unterminated escape sequence".to_string()))?;
match escaped {
'"' => decoded.push('"'),
'\\' => decoded.push('\\'),
'n' => decoded.push('\n'),
'r' => decoded.push('\r'),
't' => decoded.push('\t'),
other => {
return Err(NanoError::Parse(format!(
"unsupported escape sequence: \\{}",
other
)));
}
}
}
Ok(decoded)
}
#[derive(Debug, Error)]
pub enum NanoError {
#[error("parse error: {0}")]
Parse(String),
#[error("catalog error: {0}")]
Catalog(String),
#[error("type error: {0}")]
Type(String),
#[error("storage error: {0}")]
Storage(String),
#[error(
"@unique constraint violation on {type_name}.{property}: duplicate value '{value}' at rows {first_row} and {second_row}"
)]
UniqueConstraint {
type_name: String,
property: String,
value: String,
first_row: usize,
second_row: usize,
},
#[error("plan error: {0}")]
Plan(String),
#[error("execution error: {0}")]
Execution(String),
#[error(transparent)]
Arrow(#[from] arrow_schema::ArrowError),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("lance error: {0}")]
Lance(String),
#[error("manifest error: {0}")]
Manifest(String),
}
pub type Result<T> = std::result::Result<T, NanoError>;
#[cfg(test)]
mod tests {
use super::{SourceSpan, decode_string_literal, render_span};
#[test]
fn source_span_preserves_zero_width() {
let span = SourceSpan::new(7, 7);
assert_eq!(span.start, 7);
assert_eq!(span.end, 7);
}
#[test]
fn render_span_widens_zero_width_for_diagnostics() {
let rendered = render_span(SourceSpan::new(7, 7));
assert_eq!(rendered.start, 7);
assert_eq!(rendered.end, 8);
}
#[test]
fn decode_string_literal_supports_common_escapes() {
let decoded = decode_string_literal("\"a\\n\\r\\t\\\\\\\"b\"").unwrap();
assert_eq!(decoded, "a\n\r\t\\\"b");
}
}

View file

@ -0,0 +1,657 @@
use std::collections::HashSet;
use crate::catalog::Catalog;
use crate::error::Result;
use crate::query::ast::*;
use crate::query::typecheck::TypeContext;
use crate::types::Direction;
use super::*;
pub fn lower_query(
catalog: &Catalog,
query: &QueryDecl,
type_ctx: &TypeContext,
) -> Result<QueryIR> {
if query.mutation.is_some() {
return Err(crate::error::NanoError::Plan(
"cannot lower mutation query with read-query lowerer".to_string(),
));
}
let param_names: HashSet<String> = query.params.iter().map(|p| p.name.clone()).collect();
let mut pipeline = Vec::new();
let mut bound_vars = HashSet::new();
lower_clauses(
catalog,
&query.match_clause,
type_ctx,
&mut pipeline,
&mut bound_vars,
&param_names,
)?;
let return_exprs: Vec<IRProjection> = query
.return_clause
.iter()
.map(|p| IRProjection {
expr: lower_expr(&p.expr, &param_names),
alias: p.alias.clone(),
})
.collect();
let order_by: Vec<IROrdering> = query
.order_clause
.iter()
.map(|o| IROrdering {
expr: lower_expr(&o.expr, &param_names),
descending: o.descending,
})
.collect();
Ok(QueryIR {
name: query.name.clone(),
params: query.params.clone(),
pipeline,
return_exprs,
order_by,
limit: query.limit,
})
}
pub fn lower_mutation_query(query: &QueryDecl) -> Result<MutationIR> {
let mutation = query.mutation.as_ref().ok_or_else(|| {
crate::error::NanoError::Plan("query does not contain a mutation body".to_string())
})?;
let param_names: HashSet<String> = query.params.iter().map(|p| p.name.clone()).collect();
let op = match mutation {
Mutation::Insert(insert) => MutationOpIR::Insert {
type_name: insert.type_name.clone(),
assignments: insert
.assignments
.iter()
.map(|a| IRAssignment {
property: a.property.clone(),
value: lower_match_value(&a.value, &param_names),
})
.collect(),
},
Mutation::Update(update) => MutationOpIR::Update {
type_name: update.type_name.clone(),
assignments: update
.assignments
.iter()
.map(|a| IRAssignment {
property: a.property.clone(),
value: lower_match_value(&a.value, &param_names),
})
.collect(),
predicate: IRMutationPredicate {
property: update.predicate.property.clone(),
op: update.predicate.op,
value: lower_match_value(&update.predicate.value, &param_names),
},
},
Mutation::Delete(delete) => MutationOpIR::Delete {
type_name: delete.type_name.clone(),
predicate: IRMutationPredicate {
property: delete.predicate.property.clone(),
op: delete.predicate.op,
value: lower_match_value(&delete.predicate.value, &param_names),
},
},
};
Ok(MutationIR {
name: query.name.clone(),
params: query.params.clone(),
op,
})
}
fn lower_clauses(
catalog: &Catalog,
clauses: &[Clause],
type_ctx: &TypeContext,
pipeline: &mut Vec<IROp>,
bound_vars: &mut HashSet<String>,
param_names: &HashSet<String>,
) -> Result<()> {
// Separate clause types for ordering: bindings first, then traversals, then filters
let mut bindings = Vec::new();
let mut traversals = Vec::new();
let mut filters = Vec::new();
let mut negations = Vec::new();
for clause in clauses {
match clause {
Clause::Binding(b) => bindings.push(b),
Clause::Traversal(t) => traversals.push(t),
Clause::Filter(f) => filters.push(f),
Clause::Negation(inner) => negations.push(inner),
}
}
// Lower bindings into NodeScan ops
for binding in &bindings {
let node_type = catalog
.node_types
.get(&binding.type_name)
.expect("binding type was validated during typecheck");
// Collect inline filters from prop matches
let mut scan_filters = Vec::new();
for pm in &binding.prop_matches {
let prop = node_type
.properties
.get(&pm.prop_name)
.expect("binding property was validated during typecheck");
let op = if prop.list {
CompOp::Contains
} else {
CompOp::Eq
};
match &pm.value {
MatchValue::Literal(lit) => {
scan_filters.push(IRFilter {
left: IRExpr::PropAccess {
variable: binding.variable.clone(),
property: pm.prop_name.clone(),
},
op,
right: IRExpr::Literal(lit.clone()),
});
}
MatchValue::Now => {
scan_filters.push(IRFilter {
left: IRExpr::PropAccess {
variable: binding.variable.clone(),
property: pm.prop_name.clone(),
},
op,
right: IRExpr::Param(NOW_PARAM_NAME.to_string()),
});
}
MatchValue::Variable(v) => {
let right = if param_names.contains(v) {
IRExpr::Param(v.clone())
} else {
IRExpr::Variable(v.clone())
};
scan_filters.push(IRFilter {
left: IRExpr::PropAccess {
variable: binding.variable.clone(),
property: pm.prop_name.clone(),
},
op,
right,
});
}
}
}
pipeline.push(IROp::NodeScan {
variable: binding.variable.clone(),
type_name: binding.type_name.clone(),
filters: scan_filters,
});
bound_vars.insert(binding.variable.clone());
}
// Lower traversals into Expand ops
// Handle "cycle closing" — if both src and dst are already bound, use a filter
for traversal in &traversals {
let edge = catalog
.lookup_edge_by_name(&traversal.edge_name)
.ok_or_else(|| {
crate::error::NanoError::Plan(format!(
"lowering traversal referenced missing edge '{}' after typecheck",
traversal.edge_name
))
})?;
// Determine direction from type context
let direction = type_ctx
.traversals
.iter()
.find(|rt| {
rt.src == traversal.src && rt.dst == traversal.dst && rt.edge_type == edge.name
})
.map(|rt| rt.direction)
.unwrap_or(Direction::Out);
let dst_type = match direction {
Direction::Out => edge.to_type.clone(),
Direction::In => edge.from_type.clone(),
};
if bound_vars.contains(&traversal.src) && bound_vars.contains(&traversal.dst) {
// Cycle closing: emit expand to a temp var, then filter temp.id = dst.id
let temp_var = format!("__temp_{}", traversal.dst);
pipeline.push(IROp::Expand {
src_var: traversal.src.clone(),
dst_var: temp_var.clone(),
edge_type: edge.name.clone(),
direction,
dst_type,
min_hops: traversal.min_hops,
max_hops: traversal.max_hops,
});
pipeline.push(IROp::Filter(IRFilter {
left: IRExpr::PropAccess {
variable: temp_var,
property: "id".to_string(),
},
op: CompOp::Eq,
right: IRExpr::PropAccess {
variable: traversal.dst.clone(),
property: "id".to_string(),
},
}));
} else if !bound_vars.contains(&traversal.src) && bound_vars.contains(&traversal.dst) {
// Reverse expand: dst is bound, src is not.
// Swap direction and expand from dst to discover src.
let reverse_dir = match direction {
Direction::Out => Direction::In,
Direction::In => Direction::Out,
};
let src_type = match direction {
Direction::Out => edge.from_type.clone(),
Direction::In => edge.to_type.clone(),
};
pipeline.push(IROp::Expand {
src_var: traversal.dst.clone(),
dst_var: traversal.src.clone(),
edge_type: edge.name.clone(),
direction: reverse_dir,
dst_type: src_type,
min_hops: traversal.min_hops,
max_hops: traversal.max_hops,
});
if traversal.src != "_" {
bound_vars.insert(traversal.src.clone());
}
} else {
pipeline.push(IROp::Expand {
src_var: traversal.src.clone(),
dst_var: traversal.dst.clone(),
edge_type: edge.name.clone(),
direction,
dst_type,
min_hops: traversal.min_hops,
max_hops: traversal.max_hops,
});
if traversal.dst != "_" {
bound_vars.insert(traversal.dst.clone());
}
}
}
// Lower explicit filters
for filter in &filters {
pipeline.push(IROp::Filter(IRFilter {
left: lower_expr(&filter.left, param_names),
op: filter.op,
right: lower_expr(&filter.right, param_names),
}));
}
// Lower negations into AntiJoin ops
for neg_clauses in &negations {
// Find outer-bound variable referenced in the negation
let outer_var = find_outer_var(neg_clauses, bound_vars);
let mut inner_pipeline = Vec::new();
let mut inner_bound = bound_vars.clone();
lower_clauses(
catalog,
neg_clauses,
type_ctx,
&mut inner_pipeline,
&mut inner_bound,
param_names,
)?;
pipeline.push(IROp::AntiJoin {
outer_var: outer_var.unwrap_or_default(),
inner: inner_pipeline,
});
}
Ok(())
}
fn find_outer_var(clauses: &[Clause], outer_bound: &HashSet<String>) -> Option<String> {
for clause in clauses {
match clause {
Clause::Traversal(t) => {
if outer_bound.contains(&t.src) {
return Some(t.src.clone());
}
if outer_bound.contains(&t.dst) {
return Some(t.dst.clone());
}
}
Clause::Filter(f) => {
if let Some(v) = expr_var(&f.left)
&& outer_bound.contains(&v)
{
return Some(v);
}
if let Some(v) = expr_var(&f.right)
&& outer_bound.contains(&v)
{
return Some(v);
}
}
Clause::Binding(b) => {
if outer_bound.contains(&b.variable) {
return Some(b.variable.clone());
}
}
_ => {}
}
}
None
}
fn expr_var(expr: &Expr) -> Option<String> {
match expr {
Expr::Now => None,
Expr::PropAccess { variable, .. } => Some(variable.clone()),
Expr::Variable(v) => Some(v.clone()),
Expr::Nearest { variable, .. } => Some(variable.clone()),
Expr::Search { field, query } => expr_var(field).or_else(|| expr_var(query)),
Expr::Fuzzy {
field,
query,
max_edits,
} => expr_var(field)
.or_else(|| expr_var(query))
.or_else(|| max_edits.as_deref().and_then(expr_var)),
Expr::MatchText { field, query } => expr_var(field).or_else(|| expr_var(query)),
Expr::Bm25 { field, query } => expr_var(field).or_else(|| expr_var(query)),
Expr::Rrf {
primary,
secondary,
k,
} => expr_var(primary)
.or_else(|| expr_var(secondary))
.or_else(|| k.as_deref().and_then(expr_var)),
Expr::Aggregate { arg, .. } => expr_var(arg),
_ => None,
}
}
fn lower_expr(expr: &Expr, param_names: &HashSet<String>) -> IRExpr {
match expr {
Expr::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
Expr::PropAccess { variable, property } => IRExpr::PropAccess {
variable: variable.clone(),
property: property.clone(),
},
Expr::Nearest {
variable,
property,
query,
} => IRExpr::Nearest {
variable: variable.clone(),
property: property.clone(),
query: Box::new(lower_expr(query, param_names)),
},
Expr::Search { field, query } => IRExpr::Search {
field: Box::new(lower_expr(field, param_names)),
query: Box::new(lower_expr(query, param_names)),
},
Expr::Fuzzy {
field,
query,
max_edits,
} => IRExpr::Fuzzy {
field: Box::new(lower_expr(field, param_names)),
query: Box::new(lower_expr(query, param_names)),
max_edits: max_edits
.as_ref()
.map(|expr| Box::new(lower_expr(expr, param_names))),
},
Expr::MatchText { field, query } => IRExpr::MatchText {
field: Box::new(lower_expr(field, param_names)),
query: Box::new(lower_expr(query, param_names)),
},
Expr::Bm25 { field, query } => IRExpr::Bm25 {
field: Box::new(lower_expr(field, param_names)),
query: Box::new(lower_expr(query, param_names)),
},
Expr::Rrf {
primary,
secondary,
k,
} => IRExpr::Rrf {
primary: Box::new(lower_expr(primary, param_names)),
secondary: Box::new(lower_expr(secondary, param_names)),
k: k.as_ref()
.map(|expr| Box::new(lower_expr(expr, param_names))),
},
Expr::Variable(v) => {
if param_names.contains(v) {
IRExpr::Param(v.clone())
} else {
IRExpr::Variable(v.clone())
}
}
Expr::Literal(l) => IRExpr::Literal(l.clone()),
Expr::Aggregate { func, arg } => IRExpr::Aggregate {
func: *func,
arg: Box::new(lower_expr(arg, param_names)),
},
Expr::AliasRef(name) => IRExpr::AliasRef(name.clone()),
}
}
fn lower_match_value(value: &MatchValue, param_names: &HashSet<String>) -> IRExpr {
match value {
MatchValue::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
MatchValue::Literal(l) => IRExpr::Literal(l.clone()),
MatchValue::Variable(v) => {
if param_names.contains(v) {
IRExpr::Param(v.clone())
} else {
IRExpr::Variable(v.clone())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::catalog::build_catalog;
use crate::query::parser::parse_query;
use crate::query::typecheck::{CheckedQuery, typecheck_query, typecheck_query_decl};
use crate::schema::parser::parse_schema;
fn setup() -> Catalog {
let schema = parse_schema(
r#"
node Person { name: String age: I32? }
node Company { name: String }
edge Knows: Person -> Person { since: Date? }
edge WorksAt: Person -> Company
"#,
)
.unwrap();
build_catalog(&schema).unwrap()
}
#[test]
fn test_lower_basic() {
let catalog = setup();
let qf = parse_query(
r#"
query q($name: String) {
match {
$p: Person { name: $name }
$p knows $f
}
return { $f.name, $f.age }
}
"#,
)
.unwrap();
let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap();
let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap();
assert_eq!(ir.pipeline.len(), 2); // NodeScan + Expand
assert_eq!(ir.return_exprs.len(), 2);
}
#[test]
fn test_lower_negation() {
let catalog = setup();
let qf = parse_query(
r#"
query q() {
match {
$p: Person
not { $p worksAt $_ }
}
return { $p.name }
}
"#,
)
.unwrap();
let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap();
let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap();
assert_eq!(ir.pipeline.len(), 2); // NodeScan + AntiJoin
assert!(matches!(&ir.pipeline[1], IROp::AntiJoin { .. }));
}
#[test]
fn test_lower_mutation_update() {
let catalog = setup();
let qf = parse_query(
r#"
query q($name: String, $age: I32) {
update Person set { age: $age } where name = $name
}
"#,
)
.unwrap();
let checked = typecheck_query_decl(&catalog, &qf.queries[0]).unwrap();
assert!(matches!(checked, CheckedQuery::Mutation(_)));
let ir = lower_mutation_query(&qf.queries[0]).unwrap();
match ir.op {
MutationOpIR::Update {
type_name,
assignments,
predicate,
} => {
assert_eq!(type_name, "Person");
assert_eq!(assignments.len(), 1);
assert_eq!(assignments[0].property, "age");
assert_eq!(predicate.property, "name");
}
_ => panic!("expected update mutation op"),
}
}
#[test]
fn test_lower_bounded_traversal() {
let catalog = setup();
let qf = parse_query(
r#"
query q() {
match {
$p: Person
$p knows{1,3} $f
}
return { $f.name }
}
"#,
)
.unwrap();
let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap();
let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap();
let expand = ir
.pipeline
.iter()
.find_map(|op| match op {
IROp::Expand {
min_hops, max_hops, ..
} => Some((*min_hops, *max_hops)),
_ => None,
})
.expect("expected expand op");
assert_eq!(expand.0, 1);
assert_eq!(expand.1, Some(3));
}
#[test]
fn test_lower_now_uses_reserved_runtime_param() {
let catalog = setup();
let qf = parse_query(
r#"
query stamp() {
match { $p: Person }
return { now() as ts }
}
"#,
)
.unwrap();
let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap();
let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap();
assert!(matches!(
ir.return_exprs[0].expr,
IRExpr::Param(ref name) if name == NOW_PARAM_NAME
));
}
#[test]
fn test_lower_mutation_now_uses_reserved_runtime_param() {
let catalog = build_catalog(
&parse_schema(
r#"
node Event {
slug: String @key
updated_at: DateTime?
}
"#,
)
.unwrap(),
)
.unwrap();
let qf = parse_query(
r#"
query stamp() {
update Event set { updated_at: now() } where updated_at = now()
}
"#,
)
.unwrap();
let checked = typecheck_query_decl(&catalog, &qf.queries[0]).unwrap();
assert!(matches!(checked, CheckedQuery::Mutation(_)));
let ir = lower_mutation_query(&qf.queries[0]).unwrap();
match ir.op {
MutationOpIR::Update {
assignments,
predicate,
..
} => {
assert!(matches!(
assignments[0].value,
IRExpr::Param(ref name) if name == NOW_PARAM_NAME
));
assert!(matches!(
predicate.value,
IRExpr::Param(ref name) if name == NOW_PARAM_NAME
));
}
_ => panic!("expected update mutation op"),
}
}
}

View file

@ -0,0 +1,143 @@
pub(crate) mod lower;
use std::collections::HashMap;
use crate::query::ast::{AggFunc, CompOp, Literal, Param};
use crate::types::Direction;
#[derive(Debug, Clone)]
pub struct QueryIR {
pub name: String,
pub params: Vec<Param>,
pub pipeline: Vec<IROp>,
pub return_exprs: Vec<IRProjection>,
pub order_by: Vec<IROrdering>,
pub limit: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct MutationIR {
pub name: String,
pub params: Vec<Param>,
pub op: MutationOpIR,
}
#[derive(Debug, Clone)]
pub enum MutationOpIR {
Insert {
type_name: String,
assignments: Vec<IRAssignment>,
},
Update {
type_name: String,
assignments: Vec<IRAssignment>,
predicate: IRMutationPredicate,
},
Delete {
type_name: String,
predicate: IRMutationPredicate,
},
}
#[derive(Debug, Clone)]
pub struct IRAssignment {
pub property: String,
pub value: IRExpr,
}
#[derive(Debug, Clone)]
pub struct IRMutationPredicate {
pub property: String,
pub op: CompOp,
pub value: IRExpr,
}
/// Resolved runtime parameters: param name → literal value.
pub type ParamMap = HashMap<String, Literal>;
#[derive(Debug, Clone)]
pub enum IROp {
NodeScan {
variable: String,
type_name: String,
filters: Vec<IRFilter>,
},
Expand {
src_var: String,
dst_var: String,
edge_type: String,
direction: Direction,
dst_type: String,
min_hops: u32,
max_hops: Option<u32>,
},
Filter(IRFilter),
AntiJoin {
/// The outer variable whose id is used for the join key
outer_var: String,
/// The inner pipeline that produces rows to anti-join against
inner: Vec<IROp>,
},
}
#[derive(Debug, Clone)]
pub struct IRFilter {
pub left: IRExpr,
pub op: CompOp,
pub right: IRExpr,
}
#[derive(Debug, Clone)]
pub enum IRExpr {
PropAccess {
variable: String,
property: String,
},
Nearest {
variable: String,
property: String,
query: Box<IRExpr>,
},
Search {
field: Box<IRExpr>,
query: Box<IRExpr>,
},
Fuzzy {
field: Box<IRExpr>,
query: Box<IRExpr>,
max_edits: Option<Box<IRExpr>>,
},
MatchText {
field: Box<IRExpr>,
query: Box<IRExpr>,
},
Bm25 {
field: Box<IRExpr>,
query: Box<IRExpr>,
},
Rrf {
primary: Box<IRExpr>,
secondary: Box<IRExpr>,
k: Option<Box<IRExpr>>,
},
Variable(String),
Param(String),
Literal(Literal),
Aggregate {
func: AggFunc,
arg: Box<IRExpr>,
},
AliasRef(String),
}
#[derive(Debug, Clone)]
pub struct IRProjection {
pub expr: IRExpr,
pub alias: Option<String>,
}
#[derive(Debug, Clone)]
pub struct IROrdering {
pub expr: IRExpr,
pub descending: bool,
}

View file

@ -0,0 +1,352 @@
use arrow_array::{
Array, ArrayRef, BooleanArray, Date32Array, Date64Array, FixedSizeListArray, Float32Array,
Float64Array, Int32Array, Int64Array, ListArray, RecordBatch, StringArray, StructArray,
UInt32Array, UInt64Array,
};
use arrow_schema::DataType;
pub const JS_MAX_SAFE_INTEGER_I64: i64 = 9_007_199_254_740_991;
pub const JS_MAX_SAFE_INTEGER_U64: u64 = 9_007_199_254_740_991;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum JsonIntegerMode {
JavaScript,
Native,
}
pub fn is_js_safe_integer_i64(value: i64) -> bool {
(-JS_MAX_SAFE_INTEGER_I64..=JS_MAX_SAFE_INTEGER_I64).contains(&value)
}
/// Convert Arrow RecordBatches into a Vec of JSON objects (one per row).
pub fn record_batches_to_json_rows(results: &[RecordBatch]) -> Vec<serde_json::Value> {
record_batches_to_json_rows_with_mode(results, JsonIntegerMode::JavaScript)
}
/// Convert Arrow RecordBatches into JSON rows without JS-safe integer coercion.
pub fn record_batches_to_rust_json_rows(results: &[RecordBatch]) -> Vec<serde_json::Value> {
record_batches_to_json_rows_with_mode(results, JsonIntegerMode::Native)
}
fn record_batches_to_json_rows_with_mode(
results: &[RecordBatch],
integer_mode: JsonIntegerMode,
) -> Vec<serde_json::Value> {
let total_rows = results.iter().map(RecordBatch::num_rows).sum();
let mut out = Vec::with_capacity(total_rows);
for batch in results {
let schema = batch.schema();
for row in 0..batch.num_rows() {
let mut map = serde_json::Map::new();
for (col_idx, field) in schema.fields().iter().enumerate() {
let col_arr = batch.column(col_idx);
map.insert(
field.name().clone(),
array_value_to_json_with_mode(col_arr, row, integer_mode),
);
}
out.push(serde_json::Value::Object(map));
}
}
out
}
/// Convert a single cell from an Arrow array to a serde_json::Value.
pub fn array_value_to_json(array: &ArrayRef, row: usize) -> serde_json::Value {
array_value_to_json_with_mode(array, row, JsonIntegerMode::JavaScript)
}
fn array_value_to_json_with_mode(
array: &ArrayRef,
row: usize,
integer_mode: JsonIntegerMode,
) -> serde_json::Value {
if array.is_null(row) {
return serde_json::Value::Null;
}
match array.data_type() {
DataType::Utf8 => array
.as_any()
.downcast_ref::<StringArray>()
.map(|a| serde_json::Value::String(a.value(row).to_string()))
.unwrap_or(serde_json::Value::Null),
DataType::Boolean => array
.as_any()
.downcast_ref::<BooleanArray>()
.map(|a| serde_json::Value::Bool(a.value(row)))
.unwrap_or(serde_json::Value::Null),
DataType::Int32 => array
.as_any()
.downcast_ref::<Int32Array>()
.map(|a| serde_json::Value::Number((a.value(row) as i64).into()))
.unwrap_or(serde_json::Value::Null),
DataType::Int64 => array
.as_any()
.downcast_ref::<Int64Array>()
.map(|a| {
let value = a.value(row);
match integer_mode {
JsonIntegerMode::JavaScript if !is_js_safe_integer_i64(value) => {
serde_json::Value::String(value.to_string())
}
JsonIntegerMode::JavaScript | JsonIntegerMode::Native => {
serde_json::Value::Number(value.into())
}
}
})
.unwrap_or(serde_json::Value::Null),
DataType::UInt32 => array
.as_any()
.downcast_ref::<UInt32Array>()
.map(|a| serde_json::Value::Number((a.value(row) as u64).into()))
.unwrap_or(serde_json::Value::Null),
DataType::UInt64 => array
.as_any()
.downcast_ref::<UInt64Array>()
.map(|a| {
let value = a.value(row);
match integer_mode {
JsonIntegerMode::JavaScript if value > JS_MAX_SAFE_INTEGER_U64 => {
serde_json::Value::String(value.to_string())
}
JsonIntegerMode::JavaScript | JsonIntegerMode::Native => {
serde_json::Value::Number(value.into())
}
}
})
.unwrap_or(serde_json::Value::Null),
DataType::Float32 => array
.as_any()
.downcast_ref::<Float32Array>()
.map(|a| json_float_value(a.value(row) as f64))
.unwrap_or(serde_json::Value::Null),
DataType::Float64 => array
.as_any()
.downcast_ref::<Float64Array>()
.map(|a| json_float_value(a.value(row)))
.unwrap_or(serde_json::Value::Null),
DataType::Date32 => array
.as_any()
.downcast_ref::<Date32Array>()
.map(|a| {
let days = a.value(row);
arrow_array::temporal_conversions::date32_to_datetime(days)
.map(|dt| serde_json::Value::String(dt.format("%Y-%m-%d").to_string()))
.unwrap_or_else(|| serde_json::Value::Number((days as i64).into()))
})
.unwrap_or(serde_json::Value::Null),
DataType::Date64 => array
.as_any()
.downcast_ref::<Date64Array>()
.map(|a| {
let ms = a.value(row);
arrow_array::temporal_conversions::date64_to_datetime(ms)
.map(|dt| {
serde_json::Value::String(dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string())
})
.unwrap_or_else(|| serde_json::Value::Number(ms.into()))
})
.unwrap_or(serde_json::Value::Null),
DataType::List(_) => array
.as_any()
.downcast_ref::<ListArray>()
.map(|a| {
let values = a.value(row);
serde_json::Value::Array(
(0..values.len())
.map(|idx| array_value_to_json_with_mode(&values, idx, integer_mode))
.collect(),
)
})
.unwrap_or(serde_json::Value::Null),
DataType::FixedSizeList(_, _) => array
.as_any()
.downcast_ref::<FixedSizeListArray>()
.map(|a| fixed_size_list_value_to_json(a, row, integer_mode))
.unwrap_or(serde_json::Value::Null),
DataType::Struct(_) => array
.as_any()
.downcast_ref::<StructArray>()
.map(|struct_arr| {
let mut obj = serde_json::Map::new();
for (i, field) in struct_arr.fields().iter().enumerate() {
let col = struct_arr.column(i);
obj.insert(
field.name().clone(),
array_value_to_json_with_mode(col, row, integer_mode),
);
}
serde_json::Value::Object(obj)
})
.unwrap_or(serde_json::Value::Null),
_ => {
let display =
arrow_cast::display::array_value_to_string(array, row).unwrap_or_default();
serde_json::Value::String(display)
}
}
}
fn json_float_value(value: f64) -> serde_json::Value {
if value.is_nan() {
return serde_json::Value::String("NaN".to_string());
}
if value == f64::INFINITY {
return serde_json::Value::String("Infinity".to_string());
}
if value == f64::NEG_INFINITY {
return serde_json::Value::String("-Infinity".to_string());
}
serde_json::Number::from_f64(value)
.map(serde_json::Value::Number)
.unwrap_or(serde_json::Value::Null)
}
fn fixed_size_list_value_to_json(
array: &FixedSizeListArray,
row: usize,
integer_mode: JsonIntegerMode,
) -> serde_json::Value {
let value_len = array.value_length() as usize;
let values = array.values();
if let Some(float_values) = values.as_any().downcast_ref::<Float32Array>() {
let start = row.saturating_mul(value_len);
return float32_json_array(float_values, start, value_len);
}
let values = array.value(row);
serde_json::Value::Array(
(0..values.len())
.map(|idx| array_value_to_json_with_mode(&values, idx, integer_mode))
.collect(),
)
}
fn float32_json_array(values: &Float32Array, start: usize, len: usize) -> serde_json::Value {
let mut out = Vec::with_capacity(len);
let end = start.saturating_add(len).min(values.len());
for idx in start..end {
if values.is_null(idx) {
out.push(serde_json::Value::Null);
continue;
}
let value = values.value(idx) as f64;
out.push(
serde_json::Number::from_f64(value)
.map(serde_json::Value::Number)
.unwrap_or(serde_json::Value::Null),
);
}
serde_json::Value::Array(out)
}
#[cfg(test)]
mod tests {
use super::{array_value_to_json, record_batches_to_rust_json_rows};
use std::sync::Arc;
use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
use arrow_array::{ArrayRef, Float64Array, Int64Array, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
#[test]
fn int64_outside_js_safe_range_is_stringified() {
let values: ArrayRef = Arc::new(Int64Array::from(vec![Some(9_007_199_254_740_992)]));
assert_eq!(
array_value_to_json(&values, 0),
serde_json::Value::String("9007199254740992".to_string())
);
}
#[test]
fn uint64_outside_js_safe_range_is_stringified() {
let values: ArrayRef = Arc::new(UInt64Array::from(vec![Some(9_007_199_254_740_992)]));
assert_eq!(
array_value_to_json(&values, 0),
serde_json::Value::String("9007199254740992".to_string())
);
}
#[test]
fn uint64_within_js_safe_range_stays_numeric() {
let values: ArrayRef = Arc::new(UInt64Array::from(vec![Some(9_007_199_254_740_991)]));
assert_eq!(
array_value_to_json(&values, 0),
serde_json::json!(9_007_199_254_740_991u64)
);
}
#[test]
fn rust_json_rows_preserve_full_width_integers() {
let schema = Arc::new(Schema::new(vec![
Field::new("signed", DataType::Int64, false),
Field::new("unsigned", DataType::UInt64, false),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![i64::MIN])),
Arc::new(UInt64Array::from(vec![u64::MAX])),
],
)
.expect("batch");
assert_eq!(
record_batches_to_rust_json_rows(&[batch]),
vec![serde_json::json!({
"signed": i64::MIN,
"unsigned": u64::MAX,
})]
);
}
#[test]
fn fixed_size_float32_vectors_serialize_without_recursive_dispatch() {
let mut builder = FixedSizeListBuilder::new(Float32Builder::new(), 3);
builder.values().append_value(0.25);
builder.values().append_value(0.5);
builder.values().append_value(0.75);
builder.append(true);
for _ in 0..3 {
builder.values().append_null();
}
builder.append(false);
builder.values().append_value(1.0);
builder.values().append_value(2.0);
builder.values().append_value(3.0);
builder.append(true);
let values: ArrayRef = Arc::new(builder.finish());
assert_eq!(
array_value_to_json(&values, 0),
serde_json::json!([0.25, 0.5, 0.75])
);
assert_eq!(array_value_to_json(&values, 1), serde_json::Value::Null);
assert_eq!(
array_value_to_json(&values, 2),
serde_json::json!([1.0, 2.0, 3.0])
);
}
#[test]
fn non_finite_floats_are_stringified() {
let values: ArrayRef = Arc::new(Float64Array::from(vec![
Some(f64::NAN),
Some(f64::INFINITY),
Some(f64::NEG_INFINITY),
]));
assert_eq!(array_value_to_json(&values, 0), serde_json::json!("NaN"));
assert_eq!(
array_value_to_json(&values, 1),
serde_json::json!("Infinity")
);
assert_eq!(
array_value_to_json(&values, 2),
serde_json::json!("-Infinity")
);
}
}

View file

@ -0,0 +1,28 @@
pub mod catalog;
pub mod embedding;
pub mod error;
pub mod ir;
pub mod json_output;
pub mod query;
pub mod query_input;
pub mod result;
pub mod schema;
pub mod types;
pub use catalog::build_catalog;
pub use catalog::schema_ir::{
SchemaIR, build_catalog_from_ir, build_schema_ir, schema_ir_hash, schema_ir_json,
schema_ir_pretty_json,
};
pub use catalog::schema_plan::{
SchemaMigrationPlan, SchemaMigrationStep, SchemaTypeKind, plan_schema_migration,
};
pub use ir::ParamMap;
pub use ir::lower::{lower_mutation_query, lower_query};
pub use query::ast::Literal;
pub use query_input::{
JsonParamMode, RunInputError, RunInputResult, ToParam, find_named_query,
json_params_to_param_map,
};
pub use result::{MutationExecResult, MutationResult, QueryResult, RunResult};
pub use types::{Direction, PropType, ScalarType};

View file

@ -0,0 +1,221 @@
pub const NOW_PARAM_NAME: &str = "__nanograph_now";
#[derive(Debug, Clone)]
pub struct QueryFile {
pub queries: Vec<QueryDecl>,
}
#[derive(Debug, Clone)]
pub struct QueryDecl {
pub name: String,
pub description: Option<String>,
pub instruction: Option<String>,
pub params: Vec<Param>,
pub match_clause: Vec<Clause>,
pub return_clause: Vec<Projection>,
pub order_clause: Vec<Ordering>,
pub limit: Option<u64>,
pub mutation: Option<Mutation>,
}
#[derive(Debug, Clone)]
pub struct Param {
pub name: String,
pub type_name: String,
pub nullable: bool,
}
#[derive(Debug, Clone)]
pub enum Clause {
Binding(Binding),
Traversal(Traversal),
Filter(Filter),
Negation(Vec<Clause>),
}
#[derive(Debug, Clone)]
pub struct Binding {
pub variable: String,
pub type_name: String,
pub prop_matches: Vec<PropMatch>,
}
#[derive(Debug, Clone)]
pub struct PropMatch {
pub prop_name: String,
pub value: MatchValue,
}
#[derive(Debug, Clone)]
pub enum MatchValue {
Literal(Literal),
Variable(String),
Now,
}
#[derive(Debug, Clone)]
pub struct Traversal {
pub src: String,
pub edge_name: String,
pub dst: String,
pub min_hops: u32,
pub max_hops: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct Filter {
pub left: Expr,
pub op: CompOp,
pub right: Expr,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompOp {
Eq,
Ne,
Gt,
Lt,
Ge,
Le,
Contains,
}
impl std::fmt::Display for CompOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Eq => write!(f, "="),
Self::Ne => write!(f, "!="),
Self::Gt => write!(f, ">"),
Self::Lt => write!(f, "<"),
Self::Ge => write!(f, ">="),
Self::Le => write!(f, "<="),
Self::Contains => write!(f, "contains"),
}
}
}
#[derive(Debug, Clone)]
pub enum Expr {
Now,
PropAccess {
variable: String,
property: String,
},
Nearest {
variable: String,
property: String,
query: Box<Expr>,
},
Search {
field: Box<Expr>,
query: Box<Expr>,
},
Fuzzy {
field: Box<Expr>,
query: Box<Expr>,
max_edits: Option<Box<Expr>>,
},
MatchText {
field: Box<Expr>,
query: Box<Expr>,
},
Bm25 {
field: Box<Expr>,
query: Box<Expr>,
},
Rrf {
primary: Box<Expr>,
secondary: Box<Expr>,
k: Option<Box<Expr>>,
},
Variable(String),
Literal(Literal),
Aggregate {
func: AggFunc,
arg: Box<Expr>,
},
AliasRef(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggFunc {
Count,
Sum,
Avg,
Min,
Max,
}
impl std::fmt::Display for AggFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Count => write!(f, "count"),
Self::Sum => write!(f, "sum"),
Self::Avg => write!(f, "avg"),
Self::Min => write!(f, "min"),
Self::Max => write!(f, "max"),
}
}
}
#[derive(Debug, Clone)]
pub enum Literal {
String(String),
Integer(i64),
Float(f64),
Bool(bool),
Date(String),
DateTime(String),
List(Vec<Literal>),
}
#[derive(Debug, Clone)]
pub struct Projection {
pub expr: Expr,
pub alias: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Ordering {
pub expr: Expr,
pub descending: bool,
}
#[derive(Debug, Clone)]
pub enum Mutation {
Insert(InsertMutation),
Update(UpdateMutation),
Delete(DeleteMutation),
}
#[derive(Debug, Clone)]
pub struct InsertMutation {
pub type_name: String,
pub assignments: Vec<MutationAssignment>,
}
#[derive(Debug, Clone)]
pub struct UpdateMutation {
pub type_name: String,
pub assignments: Vec<MutationAssignment>,
pub predicate: MutationPredicate,
}
#[derive(Debug, Clone)]
pub struct DeleteMutation {
pub type_name: String,
pub predicate: MutationPredicate,
}
#[derive(Debug, Clone)]
pub struct MutationAssignment {
pub property: String,
pub value: MatchValue,
}
#[derive(Debug, Clone)]
pub struct MutationPredicate {
pub property: String,
pub op: CompOp,
pub value: MatchValue,
}

View file

@ -0,0 +1,3 @@
pub mod ast;
pub mod parser;
pub mod typecheck;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,114 @@
// NanoGraph Query Grammar (.gq files)
WHITESPACE = _{ " " | "\t" | "\r" | "\n" }
COMMENT = _{ LINE_COMMENT | BLOCK_COMMENT }
LINE_COMMENT = _{ "//" ~ (!"\n" ~ ANY)* }
BLOCK_COMMENT = _{ "/*" ~ (!"*/" ~ ANY)* ~ "*/" }
query_file = { SOI ~ query_decl* ~ EOI }
query_decl = {
"query" ~ ident ~ "(" ~ param_list? ~ ")" ~ query_annotation* ~ "{"
~ query_body
~ "}"
}
query_annotation = { description_annotation | instruction_annotation }
description_annotation = { "@description" ~ "(" ~ string_lit ~ ")" }
instruction_annotation = { "@instruction" ~ "(" ~ string_lit ~ ")" }
query_body = { read_query_body | mutation_stmt }
read_query_body = {
match_clause
~ return_clause
~ order_clause?
~ limit_clause?
}
mutation_stmt = { insert_stmt | update_stmt | delete_stmt }
insert_stmt = { "insert" ~ type_name ~ "{" ~ mutation_assignment+ ~ "}" }
update_stmt = { "update" ~ type_name ~ "set" ~ "{" ~ mutation_assignment+ ~ "}" ~ "where" ~ mutation_predicate }
delete_stmt = { "delete" ~ type_name ~ "where" ~ mutation_predicate }
mutation_assignment = { ident ~ ":" ~ match_value ~ ","? }
mutation_predicate = { ident ~ comp_op ~ match_value }
param_list = { param ~ ("," ~ param)* }
param = { variable ~ ":" ~ type_ref }
type_ref = { (list_type | base_type | vector_type) ~ "?"? }
list_type = { "[" ~ base_type ~ "]" }
vector_type = { "Vector" ~ "(" ~ integer ~ ")" }
base_type = { "String" | "Blob" | "Bool" | "I32" | "I64" | "U32" | "U64" | "F32" | "F64" | "DateTime" | "Date" }
match_clause = { "match" ~ "{" ~ clause+ ~ "}" }
clause = { negation | binding | traversal | filter | text_search_clause }
text_search_clause = { search_call | fuzzy_call | match_text_call }
// Binding: $p: Person { name: "Alice" }
binding = { variable ~ ":" ~ type_name ~ ("{" ~ prop_match_list ~ "}")? }
prop_match_list = { prop_match ~ ("," ~ prop_match)* ~ ","? }
prop_match = { ident ~ ":" ~ match_value }
match_value = { literal | variable | now_call }
// Traversal: $p knows $f
traversal = { variable ~ edge_ident ~ traversal_bounds? ~ variable }
traversal_bounds = { "{" ~ integer ~ "," ~ integer? ~ "}" }
// Filter: $f.age > 25
filter = { expr ~ filter_op ~ expr }
// Negation: not { ... }
negation = { "not" ~ "{" ~ clause+ ~ "}" }
// Return clause — projections separated by commas or newlines
return_clause = { "return" ~ "{" ~ projection+ ~ "}" }
projection = { expr ~ ("as" ~ ident)? ~ ","? }
// Order clause
order_clause = { "order" ~ "{" ~ ordering ~ ("," ~ ordering)* ~ "}" }
ordering = { nearest_ordering | (expr ~ order_dir?) }
nearest_ordering = { "nearest" ~ "(" ~ prop_access ~ "," ~ expr ~ ")" }
order_dir = { "asc" | "desc" }
// Limit clause
limit_clause = { "limit" ~ integer }
// Expressions
expr = { now_call | nearest_ordering | search_call | fuzzy_call | match_text_call | bm25_call | rrf_call | agg_call | prop_access | variable | literal | ident }
now_call = { "now" ~ "(" ~ ")" }
search_call = { "search" ~ "(" ~ expr ~ "," ~ expr ~ ")" }
fuzzy_call = { "fuzzy" ~ "(" ~ expr ~ "," ~ expr ~ ("," ~ expr)? ~ ")" }
match_text_call = { "match_text" ~ "(" ~ expr ~ "," ~ expr ~ ")" }
bm25_call = { "bm25" ~ "(" ~ expr ~ "," ~ expr ~ ")" }
rank_expr = { nearest_ordering | bm25_call }
rrf_call = { "rrf" ~ "(" ~ rank_expr ~ "," ~ rank_expr ~ ("," ~ expr)? ~ ")" }
prop_access = { variable ~ "." ~ ident }
agg_call = { agg_func ~ "(" ~ expr ~ ")" }
agg_func = { "count" | "sum" | "avg" | "min" | "max" }
comp_op = { ">=" | "<=" | "!=" | ">" | "<" | "=" }
filter_op = { "contains" | comp_op }
// Terminals
variable = @{ "$" ~ (ident_chars | "_") }
ident_chars = @{ (ASCII_ALPHA_LOWER | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
// Edge identifier — lowercase start, same as ident but used in traversal context
// Must not match keywords
edge_ident = @{ !("not" ~ !ASCII_ALPHANUMERIC) ~ (ASCII_ALPHA_LOWER | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
type_name = @{ ASCII_ALPHA_UPPER ~ (ASCII_ALPHANUMERIC | "_")* }
ident = @{ (ASCII_ALPHA_LOWER | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
literal = { list_lit | datetime_lit | date_lit | string_lit | float_lit | integer | bool_lit }
date_lit = { "date" ~ "(" ~ string_lit ~ ")" }
datetime_lit = { "datetime" ~ "(" ~ string_lit ~ ")" }
list_lit = { "[" ~ (literal ~ ("," ~ literal)*)? ~ "]" }
string_lit = @{ "\"" ~ string_char* ~ "\"" }
string_char = @{ !("\"" | "\\") ~ ANY | "\\" ~ ANY }
float_lit = @{ ASCII_DIGIT+ ~ "." ~ ASCII_DIGIT+ }
integer = @{ ASCII_DIGIT+ }
bool_lit = { "true" | "false" }

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,892 @@
use std::error::Error;
use std::fmt;
use serde_json::Value;
use crate::error::NanoError;
use crate::ir::ParamMap;
use crate::json_output::{JS_MAX_SAFE_INTEGER_U64, is_js_safe_integer_i64};
use crate::query::ast::{Literal, Param, QueryDecl};
use crate::query::parser::parse_query;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JsonParamMode {
Standard,
JavaScript,
}
#[derive(Debug)]
pub enum RunInputError {
Core(NanoError),
Message(String),
}
impl RunInputError {
fn message(message: impl Into<String>) -> Self {
Self::Message(message.into())
}
}
impl fmt::Display for RunInputError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Core(err) => err.fmt(f),
Self::Message(message) => f.write_str(message),
}
}
}
impl Error for RunInputError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Core(err) => Some(err),
Self::Message(_) => None,
}
}
}
impl From<NanoError> for RunInputError {
fn from(value: NanoError) -> Self {
Self::Core(value)
}
}
pub type RunInputResult<T> = std::result::Result<T, RunInputError>;
pub trait ToParam {
fn to_param(self) -> crate::error::Result<Literal>;
}
impl ToParam for Literal {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(self)
}
}
impl ToParam for &Literal {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(self.clone())
}
}
impl ToParam for String {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::String(self))
}
}
impl ToParam for &String {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::String(self.clone()))
}
}
impl ToParam for &str {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::String(self.to_string()))
}
}
impl ToParam for bool {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::Bool(self))
}
}
impl ToParam for i8 {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::Integer(i64::from(self)))
}
}
impl ToParam for i16 {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::Integer(i64::from(self)))
}
}
impl ToParam for i32 {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::Integer(i64::from(self)))
}
}
impl ToParam for i64 {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::Integer(self))
}
}
impl ToParam for isize {
fn to_param(self) -> crate::error::Result<Literal> {
let value = i64::try_from(self).map_err(|_| {
NanoError::Execution(format!(
"param value {} exceeds current engine range for numeric literals (max {})",
self,
i64::MAX
))
})?;
Ok(Literal::Integer(value))
}
}
impl ToParam for u8 {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::Integer(i64::from(self)))
}
}
impl ToParam for u16 {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::Integer(i64::from(self)))
}
}
impl ToParam for u32 {
fn to_param(self) -> crate::error::Result<Literal> {
Ok(Literal::Integer(i64::from(self)))
}
}
impl ToParam for u64 {
fn to_param(self) -> crate::error::Result<Literal> {
let value = i64::try_from(self).map_err(|_| {
NanoError::Execution(format!(
"param value {} exceeds current engine range for numeric literals (max {})",
self,
i64::MAX
))
})?;
Ok(Literal::Integer(value))
}
}
impl ToParam for usize {
fn to_param(self) -> crate::error::Result<Literal> {
let value = i64::try_from(self).map_err(|_| {
NanoError::Execution(format!(
"param value {} exceeds current engine range for numeric literals (max {})",
self,
i64::MAX
))
})?;
Ok(Literal::Integer(value))
}
}
impl ToParam for f32 {
fn to_param(self) -> crate::error::Result<Literal> {
if !self.is_finite() {
return Err(NanoError::Execution(format!(
"invalid float parameter {}",
self
)));
}
Ok(Literal::Float(f64::from(self)))
}
}
impl ToParam for f64 {
fn to_param(self) -> crate::error::Result<Literal> {
if !self.is_finite() {
return Err(NanoError::Execution(format!(
"invalid float parameter {}",
self
)));
}
Ok(Literal::Float(self))
}
}
impl<T> ToParam for Vec<T>
where
T: ToParam,
{
fn to_param(self) -> crate::error::Result<Literal> {
let mut out = Vec::with_capacity(self.len());
for value in self {
out.push(value.to_param()?);
}
Ok(Literal::List(out))
}
}
impl<T> ToParam for &[T]
where
T: Clone + ToParam,
{
fn to_param(self) -> crate::error::Result<Literal> {
let mut out = Vec::with_capacity(self.len());
for value in self {
out.push(value.clone().to_param()?);
}
Ok(Literal::List(out))
}
}
impl<T, const N: usize> ToParam for [T; N]
where
T: ToParam,
{
fn to_param(self) -> crate::error::Result<Literal> {
let mut out = Vec::with_capacity(N);
for value in self {
out.push(value.to_param()?);
}
Ok(Literal::List(out))
}
}
#[macro_export]
macro_rules! params {
() => {
::std::result::Result::Ok($crate::ParamMap::new())
};
($($key:expr => $value:expr),+ $(,)?) => {{
(|| -> $crate::error::Result<$crate::ParamMap> {
let mut map = $crate::ParamMap::new();
$(
map.insert(::std::convert::Into::<String>::into($key), $crate::ToParam::to_param($value)?);
)+
Ok(map)
})()
}};
}
pub fn find_named_query(query_source: &str, query_name: &str) -> RunInputResult<QueryDecl> {
let queries = parse_query(query_source)?;
queries
.queries
.into_iter()
.find(|query| query.name == query_name)
.ok_or_else(|| RunInputError::message(format!("query '{}' not found", query_name)))
}
pub fn json_params_to_param_map(
params: Option<&Value>,
query_params: &[Param],
mode: JsonParamMode,
) -> RunInputResult<ParamMap> {
let mut map = ParamMap::new();
let object = match params {
Some(Value::Object(object)) => object,
Some(Value::Null) | None => return Ok(map),
Some(other) => {
let message = match mode {
JsonParamMode::Standard => "params must be a JSON object".to_string(),
JsonParamMode::JavaScript => {
format!("params must be an object, got {}", json_type_name(other))
}
};
return Err(RunInputError::message(message));
}
};
for (key, value) in object {
let decl = query_params.iter().find(|param| param.name == *key);
let literal = if let Some(decl) = decl {
json_value_to_literal_typed(key, value, &decl.type_name, mode)?
} else {
json_value_to_literal_inferred(key, value, mode)?
};
map.insert(key.clone(), literal);
}
Ok(map)
}
fn json_value_to_literal_typed(
key: &str,
value: &Value,
type_name: &str,
mode: JsonParamMode,
) -> RunInputResult<Literal> {
match type_name {
"String" => match value {
Value::String(value) => Ok(Literal::String(value.clone())),
other => Err(RunInputError::message(format!(
"param '{}': expected string, got {}",
key,
json_type_name(other)
))),
},
"I32" => match mode {
JsonParamMode::Standard => {
let value = parse_i64_param(key, value, mode)?;
let value = i32::try_from(value).map_err(|_| {
RunInputError::message(format!("param '{}': value {} exceeds I32", key, value))
})?;
Ok(Literal::Integer(i64::from(value)))
}
JsonParamMode::JavaScript => {
let value = parse_i64_param(key, value, mode)?;
let value = i32::try_from(value).map_err(|_| {
RunInputError::message(format!(
"param '{}': value {} exceeds I32 range",
key, value
))
})?;
Ok(Literal::Integer(i64::from(value)))
}
},
"I64" => Ok(Literal::Integer(parse_i64_param(key, value, mode)?)),
"U32" => {
let value = parse_u64_param(key, value, mode)?;
let value = match mode {
JsonParamMode::Standard => u32::try_from(value).map_err(|_| {
RunInputError::message(format!("param '{}': value {} exceeds U32", key, value))
})?,
JsonParamMode::JavaScript => u32::try_from(value).map_err(|_| {
RunInputError::message(format!(
"param '{}': value {} exceeds U32 range",
key, value
))
})?,
};
Ok(Literal::Integer(i64::from(value)))
}
"U64" => {
let value = parse_u64_param(key, value, mode)?;
let value = match mode {
JsonParamMode::Standard => i64::try_from(value).map_err(|_| {
RunInputError::message(format!(
"param '{}': value {} exceeds current engine range for U64 (max {})",
key,
value,
i64::MAX
))
})?,
JsonParamMode::JavaScript => i64::try_from(value).map_err(|_| {
RunInputError::message(format!(
"param '{}': value {} exceeds current engine range for U64 parameters (max {})",
key,
value,
i64::MAX
))
})?,
};
Ok(Literal::Integer(value))
}
"F32" | "F64" => {
let value = value.as_f64().ok_or_else(|| match mode {
JsonParamMode::Standard => {
RunInputError::message(format!("param '{}': expected float", key))
}
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': expected float, got {}",
key,
json_type_name(value)
)),
})?;
Ok(Literal::Float(value))
}
"Bool" => {
let value = value.as_bool().ok_or_else(|| match mode {
JsonParamMode::Standard => {
RunInputError::message(format!("param '{}': expected boolean", key))
}
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': expected boolean, got {}",
key,
json_type_name(value)
)),
})?;
Ok(Literal::Bool(value))
}
"Date" => match value {
Value::String(value) => Ok(Literal::Date(value.clone())),
other => Err(match mode {
JsonParamMode::Standard => {
RunInputError::message(format!("param '{}': expected date string", key))
}
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': expected date string, got {}",
key,
json_type_name(other)
)),
}),
},
"DateTime" => match value {
Value::String(value) => Ok(Literal::DateTime(value.clone())),
other => Err(match mode {
JsonParamMode::Standard => {
RunInputError::message(format!("param '{}': expected datetime string", key))
}
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': expected datetime string, got {}",
key,
json_type_name(other)
)),
}),
},
"Blob" => match value {
Value::String(value) => Ok(Literal::String(value.clone())),
other => Err(RunInputError::message(format!(
"param '{}': expected blob URI string, got {}",
key,
json_type_name(other)
))),
},
other if parse_list_item_type(other).is_some() => {
let item_type = parse_list_item_type(other).unwrap();
let items = value.as_array().ok_or_else(|| match mode {
JsonParamMode::Standard => {
RunInputError::message(format!("param '{}': expected array for {}", key, other))
}
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': expected array for {}, got {}",
key,
other,
json_type_name(value)
)),
})?;
let mut out = Vec::with_capacity(items.len());
for item in items {
out.push(json_value_to_literal_typed(key, item, item_type, mode)?);
}
Ok(Literal::List(out))
}
other if other.starts_with("Vector(") => {
let expected_dim = parse_vector_dim(other).ok_or_else(|| match mode {
JsonParamMode::Standard => RunInputError::message(format!(
"param '{}': invalid vector type '{}'",
key, other
)),
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': invalid vector type '{}' (expected Vector(N))",
key, other
)),
})?;
let items = value.as_array().ok_or_else(|| match mode {
JsonParamMode::Standard => {
RunInputError::message(format!("param '{}': expected array for {}", key, other))
}
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': expected array for {}, got {}",
key,
other,
json_type_name(value)
)),
})?;
if items.len() != expected_dim {
return Err(RunInputError::message(format!(
"param '{}': expected {} values for {}, got {}",
key,
expected_dim,
other,
items.len()
)));
}
let mut out = Vec::with_capacity(items.len());
for item in items {
let value = item.as_f64().ok_or_else(|| match mode {
JsonParamMode::Standard => RunInputError::message(format!(
"param '{}': vector element is not numeric",
key
)),
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': vector element '{}' is not numeric",
key, item
)),
})?;
out.push(Literal::Float(value));
}
Ok(Literal::List(out))
}
_ => match value {
Value::String(value) => Ok(Literal::String(value.clone())),
other => Err(RunInputError::message(format!(
"param '{}': expected string for type '{}', got {}",
key,
type_name,
json_type_name(other)
))),
},
}
}
fn json_value_to_literal_inferred(
key: &str,
value: &Value,
mode: JsonParamMode,
) -> RunInputResult<Literal> {
match value {
Value::String(value) => Ok(Literal::String(value.clone())),
Value::Bool(value) => Ok(Literal::Bool(*value)),
Value::Number(number) => match mode {
JsonParamMode::Standard => {
if let Some(value) = number.as_i64() {
Ok(Literal::Integer(value))
} else if let Some(value) = number.as_f64() {
Ok(Literal::Float(value))
} else {
Err(RunInputError::message(format!(
"param '{}': unsupported numeric value",
key
)))
}
}
JsonParamMode::JavaScript => {
if let Some(value) = number.as_i64() {
if !is_js_safe_integer_i64(value) {
return Err(RunInputError::message(format!(
"param '{}': integer {} exceeds JS safe integer range; use a decimal string and a typed query parameter for exact values",
key, value
)));
}
Ok(Literal::Integer(value))
} else if let Some(value) = number.as_u64() {
if value > JS_MAX_SAFE_INTEGER_U64 {
return Err(RunInputError::message(format!(
"param '{}': integer {} exceeds JS safe integer range; use a decimal string and a typed query parameter for exact values",
key, value
)));
}
let value = i64::try_from(value).map_err(|_| {
RunInputError::message(format!(
"param '{}': integer {} exceeds supported range (max {})",
key,
value,
i64::MAX
))
})?;
Ok(Literal::Integer(value))
} else if let Some(value) = number.as_f64() {
Ok(Literal::Float(value))
} else {
Err(RunInputError::message(format!(
"param '{}': unsupported number value",
key
)))
}
}
},
Value::Array(values) => {
let mut out = Vec::with_capacity(values.len());
for value in values {
out.push(json_value_to_literal_inferred(key, value, mode)?);
}
Ok(Literal::List(out))
}
Value::Null => Err(match mode {
JsonParamMode::Standard => {
RunInputError::message(format!("param '{}': null is not supported", key))
}
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': null values are not supported as query parameters",
key
)),
}),
Value::Object(_) => Err(match mode {
JsonParamMode::Standard => {
RunInputError::message(format!("param '{}': object is not supported", key))
}
JsonParamMode::JavaScript => RunInputError::message(format!(
"param '{}': object values are not supported as query parameters",
key
)),
}),
}
}
fn parse_i64_param(key: &str, value: &Value, mode: JsonParamMode) -> RunInputResult<i64> {
match mode {
JsonParamMode::Standard => match value {
Value::Number(number) => number.as_i64().ok_or_else(|| {
RunInputError::message(format!("param '{}': expected integer number", key))
}),
Value::String(value) => value.parse::<i64>().map_err(|_| {
RunInputError::message(format!(
"param '{}': expected integer string, got '{}'",
key, value
))
}),
_ => Err(RunInputError::message(format!(
"param '{}': expected integer",
key
))),
},
JsonParamMode::JavaScript => match value {
Value::Number(number) => {
let parsed = if let Some(parsed) = number.as_i64() {
parsed
} else if let Some(parsed) = number.as_f64() {
if !parsed.is_finite() || parsed.fract() != 0.0 {
return Err(RunInputError::message(format!(
"param '{}': expected integer, got number",
key
)));
}
if parsed < i64::MIN as f64 || parsed > i64::MAX as f64 {
return Err(RunInputError::message(format!(
"param '{}': integer {} is outside i64 range",
key, parsed
)));
}
parsed as i64
} else {
return Err(RunInputError::message(format!(
"param '{}': expected integer, got number",
key
)));
};
if !is_js_safe_integer_i64(parsed) {
return Err(RunInputError::message(format!(
"param '{}': integer {} exceeds JS safe integer range; pass a decimal string for exact values",
key, parsed
)));
}
Ok(parsed)
}
Value::String(value) => value.parse::<i64>().map_err(|_| {
RunInputError::message(format!(
"param '{}': expected integer string, got '{}'",
key, value
))
}),
other => Err(RunInputError::message(format!(
"param '{}': expected integer, got {}",
key,
json_type_name(other)
))),
},
}
}
fn parse_u64_param(key: &str, value: &Value, mode: JsonParamMode) -> RunInputResult<u64> {
match mode {
JsonParamMode::Standard => match value {
Value::Number(number) => number.as_u64().ok_or_else(|| {
RunInputError::message(format!("param '{}': expected unsigned integer number", key))
}),
Value::String(value) => value.parse::<u64>().map_err(|_| {
RunInputError::message(format!(
"param '{}': expected unsigned integer string, got '{}'",
key, value
))
}),
_ => Err(RunInputError::message(format!(
"param '{}': expected unsigned integer",
key
))),
},
JsonParamMode::JavaScript => match value {
Value::Number(number) => {
let parsed = if let Some(parsed) = number.as_u64() {
parsed
} else if let Some(parsed) = number.as_f64() {
if !parsed.is_finite() || parsed.fract() != 0.0 || parsed < 0.0 {
return Err(RunInputError::message(format!(
"param '{}': expected unsigned integer, got number",
key
)));
}
if parsed > u64::MAX as f64 {
return Err(RunInputError::message(format!(
"param '{}': integer {} is outside u64 range",
key, parsed
)));
}
parsed as u64
} else {
return Err(RunInputError::message(format!(
"param '{}': expected unsigned integer, got number",
key
)));
};
if parsed > JS_MAX_SAFE_INTEGER_U64 {
return Err(RunInputError::message(format!(
"param '{}': integer {} exceeds JS safe integer range; pass a decimal string for exact values",
key, parsed
)));
}
Ok(parsed)
}
Value::String(value) => value.parse::<u64>().map_err(|_| {
RunInputError::message(format!(
"param '{}': expected unsigned integer string, got '{}'",
key, value
))
}),
other => Err(RunInputError::message(format!(
"param '{}': expected unsigned integer, got {}",
key,
json_type_name(other)
))),
},
}
}
fn parse_vector_dim(type_name: &str) -> Option<usize> {
let dim = type_name
.strip_prefix("Vector(")?
.strip_suffix(')')?
.parse::<usize>()
.ok()?;
if dim == 0 { None } else { Some(dim) }
}
fn parse_list_item_type(type_name: &str) -> Option<&str> {
Some(type_name.strip_prefix('[')?.strip_suffix(']')?.trim())
}
fn json_type_name(value: &Value) -> &'static str {
match value {
Value::Null => "null",
Value::Bool(_) => "boolean",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{JsonParamMode, ToParam, find_named_query, json_params_to_param_map};
use crate::query::ast::Literal;
#[test]
fn js_mode_rejects_unsafe_integer_numbers() {
let query = find_named_query(
"query find($id: U64) { match { $u: User } return { $u } }",
"find",
)
.expect("query should parse");
let error = json_params_to_param_map(
Some(&json!({ "id": 9_007_199_254_740_992u64 })),
&query.params,
JsonParamMode::JavaScript,
)
.expect_err("unsafe integer should fail");
assert_eq!(
error.to_string(),
"param 'id': integer 9007199254740992 exceeds JS safe integer range; pass a decimal string for exact values"
);
}
#[test]
fn standard_mode_preserves_ffi_param_object_error() {
let error = json_params_to_param_map(Some(&json!(["nope"])), &[], JsonParamMode::Standard)
.expect_err("non-object params should fail");
assert_eq!(error.to_string(), "params must be a JSON object");
}
#[test]
fn to_param_supports_lists_and_explicit_date_literals() {
let vector = vec![1_i32, 2_i32, 3_i32].to_param().expect("vector param");
match vector {
Literal::List(values) => {
assert!(matches!(values.first(), Some(Literal::Integer(1))));
assert!(matches!(values.get(1), Some(Literal::Integer(2))));
assert!(matches!(values.get(2), Some(Literal::Integer(3))));
}
other => panic!("expected list param, got {:?}", other),
}
let date = Literal::Date("2026-03-06".to_string())
.to_param()
.expect("date param");
assert!(matches!(date, Literal::Date(ref value) if value == "2026-03-06"));
}
#[test]
fn to_param_rejects_unsigned_values_outside_engine_range() {
let error = u64::MAX.to_param().expect_err("oversized u64 should fail");
assert_eq!(
error.to_string(),
format!(
"execution error: param value {} exceeds current engine range for numeric literals (max {})",
u64::MAX,
i64::MAX
)
);
}
#[test]
fn params_macro_builds_param_map() {
let params = params! {
"name" => "Alice",
"age" => 41_i32,
"scores" => [1_u8, 2_u8, 3_u8],
"published_at" => Literal::DateTime("2026-03-06T12:00:00Z".to_string()),
}
.expect("params");
assert!(matches!(
params.get("name"),
Some(Literal::String(value)) if value == "Alice"
));
assert!(matches!(params.get("age"), Some(Literal::Integer(41))));
match params.get("scores") {
Some(Literal::List(values)) => {
assert!(matches!(values.first(), Some(Literal::Integer(1))));
assert!(matches!(values.get(1), Some(Literal::Integer(2))));
assert!(matches!(values.get(2), Some(Literal::Integer(3))));
}
other => panic!("expected list param, got {:?}", other),
}
assert!(matches!(
params.get("published_at"),
Some(Literal::DateTime(value)) if value == "2026-03-06T12:00:00Z"
));
}
#[test]
fn typed_json_params_support_list_and_datetime_types() {
let query = find_named_query(
r#"
query q($tags: [String], $days: [Date]?, $due_at: DateTime) {
match { $t: Task }
return { $t.slug }
}
"#,
"q",
)
.expect("query");
let params = json_params_to_param_map(
Some(&json!({
"tags": ["launch", "priority"],
"days": ["2026-04-01", "2026-04-02"],
"due_at": "2026-04-03T10:15:00Z"
})),
&query.params,
JsonParamMode::Standard,
)
.expect("typed params");
assert!(matches!(
params.get("due_at"),
Some(Literal::DateTime(value)) if value == "2026-04-03T10:15:00Z"
));
match params.get("tags") {
Some(Literal::List(values)) => {
assert!(
matches!(values.first(), Some(Literal::String(value)) if value == "launch")
);
assert!(
matches!(values.get(1), Some(Literal::String(value)) if value == "priority")
);
}
other => panic!("expected string list param, got {:?}", other),
}
match params.get("days") {
Some(Literal::List(values)) => {
assert!(
matches!(values.first(), Some(Literal::Date(value)) if value == "2026-04-01")
);
assert!(
matches!(values.get(1), Some(Literal::Date(value)) if value == "2026-04-02")
);
}
other => panic!("expected date list param, got {:?}", other),
}
}
}

View file

@ -0,0 +1,286 @@
use std::sync::Arc;
use arrow_array::{RecordBatch, UInt64Array};
use arrow_ipc::writer::StreamWriter;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use serde::de::DeserializeOwned;
use crate::error::{NanoError, Result};
use crate::json_output::{record_batches_to_json_rows, record_batches_to_rust_json_rows};
#[derive(Debug, Clone, Copy, Default)]
pub struct MutationExecResult {
pub affected_nodes: usize,
pub affected_edges: usize,
}
#[derive(Debug, Clone)]
pub struct QueryResult {
schema: SchemaRef,
batches: Vec<RecordBatch>,
}
impl QueryResult {
pub fn new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Self {
Self { schema, batches }
}
pub fn schema(&self) -> &SchemaRef {
&self.schema
}
pub fn batches(&self) -> &[RecordBatch] {
&self.batches
}
pub fn into_batches(self) -> Vec<RecordBatch> {
self.batches
}
pub fn num_rows(&self) -> usize {
self.batches.iter().map(RecordBatch::num_rows).sum()
}
pub fn concat_batches(&self) -> Result<RecordBatch> {
if self.batches.is_empty() {
return Ok(RecordBatch::new_empty(self.schema.clone()));
}
arrow_select::concat::concat_batches(&self.schema, &self.batches)
.map_err(|err| NanoError::Execution(err.to_string()))
}
pub fn to_sdk_json(&self) -> serde_json::Value {
serde_json::Value::Array(record_batches_to_json_rows(&self.batches))
}
pub fn to_rust_json(&self) -> serde_json::Value {
serde_json::Value::Array(record_batches_to_rust_json_rows(&self.batches))
}
pub fn deserialize<T: DeserializeOwned>(&self) -> Result<T> {
serde_json::from_value(self.to_rust_json()).map_err(|err| {
NanoError::Execution(format!("failed to deserialize query result: {}", err))
})
}
pub fn to_arrow_ipc(&self) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
let mut writer = StreamWriter::try_new(&mut buffer, &self.schema)?;
for batch in &self.batches {
writer.write(batch)?;
}
writer.finish()?;
drop(writer);
Ok(buffer)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MutationResult {
pub affected_nodes: usize,
pub affected_edges: usize,
}
impl MutationResult {
pub fn to_sdk_json(&self) -> serde_json::Value {
serde_json::json!({
"affectedNodes": self.affected_nodes,
"affectedEdges": self.affected_edges,
})
}
pub fn to_record_batch(&self) -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![
Field::new("affected_nodes", DataType::UInt64, false),
Field::new("affected_edges", DataType::UInt64, false),
]));
Ok(RecordBatch::try_new(
schema,
vec![
Arc::new(UInt64Array::from(vec![self.affected_nodes as u64])),
Arc::new(UInt64Array::from(vec![self.affected_edges as u64])),
],
)?)
}
}
impl From<MutationExecResult> for MutationResult {
fn from(value: MutationExecResult) -> Self {
Self {
affected_nodes: value.affected_nodes,
affected_edges: value.affected_edges,
}
}
}
#[derive(Debug, Clone)]
pub enum RunResult {
Query(QueryResult),
Mutation(MutationResult),
}
impl RunResult {
pub fn to_sdk_json(&self) -> serde_json::Value {
match self {
Self::Query(result) => result.to_sdk_json(),
Self::Mutation(result) => result.to_sdk_json(),
}
}
pub fn into_record_batches(self) -> Result<Vec<RecordBatch>> {
match self {
Self::Query(result) => Ok(result.into_batches()),
Self::Mutation(result) => Ok(vec![result.to_record_batch()?]),
}
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use arrow_array::Int64Array;
use arrow_ipc::reader::StreamReader;
use serde::Deserialize;
use super::*;
#[test]
fn query_result_arrow_ipc_round_trips_empty_schema() {
let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false)]));
let result = QueryResult::new(schema.clone(), vec![]);
let encoded = result.to_arrow_ipc().expect("encode empty result");
let reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
assert_eq!(reader.schema().as_ref(), schema.as_ref());
assert_eq!(reader.count(), 0);
}
#[test]
fn query_result_arrow_ipc_round_trips_batches() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
)
.expect("batch");
let result = QueryResult::new(schema.clone(), vec![batch]);
let encoded = result.to_arrow_ipc().expect("encode result");
let mut reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
let decoded = reader.next().expect("first batch").expect("decode batch");
assert_eq!(reader.schema().as_ref(), schema.as_ref());
assert_eq!(decoded.num_rows(), 2);
assert_eq!(decoded.schema().as_ref(), schema.as_ref());
}
#[test]
fn query_result_num_rows_and_concat_cover_multiple_batches() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
)
.expect("batch1");
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(UInt64Array::from(vec![3_u64]))],
)
.expect("batch2");
let result = QueryResult::new(schema.clone(), vec![batch1, batch2]);
assert_eq!(result.num_rows(), 3);
let concatenated = result.concat_batches().expect("concat batches");
let ids = concatenated
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.expect("u64 ids");
assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
assert_eq!(ids.values(), &[1, 2, 3]);
}
#[test]
fn query_result_concat_empty_batches_returns_empty_batch() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
let result = QueryResult::new(schema.clone(), vec![]);
let concatenated = result.concat_batches().expect("concat empty");
assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
assert_eq!(concatenated.num_rows(), 0);
}
#[test]
fn query_result_to_rust_json_preserves_wide_integers() {
let schema = Arc::new(Schema::new(vec![
Field::new("signed", DataType::Int64, false),
Field::new("unsigned", DataType::UInt64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![i64::MIN])),
Arc::new(UInt64Array::from(vec![u64::MAX])),
],
)
.expect("batch");
let result = QueryResult::new(schema, vec![batch]);
assert_eq!(
result.to_rust_json(),
serde_json::json!([{
"signed": i64::MIN,
"unsigned": u64::MAX,
}])
);
}
#[derive(Debug, Deserialize, PartialEq)]
struct PersonRow {
id: u64,
age: i64,
}
#[test]
fn query_result_deserialize_decodes_rust_rows() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt64, false),
Field::new("age", DataType::Int64, false),
]));
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt64Array::from(vec![1_u64])),
Arc::new(Int64Array::from(vec![40_i64])),
],
)
.expect("batch1");
let batch2 = RecordBatch::try_new(
schema,
vec![
Arc::new(UInt64Array::from(vec![u64::MAX])),
Arc::new(Int64Array::from(vec![-5_i64])),
],
)
.expect("batch2");
let result = QueryResult::new(batch1.schema(), vec![batch1, batch2]);
let rows: Vec<PersonRow> = result.deserialize().expect("deserialize rows");
assert_eq!(
rows,
vec![
PersonRow { id: 1, age: 40 },
PersonRow {
id: u64::MAX,
age: -5,
},
]
);
}
}

View file

@ -0,0 +1,111 @@
use crate::types::PropType;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SchemaFile {
pub declarations: Vec<SchemaDecl>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SchemaDecl {
Interface(InterfaceDecl),
Node(NodeDecl),
Edge(EdgeDecl),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct InterfaceDecl {
pub name: String,
pub properties: Vec<PropDecl>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct NodeDecl {
pub name: String,
pub annotations: Vec<Annotation>,
pub implements: Vec<String>,
pub properties: Vec<PropDecl>,
pub constraints: Vec<Constraint>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EdgeDecl {
pub name: String,
pub from_type: String,
pub to_type: String,
pub cardinality: Cardinality,
pub annotations: Vec<Annotation>,
pub properties: Vec<PropDecl>,
pub constraints: Vec<Constraint>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PropDecl {
pub name: String,
pub prop_type: PropType,
pub annotations: Vec<Annotation>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Annotation {
pub name: String,
pub value: Option<String>,
}
/// A typed constraint declared in a node or edge body.
///
/// Property-level annotations (`@key`, `@unique`, `@index`) are desugared
/// into these during parsing, so both syntactic positions produce the same
/// representation.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Constraint {
Key(Vec<String>),
Unique(Vec<String>),
Index(Vec<String>),
Range {
property: String,
min: Option<ConstraintBound>,
max: Option<ConstraintBound>,
},
Check {
property: String,
pattern: String,
},
}
/// A numeric bound used in `@range` constraints.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ConstraintBound {
Integer(i64),
Float(f64),
}
/// Edge cardinality: `@card(min..max)`. Default is `0..*`.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Cardinality {
pub min: u32,
pub max: Option<u32>,
}
impl Default for Cardinality {
fn default() -> Self {
Self { min: 0, max: None }
}
}
impl Cardinality {
pub fn is_default(&self) -> bool {
self.min == 0 && self.max.is_none()
}
}
pub fn has_annotation(annotations: &[Annotation], name: &str) -> bool {
annotations.iter().any(|ann| ann.name == name)
}
pub fn annotation_value<'a>(annotations: &'a [Annotation], name: &str) -> Option<&'a str> {
annotations
.iter()
.find(|ann| ann.name == name)
.and_then(|ann| ann.value.as_deref())
}

View file

@ -0,0 +1,2 @@
pub mod ast;
pub mod parser;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,60 @@
// Omnigraph Schema Grammar (.pg files)
WHITESPACE = _{ " " | "\t" | "\r" | "\n" }
COMMENT = _{ LINE_COMMENT | BLOCK_COMMENT }
LINE_COMMENT = _{ "//" ~ (!"\n" ~ ANY)* }
BLOCK_COMMENT = _{ "/*" ~ (!"*/" ~ ANY)* ~ "*/" }
schema_file = { SOI ~ schema_decl* ~ EOI }
schema_decl = { interface_decl | node_decl | edge_decl }
// interface Named { name: String @key }
interface_decl = { "interface" ~ type_name ~ "{" ~ prop_decl* ~ "}" }
// node Person implements Named, Described { ... }
node_decl = { "node" ~ type_name ~ annotation* ~ implements_clause? ~ "{" ~ (prop_decl | body_constraint)* ~ "}" }
implements_clause = { "implements" ~ type_name ~ ("," ~ type_name)* }
// edge Knows: Person -> Person @card(0..1) { ... }
// edge Knows: Person -> Person
edge_decl = { "edge" ~ type_name ~ ":" ~ type_name ~ "->" ~ type_name ~ cardinality? ~ annotation* ~ ("{" ~ (prop_decl | body_constraint)* ~ "}")? }
// @card(0..1), @card(1..), @card(0..)
cardinality = { "@card" ~ "(" ~ integer ~ ".." ~ integer? ~ ")" }
prop_decl = { ident ~ ":" ~ type_ref ~ annotation* }
// Body-level constraints: @key(name), @unique(a, b), @index(a, b), @range(age, 0..200), @check(code, "regex")
body_constraint = { "@" ~ constraint_name ~ "(" ~ constraint_args ~ ")" }
constraint_name = { "key" | "unique" | "index" | "range" | "check" }
constraint_args = { constraint_arg ~ ("," ~ constraint_arg)* }
constraint_arg = { range_bound | literal | ident }
range_bound = { (signed_float | signed_integer) ~ ".." ~ (signed_float | signed_integer)? | ".." ~ (signed_float | signed_integer) }
type_ref = { core_type ~ "?"? }
core_type = { list_type | enum_type | vector_type | base_type }
list_type = { "[" ~ base_type ~ "]" }
enum_type = { "enum" ~ "(" ~ enum_value ~ ("," ~ enum_value)* ~ ")" }
vector_type = { "Vector" ~ "(" ~ integer ~ ")" }
enum_value = @{ (ASCII_ALPHANUMERIC | "_" | "-")+ }
base_type = { "String" | "Blob" | "Bool" | "I32" | "I64" | "U32" | "U64" | "F32" | "F64" | "DateTime" | "Date" }
// Annotation rule excludes constraint keywords followed by "(" — those are body_constraints
annotation = { "@" ~ !(constraint_name ~ "(") ~ ident ~ ("(" ~ annotation_arg ~ ")")? }
annotation_arg = { literal | ident }
literal = { string_lit | float_lit | integer | bool_lit }
string_lit = @{ "\"" ~ string_char* ~ "\"" }
string_char = @{ !("\"" | "\\") ~ ANY | "\\" ~ ANY }
float_lit = @{ ASCII_DIGIT+ ~ "." ~ ASCII_DIGIT+ }
integer = @{ ASCII_DIGIT+ }
signed_float = @{ "-"? ~ ASCII_DIGIT+ ~ "." ~ ASCII_DIGIT+ }
signed_integer = @{ "-"? ~ ASCII_DIGIT+ }
bool_lit = { "true" | "false" }
type_name = @{ ASCII_ALPHA_UPPER ~ (ASCII_ALPHANUMERIC | "_")* }
ident = @{ (ASCII_ALPHA_LOWER | "_") ~ (ASCII_ALPHANUMERIC | "_")* }

View file

@ -0,0 +1,227 @@
use arrow_schema::DataType;
use serde::{Deserialize, Serialize};
const MAX_VECTOR_DIM: u32 = i32::MAX as u32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ScalarType {
String,
Bool,
I32,
I64,
U32,
U64,
F32,
F64,
Date,
DateTime,
Vector(u32),
Blob,
}
impl ScalarType {
pub fn from_str_name(s: &str) -> Option<Self> {
if let Some(inner) = s.strip_prefix("Vector(").and_then(|t| t.strip_suffix(')')) {
let dim = inner.parse::<u32>().ok()?;
if dim == 0 || dim > MAX_VECTOR_DIM {
return None;
}
return Some(Self::Vector(dim));
}
match s {
"String" => Some(Self::String),
"Bool" => Some(Self::Bool),
"I32" => Some(Self::I32),
"I64" => Some(Self::I64),
"U32" => Some(Self::U32),
"U64" => Some(Self::U64),
"F32" => Some(Self::F32),
"F64" => Some(Self::F64),
"Date" => Some(Self::Date),
"DateTime" => Some(Self::DateTime),
"Blob" => Some(Self::Blob),
_ => None,
}
}
pub fn to_arrow(&self) -> DataType {
match self {
Self::String => DataType::Utf8,
Self::Bool => DataType::Boolean,
Self::I32 => DataType::Int32,
Self::I64 => DataType::Int64,
Self::U32 => DataType::UInt32,
Self::U64 => DataType::UInt64,
Self::F32 => DataType::Float32,
Self::F64 => DataType::Float64,
Self::Date => DataType::Date32,
Self::DateTime => DataType::Date64,
Self::Blob => DataType::LargeBinary,
Self::Vector(dim) => {
let dim = i32::try_from(*dim)
.expect("vector dimension exceeds Arrow FixedSizeList i32 bound");
DataType::FixedSizeList(
std::sync::Arc::new(arrow_schema::Field::new("item", DataType::Float32, true)),
dim,
)
}
}
}
pub fn is_numeric(&self) -> bool {
matches!(
self,
Self::I32 | Self::I64 | Self::U32 | Self::U64 | Self::F32 | Self::F64
)
}
}
impl std::fmt::Display for ScalarType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::String => "String",
Self::Bool => "Bool",
Self::I32 => "I32",
Self::I64 => "I64",
Self::U32 => "U32",
Self::U64 => "U64",
Self::F32 => "F32",
Self::F64 => "F64",
Self::Date => "Date",
Self::DateTime => "DateTime",
Self::Blob => "Blob",
Self::Vector(dim) => return write!(f, "Vector({})", dim),
};
write!(f, "{}", s)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PropType {
pub scalar: ScalarType,
pub nullable: bool,
pub list: bool,
pub enum_values: Option<Vec<String>>,
}
impl PropType {
pub fn from_param_type_name(s: &str, nullable: bool) -> Option<Self> {
if let Some(inner) = s
.strip_prefix('[')
.and_then(|value| value.strip_suffix(']'))
{
let scalar = ScalarType::from_str_name(inner)?;
return Some(Self::list_of(scalar, nullable));
}
let scalar = ScalarType::from_str_name(s)?;
Some(Self::scalar(scalar, nullable))
}
pub fn scalar(scalar: ScalarType, nullable: bool) -> Self {
Self {
scalar,
nullable,
list: false,
enum_values: None,
}
}
pub fn list_of(scalar: ScalarType, nullable: bool) -> Self {
Self {
scalar,
nullable,
list: true,
enum_values: None,
}
}
pub fn enum_type(mut values: Vec<String>, nullable: bool) -> Self {
values.sort();
values.dedup();
Self {
scalar: ScalarType::String,
nullable,
list: false,
enum_values: Some(values),
}
}
pub fn is_enum(&self) -> bool {
self.enum_values.is_some()
}
pub fn to_arrow(&self) -> DataType {
let scalar_dt = self.scalar.to_arrow();
if self.list {
DataType::List(std::sync::Arc::new(arrow_schema::Field::new(
"item", scalar_dt, true,
)))
} else {
scalar_dt
}
}
pub fn display_name(&self) -> String {
let base = if let Some(values) = &self.enum_values {
format!("enum({})", values.join(", "))
} else {
self.scalar.to_string()
};
let wrapped = if self.list {
format!("[{}]", base)
} else {
base
};
if self.nullable {
format!("{}?", wrapped)
} else {
wrapped
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Direction {
Out,
In,
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_schema::{DataType, Field};
use std::sync::Arc;
#[test]
fn vector_to_arrow_uses_nullable_float32_child() {
let dt = ScalarType::Vector(4).to_arrow();
assert_eq!(
dt,
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4)
);
}
#[test]
fn scalar_type_from_str_name_rejects_vector_dimensions_outside_arrow_bounds() {
let too_large = format!("Vector({})", (i32::MAX as u64) + 1);
assert!(ScalarType::from_str_name(&too_large).is_none());
assert_eq!(
ScalarType::from_str_name("Vector(2147483647)"),
Some(ScalarType::Vector(2147483647))
);
}
#[test]
fn prop_type_from_param_type_name_supports_lists_and_nullable_scalars() {
assert_eq!(
PropType::from_param_type_name("[DateTime]", false),
Some(PropType::list_of(ScalarType::DateTime, false))
);
assert_eq!(
PropType::from_param_type_name("DateTime", true),
Some(PropType::scalar(ScalarType::DateTime, true))
);
}
}

View file

@ -0,0 +1,30 @@
[package]
name = "omnigraph-server"
version = "0.4.0"
edition = "2024"
description = "HTTP server for the Omnigraph graph database."
license = "MIT"
[[bin]]
name = "omnigraph-server"
path = "src/main.rs"
[dependencies]
omnigraph = { path = "../omnigraph", version = "0.4.0" }
omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.4.0" }
axum = { workspace = true }
clap = { workspace = true }
color-eyre = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
serde_yaml = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
tower-http = { workspace = true }
cedar-policy = { workspace = true }
[dev-dependencies]
tempfile = { workspace = true }
tower = { workspace = true }
serial_test = "3"

View file

@ -0,0 +1,395 @@
use omnigraph::db::{GraphCommit, MergeOutcome, ReadTarget, RunRecord, Snapshot};
use omnigraph::error::{MergeConflict, MergeConflictKind};
use omnigraph::loader::{IngestResult, LoadMode};
use omnigraph_compiler::result::QueryResult;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotTableOutput {
pub table_key: String,
pub table_path: String,
pub table_version: u64,
pub table_branch: Option<String>,
pub row_count: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotOutput {
pub branch: String,
pub manifest_version: u64,
pub tables: Vec<SnapshotTableOutput>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunOutput {
pub run_id: String,
pub target_branch: String,
pub run_branch: String,
pub base_snapshot_id: String,
pub base_manifest_version: u64,
pub operation_hash: Option<String>,
pub actor_id: Option<String>,
pub status: String,
pub published_snapshot_id: Option<String>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunListOutput {
pub runs: Vec<RunOutput>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchCreateRequest {
pub from: Option<String>,
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchCreateOutput {
pub uri: String,
pub from: String,
pub name: String,
pub actor_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchListOutput {
pub branches: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchDeleteOutput {
pub uri: String,
pub name: String,
pub actor_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchMergeRequest {
pub source: String,
pub target: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BranchMergeOutcome {
AlreadyUpToDate,
FastForward,
Merged,
}
impl From<MergeOutcome> for BranchMergeOutcome {
fn from(value: MergeOutcome) -> Self {
match value {
MergeOutcome::AlreadyUpToDate => Self::AlreadyUpToDate,
MergeOutcome::FastForward => Self::FastForward,
MergeOutcome::Merged => Self::Merged,
}
}
}
impl BranchMergeOutcome {
pub fn as_str(self) -> &'static str {
match self {
Self::AlreadyUpToDate => "already_up_to_date",
Self::FastForward => "fast_forward",
Self::Merged => "merged",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchMergeOutput {
pub source: String,
pub target: String,
pub outcome: BranchMergeOutcome,
pub actor_id: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MergeConflictKindOutput {
DivergentInsert,
DivergentUpdate,
DeleteVsUpdate,
OrphanEdge,
UniqueViolation,
CardinalityViolation,
ValueConstraintViolation,
}
impl MergeConflictKindOutput {
pub fn as_str(self) -> &'static str {
match self {
Self::DivergentInsert => "divergent_insert",
Self::DivergentUpdate => "divergent_update",
Self::DeleteVsUpdate => "delete_vs_update",
Self::OrphanEdge => "orphan_edge",
Self::UniqueViolation => "unique_violation",
Self::CardinalityViolation => "cardinality_violation",
Self::ValueConstraintViolation => "value_constraint_violation",
}
}
}
impl From<MergeConflictKind> for MergeConflictKindOutput {
fn from(value: MergeConflictKind) -> Self {
match value {
MergeConflictKind::DivergentInsert => Self::DivergentInsert,
MergeConflictKind::DivergentUpdate => Self::DivergentUpdate,
MergeConflictKind::DeleteVsUpdate => Self::DeleteVsUpdate,
MergeConflictKind::OrphanEdge => Self::OrphanEdge,
MergeConflictKind::UniqueViolation => Self::UniqueViolation,
MergeConflictKind::CardinalityViolation => Self::CardinalityViolation,
MergeConflictKind::ValueConstraintViolation => Self::ValueConstraintViolation,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergeConflictOutput {
pub table_key: String,
pub row_id: Option<String>,
pub kind: MergeConflictKindOutput,
pub message: String,
}
impl From<&MergeConflict> for MergeConflictOutput {
fn from(value: &MergeConflict) -> Self {
Self {
table_key: value.table_key.clone(),
row_id: value.row_id.clone(),
kind: value.kind.into(),
message: value.message.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReadTargetOutput {
pub branch: Option<String>,
pub snapshot: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReadOutput {
pub query_name: String,
pub target: ReadTargetOutput,
pub row_count: usize,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub columns: Vec<String>,
pub rows: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChangeOutput {
pub branch: String,
pub query_name: String,
pub affected_nodes: usize,
pub affected_edges: usize,
pub actor_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IngestTableOutput {
pub table_key: String,
pub rows_loaded: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IngestOutput {
pub uri: String,
pub branch: String,
pub base_branch: String,
pub branch_created: bool,
pub mode: LoadMode,
pub tables: Vec<IngestTableOutput>,
pub actor_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommitOutput {
pub graph_commit_id: String,
pub manifest_branch: Option<String>,
pub manifest_version: u64,
pub parent_commit_id: Option<String>,
pub merged_parent_commit_id: Option<String>,
pub actor_id: Option<String>,
pub created_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommitListOutput {
pub commits: Vec<CommitOutput>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReadRequest {
pub query_source: String,
pub query_name: Option<String>,
pub params: Option<Value>,
pub branch: Option<String>,
pub snapshot: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChangeRequest {
pub query_source: String,
pub query_name: Option<String>,
pub params: Option<Value>,
pub branch: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IngestRequest {
pub branch: Option<String>,
pub from: Option<String>,
pub mode: Option<LoadMode>,
pub data: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportRequest {
pub branch: Option<String>,
#[serde(default)]
pub type_names: Vec<String>,
#[serde(default)]
pub table_keys: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SnapshotQuery {
pub branch: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CommitListQuery {
pub branch: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthOutput {
pub status: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub source_version: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ErrorCode {
Unauthorized,
Forbidden,
BadRequest,
NotFound,
Conflict,
Internal,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorOutput {
pub error: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<ErrorCode>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub merge_conflicts: Vec<MergeConflictOutput>,
}
pub fn snapshot_payload(branch: &str, snapshot: &Snapshot) -> SnapshotOutput {
let mut entries: Vec<_> = snapshot.entries().cloned().collect();
entries.sort_by(|a, b| a.table_key.cmp(&b.table_key));
let tables = entries
.iter()
.map(|entry| SnapshotTableOutput {
table_key: entry.table_key.clone(),
table_path: entry.table_path.clone(),
table_version: entry.table_version,
table_branch: entry.table_branch.clone(),
row_count: entry.row_count,
})
.collect::<Vec<_>>();
SnapshotOutput {
branch: branch.to_string(),
manifest_version: snapshot.version(),
tables,
}
}
pub fn run_output(run: &RunRecord) -> RunOutput {
RunOutput {
run_id: run.run_id.as_str().to_string(),
target_branch: run.target_branch.clone(),
run_branch: run.run_branch.clone(),
base_snapshot_id: run.base_snapshot_id.as_str().to_string(),
base_manifest_version: run.base_manifest_version,
operation_hash: run.operation_hash.clone(),
actor_id: run.actor_id.clone(),
status: run.status.as_str().to_string(),
published_snapshot_id: run.published_snapshot_id.clone(),
created_at: run.created_at,
updated_at: run.updated_at,
}
}
pub fn commit_output(commit: &GraphCommit) -> CommitOutput {
CommitOutput {
graph_commit_id: commit.graph_commit_id.clone(),
manifest_branch: commit.manifest_branch.clone(),
manifest_version: commit.manifest_version,
parent_commit_id: commit.parent_commit_id.clone(),
merged_parent_commit_id: commit.merged_parent_commit_id.clone(),
actor_id: commit.actor_id.clone(),
created_at: commit.created_at,
}
}
pub fn read_output(query_name: String, target: &ReadTarget, result: QueryResult) -> ReadOutput {
let columns = result
.schema()
.fields()
.iter()
.map(|field| field.name().clone())
.collect();
ReadOutput {
query_name,
target: read_target_output(target),
row_count: result.num_rows(),
columns,
rows: result.to_rust_json(),
}
}
pub fn ingest_output(uri: &str, result: &IngestResult, actor_id: Option<String>) -> IngestOutput {
IngestOutput {
uri: uri.to_string(),
branch: result.branch.clone(),
base_branch: result.base_branch.clone(),
branch_created: result.branch_created,
mode: result.mode,
tables: result
.tables
.iter()
.map(|table| IngestTableOutput {
table_key: table.table_key.clone(),
rows_loaded: table.rows_loaded,
})
.collect(),
actor_id,
}
}
pub fn read_target_output(target: &ReadTarget) -> ReadTargetOutput {
match target {
ReadTarget::Branch(branch) => ReadTargetOutput {
branch: Some(branch.clone()),
snapshot: None,
},
ReadTarget::Snapshot(snapshot) => ReadTargetOutput {
branch: None,
snapshot: Some(snapshot.as_str().to_string()),
},
}
}

View file

@ -0,0 +1,479 @@
use std::collections::BTreeMap;
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use clap::ValueEnum;
use color_eyre::eyre::{Result, bail};
use serde::{Deserialize, Serialize};
pub const DEFAULT_CONFIG_FILE: &str = "omnigraph.yaml";
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ProjectConfig {
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TargetConfig {
pub uri: String,
pub bearer_token_env: Option<String>,
}
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Serialize, Deserialize, ValueEnum)]
#[serde(rename_all = "snake_case")]
pub enum ReadOutputFormat {
#[default]
Table,
Kv,
Csv,
Jsonl,
Json,
}
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Serialize, Deserialize, ValueEnum)]
#[serde(rename_all = "snake_case")]
pub enum TableCellLayout {
#[default]
Truncate,
Wrap,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CliDefaults {
pub target: Option<String>,
pub branch: Option<String>,
pub output_format: Option<ReadOutputFormat>,
pub table_max_column_width: Option<usize>,
pub table_cell_layout: Option<TableCellLayout>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ServerDefaults {
pub target: Option<String>,
pub bind: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AuthDefaults {
pub env_file: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct QueryDefaults {
#[serde(default)]
pub roots: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PolicySettings {
pub file: Option<String>,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AliasCommand {
Read,
Change,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AliasConfig {
pub command: AliasCommand,
pub query: String,
pub name: Option<String>,
#[serde(default)]
pub args: Vec<String>,
pub target: Option<String>,
pub branch: Option<String>,
pub format: Option<ReadOutputFormat>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OmnigraphConfig {
#[serde(default)]
pub project: ProjectConfig,
#[serde(default)]
pub targets: BTreeMap<String, TargetConfig>,
#[serde(default)]
pub server: ServerDefaults,
#[serde(default)]
pub auth: AuthDefaults,
#[serde(default)]
pub cli: CliDefaults,
#[serde(default)]
pub query: QueryDefaults,
#[serde(default)]
pub aliases: BTreeMap<String, AliasConfig>,
#[serde(default)]
pub policy: PolicySettings,
#[serde(skip)]
base_dir: PathBuf,
}
impl Default for OmnigraphConfig {
fn default() -> Self {
Self {
project: ProjectConfig::default(),
targets: BTreeMap::new(),
server: ServerDefaults::default(),
auth: AuthDefaults::default(),
cli: CliDefaults::default(),
query: QueryDefaults::default(),
aliases: BTreeMap::new(),
policy: PolicySettings::default(),
base_dir: PathBuf::new(),
}
}
}
impl OmnigraphConfig {
pub fn base_dir(&self) -> &Path {
&self.base_dir
}
pub fn cli_branch(&self) -> &str {
self.cli.branch.as_deref().unwrap_or("main")
}
pub fn cli_output_format(&self) -> ReadOutputFormat {
self.cli.output_format.unwrap_or_default()
}
pub fn table_max_column_width(&self) -> usize {
self.cli.table_max_column_width.unwrap_or(80)
}
pub fn table_cell_layout(&self) -> TableCellLayout {
self.cli.table_cell_layout.unwrap_or_default()
}
pub fn cli_target_name(&self) -> Option<&str> {
self.cli.target.as_deref()
}
pub fn server_target_name(&self) -> Option<&str> {
self.server.target.as_deref()
}
pub fn server_bind(&self) -> &str {
self.server.bind.as_deref().unwrap_or("127.0.0.1:8080")
}
pub fn resolve_target_name<'a>(
&self,
explicit_uri: Option<&str>,
explicit_target: Option<&'a str>,
default_target: Option<&'a str>,
) -> Option<&'a str> {
explicit_target.or_else(|| {
if explicit_uri.is_some() {
None
} else {
default_target
}
})
}
pub fn target_bearer_token_env(
&self,
explicit_uri: Option<&str>,
explicit_target: Option<&str>,
default_target: Option<&str>,
) -> Option<&str> {
let target_name =
self.resolve_target_name(explicit_uri, explicit_target, default_target)?;
self.targets
.get(target_name)
.and_then(|target| target.bearer_token_env.as_deref())
}
pub fn resolve_auth_env_file(&self) -> Option<PathBuf> {
let path = self.auth.env_file.as_deref()?;
let path = Path::new(path);
Some(if path.is_absolute() {
path.to_path_buf()
} else {
self.base_dir.join(path)
})
}
pub fn resolve_policy_file(&self) -> Option<PathBuf> {
let path = self.policy.file.as_deref()?;
let path = Path::new(path);
Some(if path.is_absolute() {
path.to_path_buf()
} else {
self.base_dir.join(path)
})
}
pub fn resolve_policy_tests_file(&self) -> Option<PathBuf> {
let policy_file = self.resolve_policy_file()?;
Some(policy_file.with_file_name("policy.tests.yaml"))
}
pub fn alias(&self, name: &str) -> Result<&AliasConfig> {
self.aliases
.get(name)
.ok_or_else(|| color_eyre::eyre::eyre!("alias '{}' not found", name))
}
pub fn resolve_target_uri(
&self,
explicit_uri: Option<String>,
explicit_target: Option<&str>,
default_target: Option<&str>,
) -> Result<String> {
if let Some(uri) = explicit_uri {
return Ok(uri);
}
let target_name = explicit_target.or(default_target).ok_or_else(|| {
color_eyre::eyre::eyre!("URI must be provided via <URI>, --target, or config")
})?;
let target = self.targets.get(target_name).ok_or_else(|| {
color_eyre::eyre::eyre!(
"target '{}' not found in {}",
target_name,
DEFAULT_CONFIG_FILE
)
})?;
Ok(self.resolve_config_uri(&target.uri))
}
pub fn resolve_query_path(&self, query: &Path) -> Result<PathBuf> {
if query.is_absolute() {
return Ok(query.to_path_buf());
}
let direct = self.base_dir.join(query);
if direct.exists() {
return Ok(direct);
}
for root in &self.query.roots {
let candidate = self.base_dir.join(root).join(query);
if candidate.exists() {
return Ok(candidate);
}
}
bail!("query file '{}' not found", query.display());
}
fn resolve_config_uri(&self, value: &str) -> String {
if value.contains("://") {
return value.to_string();
}
let path = Path::new(value);
if path.is_absolute() {
value.to_string()
} else {
self.base_dir.join(path).to_string_lossy().to_string()
}
}
}
pub fn default_config_path() -> PathBuf {
PathBuf::from(DEFAULT_CONFIG_FILE)
}
pub fn load_config(config_path: Option<&PathBuf>) -> Result<OmnigraphConfig> {
load_config_in(&env::current_dir()?, config_path)
}
fn load_config_in(cwd: &Path, config_path: Option<&PathBuf>) -> Result<OmnigraphConfig> {
let explicit_path = config_path.cloned();
let config_path = explicit_path.or_else(|| {
let default_path = cwd.join(DEFAULT_CONFIG_FILE);
default_path.exists().then_some(default_path)
});
let mut config = if let Some(path) = &config_path {
serde_yaml::from_str::<OmnigraphConfig>(&fs::read_to_string(path)?)?
} else {
OmnigraphConfig::default()
};
config.base_dir = if let Some(path) = config_path {
absolute_base_dir(cwd, &path)?
} else {
cwd.to_path_buf()
};
Ok(config)
}
fn absolute_base_dir(cwd: &Path, path: &Path) -> Result<PathBuf> {
let path = if path.is_absolute() {
path.to_path_buf()
} else {
cwd.join(path)
};
Ok(path
.parent()
.map(Path::to_path_buf)
.unwrap_or_else(|| cwd.to_path_buf()))
}
#[cfg(test)]
mod tests {
use std::fs;
use std::path::{Path, PathBuf};
use tempfile::tempdir;
use super::{ReadOutputFormat, TableCellLayout, load_config_in};
#[test]
fn load_config_reads_yaml_defaults_from_current_dir() {
let temp = tempdir().unwrap();
fs::write(
temp.path().join("omnigraph.yaml"),
r#"
targets:
local:
uri: ./demo.omni
bearer_token_env: DEMO_TOKEN
auth:
env_file: .env.omni
cli:
target: local
branch: main
output_format: kv
table_max_column_width: 40
table_cell_layout: wrap
policy: {}
"#,
)
.unwrap();
let config = load_config_in(temp.path(), None).unwrap();
assert_eq!(config.cli_target_name(), Some("local"));
assert_eq!(config.cli_branch(), "main");
assert_eq!(config.cli_output_format(), ReadOutputFormat::Kv);
assert_eq!(config.table_max_column_width(), 40);
assert_eq!(config.table_cell_layout(), TableCellLayout::Wrap);
assert_eq!(
config.target_bearer_token_env(None, None, config.cli_target_name()),
Some("DEMO_TOKEN")
);
assert_eq!(
config.resolve_auth_env_file().unwrap(),
temp.path().join(".env.omni")
);
assert_eq!(
PathBuf::from(
config
.resolve_target_uri(None, None, config.cli_target_name())
.unwrap()
),
temp.path().join("./demo.omni")
);
}
#[test]
fn load_config_does_not_walk_parent_directories() {
let temp = tempdir().unwrap();
let child = temp.path().join("child");
fs::create_dir_all(&child).unwrap();
fs::write(
temp.path().join("omnigraph.yaml"),
"targets:\n local:\n uri: ./demo.omni\n",
)
.unwrap();
let config = load_config_in(&child, None).unwrap();
assert!(config.targets.is_empty());
}
#[test]
fn resolve_query_path_searches_config_roots() {
let temp = tempdir().unwrap();
fs::create_dir_all(temp.path().join("queries")).unwrap();
fs::write(
temp.path().join("omnigraph.yaml"),
"query:\n roots:\n - queries\npolicy: {}\n",
)
.unwrap();
fs::write(
temp.path().join("queries").join("test.gq"),
"query q { return {} }",
)
.unwrap();
let config = load_config_in(temp.path(), None).unwrap();
let resolved = config.resolve_query_path(Path::new("test.gq")).unwrap();
assert_eq!(resolved, temp.path().join("queries").join("test.gq"));
}
#[test]
fn resolve_query_path_prefers_config_base_dir_over_ambient_cwd() {
let workspace = tempdir().unwrap();
let config_dir = workspace.path().join("config");
let ambient_dir = workspace.path().join("ambient");
fs::create_dir_all(&config_dir).unwrap();
fs::create_dir_all(&ambient_dir).unwrap();
fs::write(config_dir.join("omnigraph.yaml"), "policy: {}\n").unwrap();
fs::write(config_dir.join("local.gq"), "query local { return {} }").unwrap();
fs::write(ambient_dir.join("local.gq"), "query ambient { return {} }").unwrap();
let config =
load_config_in(&ambient_dir, Some(&config_dir.join("omnigraph.yaml"))).unwrap();
let resolved = config.resolve_query_path(Path::new("local.gq")).unwrap();
assert_eq!(resolved, config_dir.join("local.gq"));
}
#[test]
fn policy_block_accepts_non_empty_mapping() {
let temp = tempdir().unwrap();
fs::write(
temp.path().join("omnigraph.yaml"),
"policy:\n file: ./policy.yaml\n",
)
.unwrap();
let config = load_config_in(temp.path(), None).unwrap();
assert_eq!(
config.resolve_policy_file().unwrap(),
temp.path().join("policy.yaml")
);
}
#[test]
fn scoped_auth_env_ignores_default_target_when_uri_is_explicit() {
let temp = tempdir().unwrap();
fs::write(
temp.path().join("omnigraph.yaml"),
r#"
targets:
demo:
uri: https://example.com
bearer_token_env: DEMO_TOKEN
cli:
target: demo
"#,
)
.unwrap();
let config = load_config_in(temp.path(), None).unwrap();
assert_eq!(
config.target_bearer_token_env(
Some("https://override.example.com"),
None,
config.cli_target_name()
),
None
);
assert_eq!(
config.target_bearer_token_env(
Some("https://override.example.com"),
Some("demo"),
config.cli_target_name()
),
Some("DEMO_TOKEN")
);
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,30 @@
use std::path::PathBuf;
use clap::Parser;
use color_eyre::eyre::Result;
use omnigraph_server::{ServerConfig, init_tracing, load_server_settings, serve};
#[derive(Debug, Parser)]
#[command(name = "omnigraph-server")]
#[command(about = "HTTP server for the Omnigraph graph database")]
struct Cli {
/// Repo URI
uri: Option<String>,
#[arg(long)]
target: Option<String>,
#[arg(long)]
config: Option<PathBuf>,
#[arg(long)]
bind: Option<String>,
}
#[tokio::main]
async fn main() -> Result<()> {
color_eyre::install()?;
init_tracing();
let cli = Cli::parse();
let settings: ServerConfig =
load_server_settings(cli.config.as_ref(), cli.uri, cli.target, cli.bind)?;
serve(settings).await
}

View file

@ -0,0 +1,812 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::fmt;
use std::fs;
use std::path::Path;
use std::str::FromStr;
use cedar_policy::{
Authorizer, Context, Decision, Entities, Entity, EntityId, EntityTypeName, EntityUid, Policy,
PolicyId, PolicySet, Request, Schema, ValidationMode, Validator,
};
use clap::ValueEnum;
use color_eyre::eyre::{Result, bail, eyre};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize, ValueEnum)]
#[serde(rename_all = "snake_case")]
pub enum PolicyAction {
Read,
Export,
Change,
BranchCreate,
BranchDelete,
BranchMerge,
RunPublish,
RunAbort,
Admin,
}
impl PolicyAction {
pub fn as_str(self) -> &'static str {
match self {
Self::Read => "read",
Self::Export => "export",
Self::Change => "change",
Self::BranchCreate => "branch_create",
Self::BranchDelete => "branch_delete",
Self::BranchMerge => "branch_merge",
Self::RunPublish => "run_publish",
Self::RunAbort => "run_abort",
Self::Admin => "admin",
}
}
fn uses_branch_scope(self) -> bool {
matches!(self, Self::Read | Self::Export | Self::Change)
}
fn uses_target_branch_scope(self) -> bool {
matches!(
self,
Self::BranchCreate
| Self::BranchDelete
| Self::BranchMerge
| Self::RunPublish
| Self::RunAbort
)
}
}
impl fmt::Display for PolicyAction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for PolicyAction {
type Err = color_eyre::eyre::Error;
fn from_str(value: &str) -> Result<Self> {
match value.trim() {
"read" => Ok(Self::Read),
"export" => Ok(Self::Export),
"change" => Ok(Self::Change),
"branch_create" => Ok(Self::BranchCreate),
"branch_delete" => Ok(Self::BranchDelete),
"branch_merge" => Ok(Self::BranchMerge),
"run_publish" => Ok(Self::RunPublish),
"run_abort" => Ok(Self::RunAbort),
"admin" => Ok(Self::Admin),
other => bail!("unknown policy action '{other}'"),
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PolicyBranchScope {
Any,
Protected,
Unprotected,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyActorSelector {
pub group: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyAllowRule {
pub actors: PolicyActorSelector,
pub actions: Vec<PolicyAction>,
pub branch_scope: Option<PolicyBranchScope>,
pub target_branch_scope: Option<PolicyBranchScope>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyRule {
pub id: String,
pub allow: PolicyAllowRule,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyConfig {
pub version: u32,
#[serde(default)]
pub groups: BTreeMap<String, Vec<String>>,
#[serde(default)]
pub protected_branches: Vec<String>,
#[serde(default)]
pub rules: Vec<PolicyRule>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyTestConfig {
pub version: u32,
#[serde(default)]
pub cases: Vec<PolicyTestCase>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyTestCase {
pub id: String,
pub actor: String,
pub action: PolicyAction,
pub branch: Option<String>,
pub target_branch: Option<String>,
pub expect: PolicyExpectation,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PolicyExpectation {
Allow,
Deny,
}
#[derive(Debug, Clone)]
pub struct PolicyRequest {
pub actor_id: String,
pub action: PolicyAction,
pub branch: Option<String>,
pub target_branch: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PolicyDecision {
pub allowed: bool,
pub matched_rule_id: Option<String>,
pub message: String,
}
pub struct PolicyCompiler;
#[derive(Clone)]
pub struct PolicyEngine {
repo_id: String,
protected_branches: BTreeSet<String>,
known_actors: BTreeSet<String>,
schema: Schema,
entities: Entities,
policies: PolicySet,
policy_to_rule: HashMap<String, String>,
}
impl PolicyConfig {
pub fn load(path: &Path) -> Result<Self> {
let config: Self = serde_yaml::from_str(&fs::read_to_string(path)?)?;
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> Result<()> {
if self.version != 1 {
bail!("policy version must be 1");
}
for (group, members) in &self.groups {
if group.trim().is_empty() {
bail!("policy group names must not be blank");
}
if members.is_empty() {
bail!("policy group '{group}' must not be empty");
}
for actor in members {
if actor.trim().is_empty() {
bail!("policy group '{group}' contains a blank actor id");
}
}
}
for branch in &self.protected_branches {
if branch.trim().is_empty() {
bail!("protected branch names must not be blank");
}
}
let mut seen_rule_ids = HashSet::new();
for rule in &self.rules {
if rule.id.trim().is_empty() {
bail!("policy rule ids must not be blank");
}
if !seen_rule_ids.insert(rule.id.clone()) {
bail!("duplicate policy rule id '{}'", rule.id);
}
if rule.allow.actors.group.trim().is_empty() {
bail!("policy rule '{}' must reference a non-blank group", rule.id);
}
if !self.groups.contains_key(rule.allow.actors.group.as_str()) {
bail!(
"policy rule '{}' references unknown group '{}'",
rule.id,
rule.allow.actors.group
);
}
if rule.allow.actions.is_empty() {
bail!("policy rule '{}' must include at least one action", rule.id);
}
if rule.allow.branch_scope.is_some() && rule.allow.target_branch_scope.is_some() {
bail!(
"policy rule '{}' may specify branch_scope or target_branch_scope, not both",
rule.id
);
}
if let Some(_) = rule.allow.branch_scope {
for action in &rule.allow.actions {
if !action.uses_branch_scope() {
bail!(
"policy rule '{}' uses branch_scope with unsupported action '{}'",
rule.id,
action
);
}
}
}
if let Some(_) = rule.allow.target_branch_scope {
for action in &rule.allow.actions {
if !action.uses_target_branch_scope() {
bail!(
"policy rule '{}' uses target_branch_scope with unsupported action '{}'",
rule.id,
action
);
}
}
}
}
Ok(())
}
}
impl PolicyTestConfig {
pub fn load(path: &Path) -> Result<Self> {
let config: Self = serde_yaml::from_str(&fs::read_to_string(path)?)?;
if config.version != 1 {
bail!("policy test version must be 1");
}
let mut seen = HashSet::new();
for case in &config.cases {
if case.id.trim().is_empty() {
bail!("policy test case ids must not be blank");
}
if !seen.insert(case.id.clone()) {
bail!("duplicate policy test case id '{}'", case.id);
}
if case.actor.trim().is_empty() {
bail!("policy test case '{}' must not use a blank actor", case.id);
}
}
Ok(config)
}
}
impl PolicyCompiler {
pub fn compile(config: &PolicyConfig, repo_id: &str) -> Result<PolicyEngine> {
config.validate()?;
let (schema, schema_warnings) = Schema::from_cedarschema_str(policy_schema_source())?;
let schema_warnings = schema_warnings
.map(|warning| warning.to_string())
.collect::<Vec<_>>();
if !schema_warnings.is_empty() {
bail!("policy schema warnings:\n{}", schema_warnings.join("\n"));
}
let entities = compile_entities(config, repo_id, &schema)?;
let (policies, policy_to_rule) = compile_policies(config, repo_id)?;
let validator = Validator::new(schema.clone());
let validation = validator.validate(&policies, ValidationMode::Strict);
let errors = validation
.validation_errors()
.map(|err| err.to_string())
.collect::<Vec<_>>();
if !errors.is_empty() {
bail!("policy validation failed:\n{}", errors.join("\n"));
}
let known_actors = config
.groups
.values()
.flat_map(|members| members.iter().cloned())
.collect();
Ok(PolicyEngine {
repo_id: repo_id.to_string(),
protected_branches: config.protected_branches.iter().cloned().collect(),
known_actors,
schema,
entities,
policies,
policy_to_rule,
})
}
}
impl PolicyEngine {
pub fn load(path: &Path, repo_id: &str) -> Result<Self> {
let config = PolicyConfig::load(path)?;
PolicyCompiler::compile(&config, repo_id)
}
pub fn authorize(&self, request: &PolicyRequest) -> Result<PolicyDecision> {
if !self.known_actors.contains(request.actor_id.as_str()) {
return Ok(self.deny(
request,
None,
format!(
"policy denied action '{}' for unknown actor '{}'",
request.action, request.actor_id
),
));
}
let principal = entity_uid("Actor", &request.actor_id)?;
let action = entity_uid("Action", request.action.as_str())?;
let resource = entity_uid("Repo", &self.repo_id)?;
let context_value = json!({
"has_branch": request.branch.is_some(),
"branch": request.branch.clone().unwrap_or_default(),
"has_target_branch": request.target_branch.is_some(),
"target_branch": request.target_branch.clone().unwrap_or_default(),
"branch_is_protected": request.branch.as_ref().is_some_and(|branch| self.protected_branches.contains(branch)),
"target_branch_is_protected": request.target_branch.as_ref().is_some_and(|branch| self.protected_branches.contains(branch)),
});
let context = Context::from_json_value(context_value, Some((&self.schema, &action)))?;
let cedar_request = Request::new(principal, action, resource, context, Some(&self.schema))?;
let response =
Authorizer::new().is_authorized(&cedar_request, &self.policies, &self.entities);
let errors = response
.diagnostics()
.errors()
.map(|err| err.to_string())
.collect::<Vec<_>>();
if !errors.is_empty() {
bail!("policy evaluation failed:\n{}", errors.join("\n"));
}
let matched_rule_id = response
.diagnostics()
.reason()
.filter_map(|policy_id| {
let key: &str = policy_id.as_ref();
self.policy_to_rule.get(key).cloned()
})
.min();
Ok(match response.decision() {
Decision::Allow => PolicyDecision {
allowed: true,
matched_rule_id: matched_rule_id.clone(),
message: format!(
"policy allowed action '{}' for actor '{}'",
request.action, request.actor_id
),
},
Decision::Deny => {
let message = format!(
"policy denied action '{}'{}{} for actor '{}'",
request.action,
request
.branch
.as_deref()
.map(|branch| format!(" on branch '{}'", branch))
.unwrap_or_default(),
request
.target_branch
.as_deref()
.map(|branch| format!(" targeting branch '{}'", branch))
.unwrap_or_default(),
request.actor_id
);
self.deny(request, matched_rule_id, message)
}
})
}
pub fn validate_request(&self, request: &PolicyRequest) -> Result<()> {
let _ = self.authorize(request)?;
Ok(())
}
pub fn run_tests(&self, tests: &PolicyTestConfig) -> Result<()> {
if tests.version != 1 {
bail!("policy test version must be 1");
}
let mut failures = Vec::new();
for case in &tests.cases {
let decision = self.authorize(&PolicyRequest {
actor_id: case.actor.clone(),
action: case.action,
branch: case.branch.clone(),
target_branch: case.target_branch.clone(),
})?;
let expected_allowed = matches!(case.expect, PolicyExpectation::Allow);
if decision.allowed != expected_allowed {
failures.push(format!(
"{}: expected {:?} but got {}",
case.id,
case.expect,
if decision.allowed { "allow" } else { "deny" }
));
}
}
if failures.is_empty() {
Ok(())
} else {
bail!("policy tests failed:\n{}", failures.join("\n"))
}
}
pub fn known_actor_count(&self) -> usize {
self.known_actors.len()
}
fn deny(
&self,
_request: &PolicyRequest,
matched_rule_id: Option<String>,
message: String,
) -> PolicyDecision {
PolicyDecision {
allowed: false,
matched_rule_id,
message,
}
}
}
fn compile_entities(config: &PolicyConfig, repo_id: &str, schema: &Schema) -> Result<Entities> {
let mut group_entities = Vec::new();
for group in config.groups.keys() {
group_entities.push(Entity::new(
entity_uid("Group", group)?,
HashMap::new(),
HashSet::<EntityUid>::new(),
)?);
}
let mut actor_groups: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
for (group, members) in &config.groups {
for actor in members {
actor_groups
.entry(actor.clone())
.or_default()
.insert(group.clone());
}
}
let mut actor_entities = Vec::new();
for (actor, groups) in actor_groups {
let parents = groups
.iter()
.map(|group| entity_uid("Group", group))
.collect::<Result<HashSet<_>>>()?;
actor_entities.push(Entity::new(
entity_uid("Actor", &actor)?,
HashMap::new(),
parents,
)?);
}
let repo_entity = Entity::new(
entity_uid("Repo", repo_id)?,
HashMap::new(),
HashSet::<EntityUid>::new(),
)?;
let mut entities = Vec::new();
entities.extend(group_entities);
entities.extend(actor_entities);
entities.push(repo_entity);
Ok(Entities::from_entities(entities, Some(schema))?)
}
fn compile_policies(
config: &PolicyConfig,
repo_id: &str,
) -> Result<(PolicySet, HashMap<String, String>)> {
let mut policies = Vec::new();
let mut policy_to_rule = HashMap::new();
for rule in &config.rules {
for action in &rule.allow.actions {
let policy_id = PolicyId::new(format!("{}:{}", rule.id, action.as_str()));
let source = compile_policy_source(rule, action, repo_id);
let policy = Policy::parse(Some(policy_id.clone()), source.as_str())?;
policy_to_rule.insert(policy_id.to_string(), rule.id.clone());
policies.push(policy);
}
}
Ok((PolicySet::from_policies(policies)?, policy_to_rule))
}
fn compile_policy_source(rule: &PolicyRule, action: &PolicyAction, repo_id: &str) -> String {
let mut conditions = Vec::new();
if let Some(scope) = rule.allow.branch_scope {
conditions.push(branch_scope_condition(scope));
}
if let Some(scope) = rule.allow.target_branch_scope {
conditions.push(target_branch_scope_condition(scope));
}
let when = if conditions.is_empty() {
String::new()
} else {
format!("\nwhen {{ {} }}", conditions.join(" && "))
};
format!(
r#"permit (
principal in Omnigraph::Group::{group},
action == Omnigraph::Action::{action},
resource == Omnigraph::Repo::{repo}
){when};"#,
group = cedar_literal(&rule.allow.actors.group),
action = cedar_literal(action.as_str()),
repo = cedar_literal(repo_id),
when = when,
)
}
fn branch_scope_condition(scope: PolicyBranchScope) -> String {
match scope {
PolicyBranchScope::Any => "true".to_string(),
PolicyBranchScope::Protected => {
"context.has_branch && context.branch_is_protected".to_string()
}
PolicyBranchScope::Unprotected => {
"context.has_branch && context.branch_is_protected == false".to_string()
}
}
}
fn target_branch_scope_condition(scope: PolicyBranchScope) -> String {
match scope {
PolicyBranchScope::Any => "true".to_string(),
PolicyBranchScope::Protected => {
"context.has_target_branch && context.target_branch_is_protected".to_string()
}
PolicyBranchScope::Unprotected => {
"context.has_target_branch && context.target_branch_is_protected == false".to_string()
}
}
}
fn policy_schema_source() -> &'static str {
r#"
namespace Omnigraph {
type RequestContext = {
has_branch: Bool,
branch: String,
has_target_branch: Bool,
target_branch: String,
branch_is_protected: Bool,
target_branch_is_protected: Bool,
};
entity Actor in [Group];
entity Group;
entity Repo;
action "read" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
action "export" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
action "change" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
action "branch_create" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
action "branch_delete" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
action "branch_merge" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
action "run_publish" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
action "run_abort" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
action "admin" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
}
"#
}
fn entity_uid(entity_type: &str, id: &str) -> Result<EntityUid> {
let typename = EntityTypeName::from_str(&format!("Omnigraph::{entity_type}"))?;
let entity_id = EntityId::from_str(id).map_err(|err| eyre!(err.to_string()))?;
Ok(EntityUid::from_type_name_and_id(typename, entity_id))
}
fn cedar_literal(value: &str) -> String {
serde_json::to_string(value).expect("string literal should serialize")
}
impl PolicyRequest {
pub fn actor_id(&self) -> &str {
&self.actor_id
}
pub fn action(&self) -> PolicyAction {
self.action
}
pub fn branch(&self) -> Option<&str> {
self.branch.as_deref()
}
pub fn target_branch(&self) -> Option<&str> {
self.target_branch.as_deref()
}
}
#[cfg(test)]
mod tests {
use super::{
PolicyAction, PolicyCompiler, PolicyConfig, PolicyExpectation, PolicyRequest,
PolicyTestCase, PolicyTestConfig,
};
#[test]
fn rejects_duplicate_rule_ids() {
let policy: PolicyConfig = serde_yaml::from_str(
r#"
version: 1
groups:
team: [act-andrew]
rules:
- id: same
allow:
actors: { group: team }
actions: [read]
branch_scope: any
- id: same
allow:
actors: { group: team }
actions: [export]
branch_scope: any
"#,
)
.unwrap();
let err = policy.validate().unwrap_err();
assert!(err.to_string().contains("duplicate policy rule id"));
}
#[test]
fn rejects_unknown_group_references() {
let policy: PolicyConfig = serde_yaml::from_str(
r#"
version: 1
groups:
team: [act-andrew]
rules:
- id: bad
allow:
actors: { group: admins }
actions: [read]
branch_scope: any
"#,
)
.unwrap();
let err = policy.validate().unwrap_err();
assert!(err.to_string().contains("references unknown group"));
}
#[test]
fn rejects_invalid_scope_action_combinations() {
let policy: PolicyConfig = serde_yaml::from_str(
r#"
version: 1
groups:
team: [act-andrew]
rules:
- id: bad
allow:
actors: { group: team }
actions: [branch_merge]
branch_scope: protected
"#,
)
.unwrap();
let err = policy.validate().unwrap_err();
assert!(err.to_string().contains("unsupported action"));
}
#[test]
fn compiles_and_authorizes_branch_and_target_rules() {
let policy: PolicyConfig = serde_yaml::from_str(
r#"
version: 1
groups:
team: [act-andrew, act-bruno]
admins: [act-andrew]
protected_branches: [main]
rules:
- id: team-read
allow:
actors: { group: team }
actions: [read, export]
branch_scope: any
- id: team-write
allow:
actors: { group: team }
actions: [change]
branch_scope: unprotected
- id: admins-promote
allow:
actors: { group: admins }
actions: [branch_delete, branch_merge, run_publish]
target_branch_scope: protected
"#,
)
.unwrap();
let engine = PolicyCompiler::compile(&policy, "repo").unwrap();
let allow = engine
.authorize(&PolicyRequest {
actor_id: "act-bruno".to_string(),
action: PolicyAction::Change,
branch: Some("feature".to_string()),
target_branch: None,
})
.unwrap();
assert!(allow.allowed);
assert_eq!(allow.matched_rule_id.as_deref(), Some("team-write"));
let deny = engine
.authorize(&PolicyRequest {
actor_id: "act-bruno".to_string(),
action: PolicyAction::BranchDelete,
branch: None,
target_branch: Some("main".to_string()),
})
.unwrap();
assert!(!deny.allowed);
let admin = engine
.authorize(&PolicyRequest {
actor_id: "act-andrew".to_string(),
action: PolicyAction::BranchDelete,
branch: None,
target_branch: Some("main".to_string()),
})
.unwrap();
assert!(admin.allowed);
assert_eq!(admin.matched_rule_id.as_deref(), Some("admins-promote"));
}
#[test]
fn policy_tests_enforce_expected_outcomes() {
let policy: PolicyConfig = serde_yaml::from_str(
r#"
version: 1
groups:
team: [act-andrew]
protected_branches: [main]
rules:
- id: team-read
allow:
actors: { group: team }
actions: [read]
branch_scope: any
"#,
)
.unwrap();
let engine = PolicyCompiler::compile(&policy, "repo").unwrap();
let tests = PolicyTestConfig {
version: 1,
cases: vec![
PolicyTestCase {
id: "allow-read".to_string(),
actor: "act-andrew".to_string(),
action: PolicyAction::Read,
branch: Some("main".to_string()),
target_branch: None,
expect: PolicyExpectation::Allow,
},
PolicyTestCase {
id: "deny-change".to_string(),
actor: "act-andrew".to_string(),
action: PolicyAction::Change,
branch: Some("main".to_string()),
target_branch: None,
expect: PolicyExpectation::Deny,
},
],
};
engine.run_tests(&tests).unwrap();
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,47 @@
[package]
name = "omnigraph"
version = "0.4.0"
edition = "2024"
description = "Lance-native graph database with git-style branching."
license = "MIT"
[features]
default = []
failpoints = ["dep:fail", "fail/failpoints"]
[dependencies]
omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.4.0" }
lance = { workspace = true }
lance-datafusion = { workspace = true }
lance-file = { workspace = true }
lance-index = { workspace = true }
lance-linalg = { workspace = true }
lance-namespace = { workspace = true }
lance-table = { workspace = true }
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
arrow-ord = { workspace = true }
arrow-select = { workspace = true }
arrow-cast = { workspace = true }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
reqwest = { workspace = true }
object_store = { workspace = true }
ulid = { workspace = true }
base64 = { workspace = true }
futures = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
regex = { workspace = true }
tempfile = { workspace = true }
fail = { workspace = true, optional = true }
time = { workspace = true }
async-trait = { workspace = true }
url = { workspace = true }
[dev-dependencies]
omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.4.0" }
tokio = { workspace = true }
lance-namespace-impls = { workspace = true }
serial_test = "3"

View file

@ -0,0 +1,598 @@
use std::collections::HashSet;
use arrow_array::{Array, RecordBatch, StringArray, UInt64Array};
use arrow_cast::display::array_value_to_string;
use lance::dataset::scanner::ColumnOrdering;
use crate::db::SubTableEntry;
use crate::db::manifest::Snapshot;
use crate::error::Result;
use crate::table_store::TableStore;
// ─── Types ──────────────────────────────────────────────────────────────────
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EntityKind {
Node,
Edge,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChangeOp {
Insert,
Update,
Delete,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Endpoints {
pub src: String,
pub dst: String,
}
#[derive(Debug, Clone)]
pub struct EntityChange {
pub table_key: String,
pub kind: EntityKind,
pub type_name: String,
pub id: String,
pub op: ChangeOp,
pub manifest_version: u64,
pub endpoints: Option<Endpoints>,
}
#[derive(Debug, Clone, Default)]
pub struct ChangeFilter {
pub kinds: Option<Vec<EntityKind>>,
pub type_names: Option<Vec<String>>,
pub ops: Option<Vec<ChangeOp>>,
}
#[derive(Debug, Clone, Default)]
pub struct ChangeStats {
pub inserts: usize,
pub updates: usize,
pub deletes: usize,
pub types_affected: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ChangeSet {
pub from_version: u64,
pub to_version: u64,
pub branch: Option<String>,
pub changes: Vec<EntityChange>,
pub stats: ChangeStats,
}
// ─── Filter helpers ─────────────────────────────────────────────────────────
fn parse_table_key(table_key: &str) -> (EntityKind, &str) {
if let Some(name) = table_key.strip_prefix("node:") {
(EntityKind::Node, name)
} else if let Some(name) = table_key.strip_prefix("edge:") {
(EntityKind::Edge, name)
} else {
(EntityKind::Node, table_key)
}
}
impl ChangeFilter {
fn matches_table(&self, table_key: &str) -> bool {
let (kind, type_name) = parse_table_key(table_key);
if let Some(ref kinds) = self.kinds {
if !kinds.contains(&kind) {
return false;
}
}
if let Some(ref names) = self.type_names {
if !names.iter().any(|n| n == type_name) {
return false;
}
}
true
}
fn wants_op(&self, op: ChangeOp) -> bool {
match &self.ops {
Some(ops) => ops.contains(&op),
None => true,
}
}
}
// ─── Core diff ──────────────────────────────────────────────────────────────
/// Net-current diff between two snapshots.
///
/// Uses a three-level algorithm:
/// 1. Manifest diff — skip unchanged sub-tables
/// 2. Lineage check — same branch → version-column diff; different → ID-based diff
/// 3. Row-level diff
pub async fn diff_snapshots(
root_uri: &str,
from: &Snapshot,
to: &Snapshot,
filter: &ChangeFilter,
branch: Option<String>,
) -> Result<ChangeSet> {
let table_store = TableStore::new(root_uri);
let mut all_keys: HashSet<String> = HashSet::new();
for entry in from.entries() {
all_keys.insert(entry.table_key.clone());
}
for entry in to.entries() {
all_keys.insert(entry.table_key.clone());
}
let mut changes = Vec::new();
for table_key in &all_keys {
if !filter.matches_table(table_key) {
continue;
}
let from_entry = from.entry(table_key);
let to_entry = to.entry(table_key);
// Skip if both snapshots have identical state for this table
if same_state(from_entry, to_entry) {
continue;
}
let (kind, type_name) = parse_table_key(table_key);
let is_edge = kind == EntityKind::Edge;
let table_changes = if from_entry.is_none() {
// Table added — all rows are inserts
diff_table_added(&table_store, to, table_key, is_edge, filter).await?
} else if to_entry.is_none() {
// Table removed — all rows are deletes
diff_table_removed(&table_store, from, table_key, is_edge, filter).await?
} else if same_lineage(from_entry, to_entry) {
// Fast path: version-column diff
diff_table_same_lineage(
&table_store,
from_entry.unwrap(),
to_entry.unwrap(),
is_edge,
filter,
)
.await?
} else {
// Cross-branch path: streaming ID-based diff
diff_table_cross_branch(&table_store, from, to, table_key, is_edge, filter).await?
};
for mut c in table_changes {
c.table_key = table_key.clone();
c.kind = kind;
c.type_name = type_name.to_string();
if c.manifest_version == 0 {
c.manifest_version = to.version();
}
changes.push(c);
}
}
let stats = compute_stats(&changes);
Ok(ChangeSet {
from_version: from.version(),
to_version: to.version(),
branch,
changes,
stats,
})
}
fn same_state(a: Option<&SubTableEntry>, b: Option<&SubTableEntry>) -> bool {
match (a, b) {
(None, None) => true,
(Some(a), Some(b)) => {
a.table_version == b.table_version && a.table_branch == b.table_branch
}
_ => false,
}
}
fn same_lineage(from: Option<&SubTableEntry>, to: Option<&SubTableEntry>) -> bool {
match (from, to) {
(Some(f), Some(t)) => f.table_branch == t.table_branch,
_ => false,
}
}
fn compute_stats(changes: &[EntityChange]) -> ChangeStats {
let mut stats = ChangeStats::default();
let mut types = HashSet::new();
for c in changes {
match c.op {
ChangeOp::Insert => stats.inserts += 1,
ChangeOp::Update => stats.updates += 1,
ChangeOp::Delete => stats.deletes += 1,
}
types.insert(c.type_name.clone());
}
stats.types_affected = types.into_iter().collect();
stats.types_affected.sort();
stats
}
// ─── Fast path: version-column diff ─────────────────────────────────────────
async fn diff_table_same_lineage(
table_store: &TableStore,
from_entry: &SubTableEntry,
to_entry: &SubTableEntry,
is_edge: bool,
filter: &ChangeFilter,
) -> Result<Vec<EntityChange>> {
let vf = from_entry.table_version;
let vt = to_entry.table_version;
let to_ds = table_store.open_at_entry(to_entry).await?;
let cols: Vec<&str> = if is_edge {
vec!["id", "src", "dst", "_row_last_updated_at_version"]
} else {
vec!["id", "_row_last_updated_at_version"]
};
let wants_inserts = filter.wants_op(ChangeOp::Insert);
let wants_updates = filter.wants_op(ChangeOp::Update);
let wants_deletes = filter.wants_op(ChangeOp::Delete);
let mut changes = Vec::new();
// Inserts + Updates: use _row_last_updated_at_version to find all rows
// touched since Vf, then classify by checking whether the ID existed at Vf.
//
// Why not _row_created_at_version for inserts: Lance's merge_insert stamps
// new rows with _row_created_at_version = dataset_creation_version (v1),
// not the merge_insert commit version. This makes _row_created_at_version
// unreliable for detecting inserts from merge_insert writes. Using
// _row_last_updated_at_version catches all touched rows regardless of
// write mode, and ID-set membership distinguishes inserts from updates.
if wants_inserts || wants_updates {
let filter_sql = format!(
"_row_last_updated_at_version > {} AND _row_last_updated_at_version <= {}",
vf, vt
);
let changed_rows = scan_with_filter(table_store, &to_ds, &cols, &filter_sql).await?;
if !changed_rows.is_empty() {
// Build the set of IDs that existed at the from version
let from_ds = table_store.open_at_entry(from_entry).await?;
let from_ids: HashSet<String> = scan_id_set(table_store, &from_ds, &["id"])
.await?
.into_iter()
.map(|r| r.id)
.collect();
for row in changed_rows {
if from_ids.contains(&row.id) {
if wants_updates {
changes.push(entity_change_from_row(&row, ChangeOp::Update, is_edge));
}
} else if wants_inserts {
changes.push(entity_change_from_row(&row, ChangeOp::Insert, is_edge));
}
}
}
}
// Deletes: ID set-difference
if wants_deletes {
let from_ds = table_store.open_at_entry(from_entry).await?;
let deleted = deleted_ids_by_set_diff(table_store, &from_ds, &to_ds, is_edge).await?;
changes.extend(deleted);
}
Ok(changes)
}
// ─── Cross-branch path: streaming ID-based diff ────────────────────────────
async fn diff_table_cross_branch(
table_store: &TableStore,
from_snap: &Snapshot,
to_snap: &Snapshot,
table_key: &str,
is_edge: bool,
filter: &ChangeFilter,
) -> Result<Vec<EntityChange>> {
let from_ds = table_store
.open_snapshot_table(from_snap, table_key)
.await?;
let to_ds = table_store.open_snapshot_table(to_snap, table_key).await?;
let from_rows = scan_all_rows_ordered(table_store, &from_ds, is_edge).await?;
let to_rows = scan_all_rows_ordered(table_store, &to_ds, is_edge).await?;
let mut changes = Vec::new();
let mut fi = 0;
let mut ti = 0;
while fi < from_rows.len() || ti < to_rows.len() {
let from_id = from_rows.get(fi).map(|r| r.id.as_str());
let to_id = to_rows.get(ti).map(|r| r.id.as_str());
match (from_id, to_id) {
(Some(fid), Some(tid)) if fid < tid => {
// ID only in from → Delete
if filter.wants_op(ChangeOp::Delete) {
changes.push(entity_change_from_row(
&from_rows[fi],
ChangeOp::Delete,
is_edge,
));
}
fi += 1;
}
(Some(fid), Some(tid)) if fid > tid => {
// ID only in to → Insert
if filter.wants_op(ChangeOp::Insert) {
changes.push(entity_change_from_row(
&to_rows[ti],
ChangeOp::Insert,
is_edge,
));
}
ti += 1;
}
(Some(_), Some(_)) => {
// Same ID — check signature
if from_rows[fi].signature != to_rows[ti].signature
&& filter.wants_op(ChangeOp::Update)
{
changes.push(entity_change_from_row(
&to_rows[ti],
ChangeOp::Update,
is_edge,
));
}
fi += 1;
ti += 1;
}
(Some(_), None) => {
if filter.wants_op(ChangeOp::Delete) {
changes.push(entity_change_from_row(
&from_rows[fi],
ChangeOp::Delete,
is_edge,
));
}
fi += 1;
}
(None, Some(_)) => {
if filter.wants_op(ChangeOp::Insert) {
changes.push(entity_change_from_row(
&to_rows[ti],
ChangeOp::Insert,
is_edge,
));
}
ti += 1;
}
(None, None) => break,
}
}
Ok(changes)
}
// ─── Table added/removed ────────────────────────────────────────────────────
async fn diff_table_added(
table_store: &TableStore,
to_snap: &Snapshot,
table_key: &str,
is_edge: bool,
filter: &ChangeFilter,
) -> Result<Vec<EntityChange>> {
if !filter.wants_op(ChangeOp::Insert) {
return Ok(Vec::new());
}
let ds = table_store.open_snapshot_table(to_snap, table_key).await?;
let rows = scan_all_rows_ordered(table_store, &ds, is_edge).await?;
Ok(rows
.into_iter()
.map(|r| entity_change_from_row(&r, ChangeOp::Insert, is_edge))
.collect())
}
async fn diff_table_removed(
table_store: &TableStore,
from_snap: &Snapshot,
table_key: &str,
is_edge: bool,
filter: &ChangeFilter,
) -> Result<Vec<EntityChange>> {
if !filter.wants_op(ChangeOp::Delete) {
return Ok(Vec::new());
}
let ds = table_store
.open_snapshot_table(from_snap, table_key)
.await?;
let rows = scan_all_rows_ordered(table_store, &ds, is_edge).await?;
Ok(rows
.into_iter()
.map(|r| entity_change_from_row(&r, ChangeOp::Delete, is_edge))
.collect())
}
// ─── Helpers ────────────────────────────────────────────────────────────────
/// Scan with a SQL filter, projecting specific columns.
async fn scan_with_filter(
table_store: &TableStore,
ds: &lance::Dataset,
cols: &[&str],
filter_sql: &str,
) -> Result<Vec<ScannedRow>> {
let batches = table_store
.scan(ds, Some(cols), Some(filter_sql), None)
.await?;
Ok(extract_rows(&batches))
}
/// Scan all rows ordered by id, projecting id (+ src/dst for edges) + all columns for signature.
async fn scan_all_rows_ordered(
table_store: &TableStore,
ds: &lance::Dataset,
is_edge: bool,
) -> Result<Vec<ScannedRow>> {
let batches = table_store
.scan(
ds,
None,
None,
Some(vec![ColumnOrdering::asc_nulls_last("id".to_string())]),
)
.await?;
Ok(extract_rows_with_signature(&batches, is_edge))
}
/// Compute deleted IDs: scan id at from and to, set-difference.
async fn deleted_ids_by_set_diff(
table_store: &TableStore,
from_ds: &lance::Dataset,
to_ds: &lance::Dataset,
is_edge: bool,
) -> Result<Vec<EntityChange>> {
let cols: Vec<&str> = if is_edge {
vec!["id", "src", "dst"]
} else {
vec!["id"]
};
let from_rows = scan_id_set(table_store, from_ds, &cols).await?;
let to_ids: HashSet<String> = scan_id_set(table_store, to_ds, &["id"])
.await?
.into_iter()
.map(|r| r.id)
.collect();
Ok(from_rows
.into_iter()
.filter(|r| !to_ids.contains(&r.id))
.map(|r| entity_change_from_row(&r, ChangeOp::Delete, is_edge))
.collect())
}
async fn scan_id_set(
table_store: &TableStore,
ds: &lance::Dataset,
cols: &[&str],
) -> Result<Vec<ScannedRow>> {
let batches = table_store.scan(ds, Some(cols), None, None).await?;
Ok(extract_rows(&batches))
}
// ─── Row extraction ─────────────────────────────────────────────────────────
#[derive(Debug, Clone)]
struct ScannedRow {
id: String,
src: Option<String>,
dst: Option<String>,
signature: String,
change_version: Option<u64>,
}
fn extract_rows(batches: &[RecordBatch]) -> Vec<ScannedRow> {
let mut rows = Vec::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.and_then(|c| c.as_any().downcast_ref::<StringArray>());
let Some(ids) = ids else { continue };
let srcs = batch
.column_by_name("src")
.and_then(|c| c.as_any().downcast_ref::<StringArray>());
let dsts = batch
.column_by_name("dst")
.and_then(|c| c.as_any().downcast_ref::<StringArray>());
for i in 0..ids.len() {
rows.push(ScannedRow {
id: ids.value(i).to_string(),
src: srcs.map(|a| a.value(i).to_string()),
dst: dsts.map(|a| a.value(i).to_string()),
signature: String::new(),
change_version: batch
.column_by_name("_row_last_updated_at_version")
.and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
.map(|versions| versions.value(i)),
});
}
}
rows
}
fn extract_rows_with_signature(batches: &[RecordBatch], is_edge: bool) -> Vec<ScannedRow> {
let mut rows = Vec::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.and_then(|c| c.as_any().downcast_ref::<StringArray>());
let Some(ids) = ids else { continue };
let srcs = if is_edge {
batch
.column_by_name("src")
.and_then(|c| c.as_any().downcast_ref::<StringArray>())
} else {
None
};
let dsts = if is_edge {
batch
.column_by_name("dst")
.and_then(|c| c.as_any().downcast_ref::<StringArray>())
} else {
None
};
for i in 0..ids.len() {
let mut values = Vec::with_capacity(batch.num_columns());
for (field, col) in batch.schema().fields().iter().zip(batch.columns()) {
if field.name().starts_with("_row_") {
continue;
}
if let Ok(v) = array_value_to_string(col.as_ref(), i) {
values.push(v);
}
}
rows.push(ScannedRow {
id: ids.value(i).to_string(),
src: srcs.map(|a| a.value(i).to_string()),
dst: dsts.map(|a| a.value(i).to_string()),
signature: values.join("\x1f"),
change_version: batch
.column_by_name("_row_last_updated_at_version")
.and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
.map(|versions| versions.value(i)),
});
}
}
rows
}
fn entity_change_from_row(row: &ScannedRow, op: ChangeOp, is_edge: bool) -> EntityChange {
EntityChange {
table_key: String::new(),
kind: if is_edge {
EntityKind::Edge
} else {
EntityKind::Node
},
type_name: String::new(),
id: row.id.clone(),
op,
manifest_version: row.change_version.unwrap_or(0),
endpoints: if is_edge {
Some(Endpoints {
src: row.src.clone().unwrap_or_default(),
dst: row.dst.clone().unwrap_or_default(),
})
} else {
None
},
}
}

View file

@ -0,0 +1,692 @@
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use arrow_array::{
Array, RecordBatch, RecordBatchIterator, StringArray, TimestampMicrosecondArray, UInt64Array,
};
use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
use futures::TryStreamExt;
use lance::Dataset;
use lance::dataset::{WriteMode, WriteParams};
use lance_file::version::LanceFileVersion;
use crate::error::{OmniError, Result};
const GRAPH_COMMITS_DIR: &str = "_graph_commits.lance";
const GRAPH_COMMIT_ACTORS_DIR: &str = "_graph_commit_actors.lance";
#[derive(Debug, Clone)]
pub struct GraphCommit {
pub graph_commit_id: String,
pub manifest_branch: Option<String>,
pub manifest_version: u64,
pub parent_commit_id: Option<String>,
pub merged_parent_commit_id: Option<String>,
pub actor_id: Option<String>,
pub created_at: i64,
}
pub struct CommitGraph {
root_uri: String,
dataset: Dataset,
actor_dataset: Option<Dataset>,
active_branch: Option<String>,
actor_by_commit_id: HashMap<String, String>,
commit_by_id: HashMap<String, GraphCommit>,
head_commit: Option<GraphCommit>,
}
impl CommitGraph {
pub async fn init(root_uri: &str, manifest_version: u64) -> Result<Self> {
let root = root_uri.trim_end_matches('/');
let uri = graph_commits_uri(root);
let genesis = GraphCommit {
graph_commit_id: ulid::Ulid::new().to_string(),
manifest_branch: None,
manifest_version,
parent_commit_id: None,
merged_parent_commit_id: None,
actor_id: None,
created_at: now_micros()?,
};
let batch = commits_to_batch(&[genesis.clone()])?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], commit_graph_schema());
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
};
let dataset = Dataset::write(reader, &uri as &str, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let actor_dataset = create_commit_actor_dataset(root).await?;
Ok(Self {
root_uri: root.to_string(),
dataset,
actor_dataset: Some(actor_dataset),
active_branch: None,
actor_by_commit_id: HashMap::new(),
commit_by_id: HashMap::from([(genesis.graph_commit_id.clone(), genesis.clone())]),
head_commit: Some(genesis),
})
}
pub async fn open(root_uri: &str) -> Result<Self> {
let root = root_uri.trim_end_matches('/');
let dataset = Dataset::open(&graph_commits_uri(root))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let actor_dataset = Dataset::open(&graph_commit_actors_uri(root)).await.ok();
let actor_by_commit_id = match &actor_dataset {
Some(dataset) => load_commit_actor_cache(dataset).await?,
None => HashMap::new(),
};
let (commit_by_id, head_commit) = load_commit_cache(&dataset, &actor_by_commit_id).await?;
Ok(Self {
root_uri: root.to_string(),
dataset,
actor_dataset,
active_branch: None,
actor_by_commit_id,
commit_by_id,
head_commit,
})
}
pub async fn open_at_branch(root_uri: &str, branch: &str) -> Result<Self> {
let root = root_uri.trim_end_matches('/');
let dataset = Dataset::open(&graph_commits_uri(root))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let dataset = dataset
.checkout_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let actor_dataset = Dataset::open(&graph_commit_actors_uri(root)).await.ok();
let actor_by_commit_id = match &actor_dataset {
Some(dataset) => load_commit_actor_cache(dataset).await?,
None => HashMap::new(),
};
let (commit_by_id, head_commit) = load_commit_cache(&dataset, &actor_by_commit_id).await?;
Ok(Self {
root_uri: root.to_string(),
dataset,
actor_dataset,
active_branch: Some(branch.to_string()),
actor_by_commit_id,
commit_by_id,
head_commit,
})
}
pub async fn refresh(&mut self) -> Result<()> {
let root = self.root_uri.clone();
self.dataset = Dataset::open(&graph_commits_uri(&root))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
if let Some(branch) = &self.active_branch {
self.dataset = self
.dataset
.checkout_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
self.actor_dataset = Dataset::open(&graph_commit_actors_uri(&root)).await.ok();
self.actor_by_commit_id = match &self.actor_dataset {
Some(dataset) => load_commit_actor_cache(dataset).await?,
None => HashMap::new(),
};
let (commit_by_id, head_commit) =
load_commit_cache(&self.dataset, &self.actor_by_commit_id).await?;
self.commit_by_id = commit_by_id;
self.head_commit = head_commit;
Ok(())
}
pub fn version(&self) -> u64 {
self.dataset.version().version
}
pub async fn create_branch(&mut self, name: &str) -> Result<()> {
let mut ds = self.dataset.clone();
ds.create_branch(name, self.version(), None)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(())
}
pub async fn delete_branch(&mut self, name: &str) -> Result<()> {
let mut ds = Dataset::open(&graph_commits_uri(&self.root_uri))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
ds.delete_branch(name)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.refresh().await
}
pub async fn append_commit(
&mut self,
manifest_branch: Option<&str>,
manifest_version: u64,
actor_id: Option<&str>,
) -> Result<String> {
let parent_commit_id = self.head_commit_id().await?;
self.append_commit_with_parents(
manifest_branch,
manifest_version,
parent_commit_id.as_deref(),
None,
actor_id,
)
.await
}
pub async fn append_merge_commit(
&mut self,
manifest_branch: Option<&str>,
manifest_version: u64,
parent_commit_id: &str,
merged_parent_commit_id: &str,
actor_id: Option<&str>,
) -> Result<String> {
self.append_commit_with_parents(
manifest_branch,
manifest_version,
Some(parent_commit_id),
Some(merged_parent_commit_id),
actor_id,
)
.await
}
async fn append_commit_with_parents(
&mut self,
manifest_branch: Option<&str>,
manifest_version: u64,
parent_commit_id: Option<&str>,
merged_parent_commit_id: Option<&str>,
actor_id: Option<&str>,
) -> Result<String> {
let graph_commit_id = ulid::Ulid::new().to_string();
let commit = GraphCommit {
graph_commit_id: graph_commit_id.clone(),
manifest_branch: manifest_branch.map(|s| s.to_string()),
manifest_version,
parent_commit_id: parent_commit_id.map(|s| s.to_string()),
merged_parent_commit_id: merged_parent_commit_id.map(|s| s.to_string()),
actor_id: actor_id.map(str::to_string),
created_at: now_micros()?,
};
let batch = commits_to_batch(&[commit.clone()])?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], commit_graph_schema());
let mut ds = self.dataset.clone();
ds.append(reader, None)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.dataset = ds;
if let Some(actor_id) = actor_id {
self.append_actor(&graph_commit_id, actor_id).await?;
}
self.commit_by_id
.insert(graph_commit_id.clone(), commit.clone());
if should_replace_head(self.head_commit.as_ref(), &commit) {
self.head_commit = Some(commit);
}
Ok(graph_commit_id)
}
async fn append_actor(&mut self, graph_commit_id: &str, actor_id: &str) -> Result<()> {
if self
.actor_by_commit_id
.get(graph_commit_id)
.is_some_and(|existing| existing == actor_id)
{
return Ok(());
}
let record = CommitActorRecord {
graph_commit_id: graph_commit_id.to_string(),
actor_id: actor_id.to_string(),
created_at: now_micros()?,
};
let batch = commit_actors_to_batch(&[record])?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], commit_actor_schema());
let mut dataset = match self.actor_dataset.take() {
Some(dataset) => dataset,
None => create_commit_actor_dataset(&self.root_uri).await?,
};
dataset
.append(reader, None)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.actor_by_commit_id
.insert(graph_commit_id.to_string(), actor_id.to_string());
self.actor_dataset = Some(dataset);
Ok(())
}
pub async fn head_commit(&self) -> Result<Option<GraphCommit>> {
Ok(self.head_commit.clone())
}
pub async fn head_commit_id(&self) -> Result<Option<String>> {
Ok(self.head_commit().await?.map(|c| c.graph_commit_id))
}
pub async fn load_commits(&self) -> Result<Vec<GraphCommit>> {
let mut commits = self.commit_by_id.values().cloned().collect::<Vec<_>>();
commits.sort_by(|a, b| {
a.manifest_version
.cmp(&b.manifest_version)
.then_with(|| a.created_at.cmp(&b.created_at))
.then_with(|| a.graph_commit_id.cmp(&b.graph_commit_id))
});
Ok(commits)
}
pub fn get_commit(&self, commit_id: &str) -> Option<GraphCommit> {
self.commit_by_id.get(commit_id).cloned()
}
pub async fn merge_base(
root_uri: &str,
source_branch: Option<&str>,
target_branch: Option<&str>,
) -> Result<Option<GraphCommit>> {
let source = open_for_branch(root_uri, source_branch).await?;
let target = open_for_branch(root_uri, target_branch).await?;
let source_head = match source.head_commit().await? {
Some(commit) => commit,
None => return Ok(None),
};
let target_head = match target.head_commit().await? {
Some(commit) => commit,
None => return Ok(None),
};
let mut commits = HashMap::new();
for commit in source.load_commits().await? {
commits.insert(commit.graph_commit_id.clone(), commit);
}
for commit in target.load_commits().await? {
commits.insert(commit.graph_commit_id.clone(), commit);
}
let source_distances = ancestor_distances(&source_head.graph_commit_id, &commits);
let target_distances = ancestor_distances(&target_head.graph_commit_id, &commits);
let best = source_distances
.iter()
.filter_map(|(id, source_distance)| {
target_distances.get(id).and_then(|target_distance| {
commits.get(id).map(|commit| {
(
(
*source_distance + *target_distance,
u64::MAX - commit.manifest_version,
),
commit.clone(),
)
})
})
})
.min_by_key(|(score, _)| *score)
.map(|(_, commit)| commit);
Ok(best)
}
}
fn graph_commits_uri(root_uri: &str) -> String {
format!("{}/{}", root_uri.trim_end_matches('/'), GRAPH_COMMITS_DIR)
}
fn graph_commit_actors_uri(root_uri: &str) -> String {
format!(
"{}/{}",
root_uri.trim_end_matches('/'),
GRAPH_COMMIT_ACTORS_DIR
)
}
fn commit_graph_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("graph_commit_id", DataType::Utf8, false),
Field::new("manifest_branch", DataType::Utf8, true),
Field::new("manifest_version", DataType::UInt64, false),
Field::new("parent_commit_id", DataType::Utf8, true),
Field::new("merged_parent_commit_id", DataType::Utf8, true),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
]))
}
fn commit_actor_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("graph_commit_id", DataType::Utf8, false),
Field::new("actor_id", DataType::Utf8, false),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
]))
}
#[derive(Debug, Clone)]
struct CommitActorRecord {
graph_commit_id: String,
actor_id: String,
created_at: i64,
}
async fn create_commit_actor_dataset(root_uri: &str) -> Result<Dataset> {
let uri = graph_commit_actors_uri(root_uri);
let batch = RecordBatch::new_empty(commit_actor_schema());
let reader = RecordBatchIterator::new(vec![Ok(batch)], commit_actor_schema());
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
};
match Dataset::write(reader, &uri as &str, Some(params)).await {
Ok(dataset) => Ok(dataset),
Err(err) if err.to_string().contains("Dataset already exists") => Dataset::open(&uri)
.await
.map_err(|open_err| OmniError::Lance(open_err.to_string())),
Err(err) => Err(OmniError::Lance(err.to_string())),
}
}
fn commits_to_batch(commits: &[GraphCommit]) -> Result<RecordBatch> {
let ids: Vec<&str> = commits.iter().map(|c| c.graph_commit_id.as_str()).collect();
let branches: Vec<Option<&str>> = commits
.iter()
.map(|c| c.manifest_branch.as_deref())
.collect();
let versions: Vec<u64> = commits.iter().map(|c| c.manifest_version).collect();
let parents: Vec<Option<&str>> = commits
.iter()
.map(|c| c.parent_commit_id.as_deref())
.collect();
let merged_parents: Vec<Option<&str>> = commits
.iter()
.map(|c| c.merged_parent_commit_id.as_deref())
.collect();
let created_at: Vec<i64> = commits.iter().map(|c| c.created_at).collect();
RecordBatch::try_new(
commit_graph_schema(),
vec![
Arc::new(StringArray::from(ids)),
Arc::new(StringArray::from(branches)),
Arc::new(UInt64Array::from(versions)),
Arc::new(StringArray::from(parents)),
Arc::new(StringArray::from(merged_parents)),
Arc::new(TimestampMicrosecondArray::from(created_at)),
],
)
.map_err(|e| OmniError::Lance(e.to_string()))
}
async fn load_commit_cache(
dataset: &Dataset,
actor_by_commit_id: &HashMap<String, String>,
) -> Result<(HashMap<String, GraphCommit>, Option<GraphCommit>)> {
let batches: Vec<RecordBatch> = dataset
.scan()
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut commits = load_commits_from_batches(&batches)?;
for commit in &mut commits {
commit.actor_id = actor_by_commit_id
.get(commit.graph_commit_id.as_str())
.cloned();
}
let mut commit_by_id = HashMap::with_capacity(commits.len());
let mut head_commit = None;
for commit in commits {
if should_replace_head(head_commit.as_ref(), &commit) {
head_commit = Some(commit.clone());
}
commit_by_id.insert(commit.graph_commit_id.clone(), commit);
}
Ok((commit_by_id, head_commit))
}
async fn load_commit_actor_cache(dataset: &Dataset) -> Result<HashMap<String, String>> {
let batches: Vec<RecordBatch> = dataset
.scan()
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut actors = HashMap::new();
for batch in batches {
let commit_ids = string_column(&batch, "graph_commit_id", "commit actor registry")?;
let actor_ids = string_column(&batch, "actor_id", "commit actor registry")?;
for row in 0..batch.num_rows() {
actors.insert(
commit_ids.value(row).to_string(),
actor_ids.value(row).to_string(),
);
}
}
Ok(actors)
}
fn load_commits_from_batches(batches: &[RecordBatch]) -> Result<Vec<GraphCommit>> {
let mut commits = Vec::new();
for batch in batches {
let ids = string_column(batch, "graph_commit_id", "commit graph")?;
let branches = string_column(batch, "manifest_branch", "commit graph")?;
let versions = u64_column(batch, "manifest_version", "commit graph")?;
let parents = string_column(batch, "parent_commit_id", "commit graph")?;
let merged_parents = string_column(batch, "merged_parent_commit_id", "commit graph")?;
let created = timestamp_micros_column(batch, "created_at", "commit graph")?;
for row in 0..batch.num_rows() {
commits.push(GraphCommit {
graph_commit_id: ids.value(row).to_string(),
manifest_branch: if branches.is_null(row) {
None
} else {
Some(branches.value(row).to_string())
},
manifest_version: versions.value(row),
parent_commit_id: if parents.is_null(row) {
None
} else {
Some(parents.value(row).to_string())
},
merged_parent_commit_id: if merged_parents.is_null(row) {
None
} else {
Some(merged_parents.value(row).to_string())
},
actor_id: None,
created_at: created.value(row),
});
}
}
Ok(commits)
}
fn commit_actors_to_batch(records: &[CommitActorRecord]) -> Result<RecordBatch> {
let commit_ids: Vec<&str> = records
.iter()
.map(|record| record.graph_commit_id.as_str())
.collect();
let actor_ids: Vec<&str> = records
.iter()
.map(|record| record.actor_id.as_str())
.collect();
let created_at: Vec<i64> = records.iter().map(|record| record.created_at).collect();
RecordBatch::try_new(
commit_actor_schema(),
vec![
Arc::new(StringArray::from(commit_ids)),
Arc::new(StringArray::from(actor_ids)),
Arc::new(TimestampMicrosecondArray::from(created_at)),
],
)
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn should_replace_head(current: Option<&GraphCommit>, candidate: &GraphCommit) -> bool {
current.is_none_or(|existing| {
candidate
.manifest_version
.cmp(&existing.manifest_version)
.then_with(|| candidate.created_at.cmp(&existing.created_at))
.then_with(|| candidate.graph_commit_id.cmp(&existing.graph_commit_id))
.is_gt()
})
}
fn string_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a StringArray> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} column '{name}' is not Utf8"))
})
}
fn u64_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a UInt64Array> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} column '{name}' is not UInt64"))
})
}
fn timestamp_micros_column<'a>(
batch: &'a RecordBatch,
name: &str,
context: &str,
) -> Result<&'a TimestampMicrosecondArray> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"{context} column '{name}' is not Timestamp(Microsecond)"
))
})
}
fn ancestor_distances(
start_id: &str,
commits: &HashMap<String, GraphCommit>,
) -> HashMap<String, u64> {
let mut distances = HashMap::new();
let mut queue = VecDeque::from([(start_id.to_string(), 0u64)]);
while let Some((id, distance)) = queue.pop_front() {
if let Some(existing) = distances.get(&id) {
if *existing <= distance {
continue;
}
}
distances.insert(id.clone(), distance);
if let Some(commit) = commits.get(&id) {
if let Some(parent) = &commit.parent_commit_id {
queue.push_back((parent.clone(), distance + 1));
}
if let Some(parent) = &commit.merged_parent_commit_id {
queue.push_back((parent.clone(), distance + 1));
}
}
}
distances
}
async fn open_for_branch(root_uri: &str, branch: Option<&str>) -> Result<CommitGraph> {
match branch {
Some(branch) if branch != "main" => CommitGraph::open_at_branch(root_uri, branch).await,
_ => CommitGraph::open(root_uri).await,
}
}
fn now_micros() -> Result<i64> {
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| OmniError::manifest(format!("system clock before UNIX_EPOCH: {}", e)))?;
Ok(duration.as_micros() as i64)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema};
use super::*;
#[test]
fn load_commits_from_batches_returns_error_for_bad_schema() {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("graph_commit_id", DataType::UInt64, false),
Field::new("manifest_branch", DataType::Utf8, true),
Field::new("manifest_version", DataType::UInt64, false),
Field::new("parent_commit_id", DataType::Utf8, true),
Field::new("merged_parent_commit_id", DataType::Utf8, true),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
])),
vec![
Arc::new(UInt64Array::from(vec![1_u64])),
Arc::new(StringArray::from(vec![None::<&str>])),
Arc::new(UInt64Array::from(vec![1_u64])),
Arc::new(StringArray::from(vec![None::<&str>])),
Arc::new(StringArray::from(vec![None::<&str>])),
Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
],
)
.unwrap();
let err = load_commits_from_batches(&[batch]).unwrap_err();
assert!(err.to_string().contains("graph_commit_id"));
}
}

View file

@ -0,0 +1,562 @@
use std::fmt;
use std::sync::Arc;
use omnigraph_compiler::catalog::Catalog;
use crate::error::{OmniError, Result};
use crate::failpoints;
use crate::storage::{StorageAdapter, join_uri, normalize_root_uri};
use super::commit_graph::{CommitGraph, GraphCommit};
use super::manifest::{ManifestCoordinator, Snapshot, SubTableUpdate};
use super::run_registry::{RunId, RunRecord, RunRegistry, graph_runs_uri, is_internal_run_branch};
const GRAPH_COMMITS_DIR: &str = "_graph_commits.lance";
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SnapshotId(String);
impl SnapshotId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub(crate) fn synthetic(branch: Option<&str>, version: u64) -> Self {
match branch {
Some(branch) => Self(format!("manifest:{}:v{}", branch, version)),
None => Self(format!("manifest:main:v{}", version)),
}
}
}
impl fmt::Display for SnapshotId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReadTarget {
Branch(String),
Snapshot(SnapshotId),
}
impl ReadTarget {
pub fn branch(name: impl Into<String>) -> Self {
Self::Branch(name.into())
}
pub fn snapshot(id: impl Into<SnapshotId>) -> Self {
Self::Snapshot(id.into())
}
}
impl From<&str> for ReadTarget {
fn from(value: &str) -> Self {
Self::branch(value)
}
}
impl From<String> for ReadTarget {
fn from(value: String) -> Self {
Self::Branch(value)
}
}
impl From<SnapshotId> for ReadTarget {
fn from(value: SnapshotId) -> Self {
Self::Snapshot(value)
}
}
#[derive(Debug, Clone)]
pub struct ResolvedTarget {
pub requested: ReadTarget,
pub branch: Option<String>,
pub snapshot_id: SnapshotId,
pub snapshot: Snapshot,
}
#[derive(Debug, Clone)]
pub(crate) struct PublishedSnapshot {
pub manifest_version: u64,
pub _snapshot_id: SnapshotId,
}
pub struct GraphCoordinator {
root_uri: String,
storage: Arc<dyn StorageAdapter>,
manifest: ManifestCoordinator,
commit_graph: Option<CommitGraph>,
run_registry: Option<RunRegistry>,
bound_branch: Option<String>,
}
impl GraphCoordinator {
pub async fn init(
root_uri: &str,
catalog: &Catalog,
storage: Arc<dyn StorageAdapter>,
) -> Result<Self> {
let root = normalize_root_uri(root_uri)?;
let manifest = ManifestCoordinator::init(&root, catalog).await?;
let commit_graph = Some(CommitGraph::init(&root, manifest.version()).await?);
Ok(Self {
root_uri: root,
storage,
manifest,
commit_graph,
run_registry: None,
bound_branch: None,
})
}
pub async fn open(root_uri: &str, storage: Arc<dyn StorageAdapter>) -> Result<Self> {
let root = normalize_root_uri(root_uri)?;
let manifest = ManifestCoordinator::open(&root).await?;
let commit_graph = if storage.exists(&graph_commits_uri(&root)).await? {
Some(CommitGraph::open(&root).await?)
} else {
None
};
let run_registry = if storage.exists(&graph_runs_uri(&root)).await? {
Some(RunRegistry::open(&root).await?)
} else {
None
};
Ok(Self {
root_uri: root,
storage,
manifest,
commit_graph,
run_registry,
bound_branch: None,
})
}
pub async fn open_branch(
root_uri: &str,
branch: &str,
storage: Arc<dyn StorageAdapter>,
) -> Result<Self> {
let branch = normalize_branch_name(branch)?;
let Some(branch_name) = branch else {
return Self::open(root_uri, storage).await;
};
let root = normalize_root_uri(root_uri)?;
let manifest = ManifestCoordinator::open_at_branch(&root, &branch_name).await?;
let commit_graph = if storage.exists(&graph_commits_uri(&root)).await? {
Some(CommitGraph::open_at_branch(&root, &branch_name).await?)
} else {
None
};
let run_registry = if storage.exists(&graph_runs_uri(&root)).await? {
Some(RunRegistry::open(&root).await?)
} else {
None
};
Ok(Self {
root_uri: root,
storage,
manifest,
commit_graph,
run_registry,
bound_branch: Some(branch_name),
})
}
pub fn root_uri(&self) -> &str {
&self.root_uri
}
pub fn version(&self) -> u64 {
self.manifest.version()
}
pub fn snapshot(&self) -> Snapshot {
self.manifest.snapshot()
}
pub fn current_branch(&self) -> Option<&str> {
self.bound_branch.as_deref()
}
pub async fn refresh(&mut self) -> Result<()> {
self.manifest.refresh().await?;
if let Some(commit_graph) = &mut self.commit_graph {
commit_graph.refresh().await?;
}
if let Some(run_registry) = &mut self.run_registry {
let root_uri = self.root_uri.clone();
run_registry.refresh(&root_uri).await?;
}
Ok(())
}
pub async fn branch_list(&self) -> Result<Vec<String>> {
self.manifest.list_branches().await.map(|branches| {
branches
.into_iter()
.filter(|branch| !is_internal_run_branch(branch))
.collect()
})
}
pub async fn branch_descendants(&self, name: &str) -> Result<Vec<String>> {
self.manifest
.descendant_branches(name)
.await
.map(|branches| {
branches
.into_iter()
.filter(|branch| !is_internal_run_branch(branch))
.collect()
})
}
pub async fn branch_create(&mut self, name: &str) -> Result<()> {
let branch = normalize_branch_name(name)?
.ok_or_else(|| OmniError::manifest("cannot create branch 'main'".to_string()))?;
self.ensure_commit_graph_initialized().await?;
self.manifest.create_branch(&branch).await?;
failpoints::maybe_fail("branch_create.after_manifest_branch_create")?;
if let Some(commit_graph) = &mut self.commit_graph {
commit_graph.create_branch(&branch).await?;
}
Ok(())
}
pub async fn branch_delete(&mut self, name: &str) -> Result<()> {
let branch = normalize_branch_name(name)?
.ok_or_else(|| OmniError::manifest("cannot delete branch 'main'".to_string()))?;
if self.current_branch() == Some(branch.as_str()) {
return Err(OmniError::manifest_conflict(format!(
"cannot delete currently active branch '{}'",
branch
)));
}
self.manifest.delete_branch(&branch).await?;
if let Some(commit_graph) = &mut self.commit_graph {
commit_graph.delete_branch(&branch).await?;
} else if self
.storage
.exists(&graph_commits_uri(self.root_uri()))
.await?
{
let mut commit_graph = CommitGraph::open(self.root_uri()).await?;
commit_graph.delete_branch(&branch).await?;
}
Ok(())
}
pub async fn snapshot_at_version(&self, version: u64) -> Result<Snapshot> {
ManifestCoordinator::snapshot_at(self.root_uri(), self.current_branch(), version).await
}
pub async fn resolve_snapshot_id(&self, branch: &str) -> Result<SnapshotId> {
let normalized = normalize_branch_name(branch)?;
let other = match normalized.as_deref() {
Some(branch) => {
GraphCoordinator::open_branch(self.root_uri(), branch, Arc::clone(&self.storage))
.await?
}
None => GraphCoordinator::open(self.root_uri(), Arc::clone(&self.storage)).await?,
};
Ok(other
.head_commit_id()
.await?
.unwrap_or_else(|| SnapshotId::synthetic(other.current_branch(), other.version())))
}
pub async fn resolve_target(&self, target: &ReadTarget) -> Result<ResolvedTarget> {
match target {
ReadTarget::Branch(branch) => {
let normalized = normalize_branch_name(branch)?;
let other = match normalized.as_deref() {
Some(branch) => {
GraphCoordinator::open_branch(
self.root_uri(),
branch,
Arc::clone(&self.storage),
)
.await?
}
None => {
GraphCoordinator::open(self.root_uri(), Arc::clone(&self.storage)).await?
}
};
let snapshot_id = other.head_commit_id().await?.unwrap_or_else(|| {
SnapshotId::synthetic(other.current_branch(), other.version())
});
Ok(ResolvedTarget {
requested: target.clone(),
branch: other.bound_branch.clone(),
snapshot_id,
snapshot: other.snapshot(),
})
}
ReadTarget::Snapshot(snapshot_id) => {
let commit = self.resolve_commit(snapshot_id).await?;
let snapshot = ManifestCoordinator::snapshot_at(
self.root_uri(),
commit.manifest_branch.as_deref(),
commit.manifest_version,
)
.await?;
Ok(ResolvedTarget {
requested: target.clone(),
branch: commit.manifest_branch.clone(),
snapshot_id: snapshot_id.clone(),
snapshot,
})
}
}
}
pub async fn resolve_commit(&self, snapshot_id: &SnapshotId) -> Result<GraphCommit> {
if let Some(commit_graph) = &self.commit_graph {
if let Some(commit) = commit_graph.get_commit(snapshot_id.as_str()) {
return Ok(commit);
}
}
for branch in self.manifest.list_branches().await? {
let normalized = normalize_branch_name(&branch)?;
let Some(commit_graph) = self
.open_commit_graph_for_branch(normalized.as_deref())
.await?
else {
break;
};
if let Some(commit) = commit_graph.get_commit(snapshot_id.as_str()) {
return Ok(commit);
}
}
Err(OmniError::manifest_not_found(format!(
"commit '{}' not found",
snapshot_id
)))
}
pub(crate) async fn head_commit_id(&self) -> Result<Option<SnapshotId>> {
match &self.commit_graph {
Some(commit_graph) => commit_graph
.head_commit_id()
.await
.map(|id| id.map(SnapshotId::new)),
None => Ok(None),
}
}
pub(crate) async fn ensure_commit_graph_initialized(&mut self) -> Result<()> {
if self.commit_graph.is_some() {
return Ok(());
}
if !self
.storage
.exists(&graph_commits_uri(self.root_uri()))
.await?
{
let _ = CommitGraph::init(self.root_uri(), self.manifest.version()).await?;
}
self.commit_graph = match self.current_branch() {
Some(branch) => Some(CommitGraph::open_at_branch(self.root_uri(), branch).await?),
None => Some(CommitGraph::open(self.root_uri()).await?),
};
Ok(())
}
pub(crate) async fn ensure_run_registry_initialized(&mut self) -> Result<()> {
if self.run_registry.is_some() {
return Ok(());
}
if !self
.storage
.exists(&graph_runs_uri(self.root_uri()))
.await?
{
let _ = RunRegistry::init(self.root_uri()).await?;
}
self.run_registry = Some(RunRegistry::open(self.root_uri()).await?);
Ok(())
}
pub(crate) async fn commit_updates_with_actor(
&mut self,
updates: &[SubTableUpdate],
actor_id: Option<&str>,
) -> Result<PublishedSnapshot> {
let manifest_version = self.commit_manifest_updates(updates).await?;
let snapshot_id = self.record_graph_commit(manifest_version, actor_id).await?;
Ok(PublishedSnapshot {
manifest_version,
_snapshot_id: snapshot_id,
})
}
pub(crate) async fn commit_manifest_updates(
&mut self,
updates: &[SubTableUpdate],
) -> Result<u64> {
let manifest_version = self.manifest.commit(updates).await?;
failpoints::maybe_fail("graph_publish.after_manifest_commit")?;
Ok(manifest_version)
}
pub(crate) async fn record_graph_commit(
&mut self,
manifest_version: u64,
actor_id: Option<&str>,
) -> Result<SnapshotId> {
self.ensure_commit_graph_initialized().await?;
let current_branch = self.current_branch().map(str::to_string);
let Some(commit_graph) = &mut self.commit_graph else {
return Ok(SnapshotId::synthetic(
current_branch.as_deref(),
manifest_version,
));
};
failpoints::maybe_fail("graph_publish.before_commit_append")?;
let graph_commit_id = commit_graph
.append_commit(current_branch.as_deref(), manifest_version, actor_id)
.await?;
Ok(SnapshotId::new(graph_commit_id))
}
pub(crate) async fn record_merge_commit(
&mut self,
manifest_version: u64,
parent_commit_id: &str,
merged_parent_commit_id: &str,
actor_id: Option<&str>,
) -> Result<SnapshotId> {
self.ensure_commit_graph_initialized().await?;
let current_branch = self.current_branch().map(str::to_string);
let commit_graph = self.commit_graph.as_mut().ok_or_else(|| {
OmniError::manifest("branch merge requires _graph_commits.lance".to_string())
})?;
failpoints::maybe_fail("graph_publish.before_commit_append")?;
let graph_commit_id = commit_graph
.append_merge_commit(
current_branch.as_deref(),
manifest_version,
parent_commit_id,
merged_parent_commit_id,
actor_id,
)
.await?;
Ok(SnapshotId::new(graph_commit_id))
}
async fn open_commit_graph_for_branch(
&self,
branch: Option<&str>,
) -> Result<Option<CommitGraph>> {
if !self
.storage
.exists(&graph_commits_uri(self.root_uri()))
.await?
{
return Ok(None);
}
let graph = match branch {
Some(branch) => CommitGraph::open_at_branch(self.root_uri(), branch).await?,
None => CommitGraph::open(self.root_uri()).await?,
};
Ok(Some(graph))
}
pub(crate) async fn append_run_record(&mut self, record: &RunRecord) -> Result<()> {
self.ensure_run_registry_initialized().await?;
let Some(run_registry) = &mut self.run_registry else {
return Err(OmniError::manifest(
"run registry not initialized".to_string(),
));
};
run_registry.append_record(record).await
}
pub(crate) async fn get_run(&self, run_id: &RunId) -> Result<RunRecord> {
if let Some(run_registry) = &self.run_registry {
if let Some(run) = run_registry.get_run(run_id).await? {
return Ok(run);
}
}
if !self
.storage
.exists(&graph_runs_uri(self.root_uri()))
.await?
{
return Err(OmniError::manifest_not_found(format!(
"run '{}' not found",
run_id
)));
}
let run_registry = RunRegistry::open(self.root_uri()).await?;
run_registry
.get_run(run_id)
.await?
.ok_or_else(|| OmniError::manifest_not_found(format!("run '{}' not found", run_id)))
}
pub(crate) async fn list_runs(&self) -> Result<Vec<RunRecord>> {
if let Some(run_registry) = &self.run_registry {
return run_registry.list_runs().await;
}
if !self
.storage
.exists(&graph_runs_uri(self.root_uri()))
.await?
{
return Ok(Vec::new());
}
let run_registry = RunRegistry::open(self.root_uri()).await?;
run_registry.list_runs().await
}
pub(crate) async fn list_commits(&self) -> Result<Vec<GraphCommit>> {
if let Some(commit_graph) = &self.commit_graph {
return commit_graph.load_commits().await;
}
if !self
.storage
.exists(&graph_commits_uri(self.root_uri()))
.await?
{
return Ok(Vec::new());
}
let commit_graph = match self.current_branch() {
Some(branch) => CommitGraph::open_at_branch(self.root_uri(), branch).await?,
None => CommitGraph::open(self.root_uri()).await?,
};
commit_graph.load_commits().await
}
}
fn graph_commits_uri(root_uri: &str) -> String {
join_uri(root_uri, GRAPH_COMMITS_DIR)
}
fn normalize_branch_name(branch: &str) -> Result<Option<String>> {
let branch = branch.trim();
if branch.is_empty() {
return Err(OmniError::manifest(
"branch name cannot be empty".to_string(),
));
}
if branch == "main" {
return Ok(None);
}
Ok(Some(branch.to_string()))
}

View file

@ -0,0 +1,339 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::error::{OmniError, Result};
use lance::Dataset;
use lance_namespace::models::CreateTableVersionRequest;
use omnigraph_compiler::catalog::Catalog;
#[path = "manifest/layout.rs"]
mod layout;
#[path = "manifest/metadata.rs"]
mod metadata;
#[path = "manifest/namespace.rs"]
mod namespace;
#[path = "manifest/publisher.rs"]
mod publisher;
#[path = "manifest/repo.rs"]
mod repo;
#[path = "manifest/state.rs"]
mod state;
use layout::{manifest_uri, open_manifest_dataset};
pub(crate) use metadata::TableVersionMetadata;
#[cfg(test)]
use metadata::{OMNIGRAPH_ROW_COUNT_KEY, table_version_metadata_for_state};
use namespace::open_table_at_version_from_manifest;
pub(crate) use namespace::open_table_head_for_write;
#[cfg(test)]
use namespace::{branch_manifest_namespace, staged_table_namespace};
use publisher::{GraphNamespacePublisher, ManifestBatchPublisher};
use repo::{init_manifest_repo, open_manifest_repo, snapshot_state_at};
pub use state::SubTableEntry;
#[cfg(test)]
use state::string_column;
use state::{ManifestState, read_manifest_state};
const OBJECT_TYPE_TABLE: &str = "table";
const OBJECT_TYPE_TABLE_VERSION: &str = "table_version";
const TABLE_VERSION_MANAGEMENT_KEY: &str = "table_version_management";
/// Immutable point-in-time view of the database.
///
/// Cheap to create (no storage I/O). All reads within a query go through one
/// Snapshot to guarantee cross-type consistency.
#[derive(Debug, Clone)]
pub struct Snapshot {
root_uri: String,
version: u64,
entries: HashMap<String, SubTableEntry>,
}
impl Snapshot {
/// Open a sub-table dataset at its pinned version.
pub async fn open(&self, table_key: &str) -> Result<Dataset> {
let entry = self
.entries
.get(table_key)
.ok_or_else(|| OmniError::manifest(format!("no manifest entry for {}", table_key)))?;
entry.open(&self.root_uri).await
}
/// Manifest version this snapshot was taken from.
pub fn version(&self) -> u64 {
self.version
}
/// Look up a sub-table entry by key.
pub fn entry(&self, table_key: &str) -> Option<&SubTableEntry> {
self.entries.get(table_key)
}
pub fn entries(&self) -> impl Iterator<Item = &SubTableEntry> {
self.entries.values()
}
}
impl SubTableUpdate {
pub(crate) fn to_create_table_version_request(&self) -> CreateTableVersionRequest {
self.version_metadata.to_create_table_version_request(
&self.table_key,
self.table_version,
self.row_count,
self.table_branch.as_deref(),
)
}
}
impl SubTableEntry {
pub(crate) async fn open(&self, root_uri: &str) -> Result<Dataset> {
open_table_at_version_from_manifest(
root_uri,
&self.table_key,
self.table_branch.as_deref(),
self.table_version,
)
.await
}
}
/// An update to apply to the manifest via `commit`.
#[derive(Debug, Clone)]
pub struct SubTableUpdate {
pub table_key: String,
pub table_version: u64,
pub table_branch: Option<String>,
pub row_count: u64,
pub(crate) version_metadata: TableVersionMetadata,
}
/// Coordinates cross-dataset state through the namespace `__manifest` table.
///
/// Table rows register stable metadata such as location. Append-only
/// `table_version` rows are the graph publish boundary and reconstruct the
/// current graph snapshot by selecting the latest visible version row per
/// sub-table.
pub struct ManifestCoordinator {
root_uri: String,
dataset: Dataset,
known_state: ManifestState,
active_branch: Option<String>,
publisher: Arc<dyn ManifestBatchPublisher>,
}
impl ManifestCoordinator {
fn default_batch_publisher(
root_uri: &str,
active_branch: Option<&str>,
) -> Arc<dyn ManifestBatchPublisher> {
Arc::new(GraphNamespacePublisher::new(root_uri, active_branch))
}
fn from_parts(
root_uri: &str,
dataset: Dataset,
known_state: ManifestState,
active_branch: Option<String>,
publisher: Arc<dyn ManifestBatchPublisher>,
) -> Self {
Self {
root_uri: root_uri.trim_end_matches('/').to_string(),
dataset,
known_state,
active_branch,
publisher,
}
}
fn from_parts_with_default_publisher(
root_uri: &str,
dataset: Dataset,
known_state: ManifestState,
active_branch: Option<String>,
) -> Self {
let publisher = Self::default_batch_publisher(root_uri, active_branch.as_deref());
Self::from_parts(root_uri, dataset, known_state, active_branch, publisher)
}
fn snapshot_from_state(root_uri: &str, state: ManifestState) -> Snapshot {
Snapshot {
root_uri: root_uri.trim_end_matches('/').to_string(),
version: state.version,
entries: state
.entries
.into_iter()
.map(|entry| (entry.table_key.clone(), entry))
.collect(),
}
}
#[cfg(test)]
fn with_batch_publisher(mut self, publisher: Arc<dyn ManifestBatchPublisher>) -> Self {
self.publisher = publisher;
self
}
/// Create a new repo at `root_uri` from a catalog.
///
/// Creates per-type Lance datasets and the namespace `__manifest` table.
pub async fn init(root_uri: &str, catalog: &Catalog) -> Result<Self> {
let root = root_uri.trim_end_matches('/');
let (dataset, known_state) = init_manifest_repo(root, catalog).await?;
Ok(Self::from_parts_with_default_publisher(
root,
dataset,
known_state,
None,
))
}
/// Open an existing repo's manifest.
pub async fn open(root_uri: &str) -> Result<Self> {
let root = root_uri.trim_end_matches('/');
let (dataset, known_state) = open_manifest_repo(root, None).await?;
Ok(Self::from_parts_with_default_publisher(
root,
dataset,
known_state,
None,
))
}
/// Open an existing repo's manifest at a specific branch.
pub async fn open_at_branch(root_uri: &str, branch: &str) -> Result<Self> {
if branch == "main" {
return Self::open(root_uri).await;
}
let root = root_uri.trim_end_matches('/');
let (dataset, known_state) = open_manifest_repo(root, Some(branch)).await?;
Ok(Self::from_parts_with_default_publisher(
root,
dataset,
known_state,
Some(branch.to_string()),
))
}
pub async fn snapshot_at(
root_uri: &str,
branch: Option<&str>,
version: u64,
) -> Result<Snapshot> {
let root = root_uri.trim_end_matches('/');
Ok(Self::snapshot_from_state(
root,
snapshot_state_at(root, branch, version).await?,
))
}
/// Return a Snapshot from the known manifest state. No storage I/O.
pub fn snapshot(&self) -> Snapshot {
Self::snapshot_from_state(&self.root_uri, self.known_state.clone())
}
/// Re-read manifest from storage to see other writers' commits.
pub async fn refresh(&mut self) -> Result<()> {
self.dataset = open_manifest_dataset(&self.root_uri, self.active_branch.as_deref()).await?;
self.known_state = read_manifest_state(&self.dataset).await?;
Ok(())
}
/// Commit updated sub-table versions to the manifest.
///
/// Atomically inserts one immutable `table_version` row per updated table.
/// The merge-insert commit on `__manifest` is the graph-level publish point.
pub async fn commit(&mut self, updates: &[SubTableUpdate]) -> Result<u64> {
if updates.is_empty() {
return Ok(self.version());
}
self.dataset = self.publisher.publish(updates).await?;
self.known_state = read_manifest_state(&self.dataset).await?;
Ok(self.version())
}
/// Current manifest version.
pub fn version(&self) -> u64 {
self.dataset.version().version
}
pub fn active_branch(&self) -> Option<&str> {
self.active_branch.as_deref()
}
pub async fn create_branch(&mut self, name: &str) -> Result<()> {
let mut ds = self.dataset.clone();
ds.create_branch(name, self.version(), None)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(())
}
pub async fn delete_branch(&mut self, name: &str) -> Result<()> {
let uri = manifest_uri(&self.root_uri);
let mut ds = Dataset::open(&uri)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
ds.delete_branch(name)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.dataset = open_manifest_dataset(&self.root_uri, self.active_branch.as_deref()).await?;
self.known_state = read_manifest_state(&self.dataset).await?;
Ok(())
}
pub async fn list_branches(&self) -> Result<Vec<String>> {
let branches = self
.dataset
.list_branches()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut names: Vec<String> = branches.into_keys().filter(|name| name != "main").collect();
names.sort();
let mut all = vec!["main".to_string()];
all.extend(names);
Ok(all)
}
pub async fn descendant_branches(&self, name: &str) -> Result<Vec<String>> {
let branches = self
.dataset
.list_branches()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut frontier = vec![name.to_string()];
let mut descendants = Vec::new();
let mut seen = HashSet::new();
while let Some(parent) = frontier.pop() {
let mut children = branches
.iter()
.filter_map(|(branch, contents)| {
(contents.parent_branch.as_deref() == Some(parent.as_str()))
.then_some(branch.clone())
})
.collect::<Vec<_>>();
children.sort();
for child in children {
if seen.insert(child.clone()) {
frontier.push(child.clone());
descendants.push(child);
}
}
}
Ok(descendants)
}
/// Root URI of the repo.
pub fn root_uri(&self) -> &str {
&self.root_uri
}
}
#[cfg(test)]
#[path = "manifest/tests.rs"]
mod tests;

View file

@ -0,0 +1,74 @@
use lance::Dataset;
use lance_namespace::Error as LanceNamespaceError;
use crate::error::{OmniError, Result};
use crate::storage::{StorageKind, join_uri, storage_kind_for_uri};
const MANIFEST_DIR: &str = "__manifest";
pub(super) fn type_name_hash(name: &str) -> String {
let mut h: u64 = 0xcbf29ce484222325;
for byte in name.as_bytes() {
h ^= *byte as u64;
h = h.wrapping_mul(0x100000001b3);
}
format!("{:016x}", h)
}
pub(super) fn manifest_uri(root: &str) -> String {
format!("{}/{}", root.trim_end_matches('/'), MANIFEST_DIR)
}
pub(super) async fn open_manifest_dataset(root_uri: &str, branch: Option<&str>) -> Result<Dataset> {
let dataset = Dataset::open(&manifest_uri(root_uri.trim_end_matches('/')))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
match branch {
Some(branch) if branch != "main" => dataset
.checkout_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string())),
_ => Ok(dataset),
}
}
fn format_table_version(version: u64) -> String {
format!("{version:020}")
}
pub(super) fn version_object_id(table_key: &str, version: u64) -> String {
format!("{}${}", table_key, format_table_version(version))
}
pub(super) fn table_id_to_key(request_id: Option<&Vec<String>>) -> lance_namespace::Result<String> {
match request_id {
Some(request_id) if request_id.len() == 1 && !request_id[0].is_empty() => {
Ok(request_id[0].clone())
}
Some(request_id) => Err(LanceNamespaceError::invalid_input(format!(
"expected single table id component, got {:?}",
request_id
))),
None => Err(LanceNamespaceError::invalid_input("table id is required")),
}
}
pub(super) fn table_uri_for_path(root_uri: &str, table_path: &str, branch: Option<&str>) -> String {
let mut dataset_location = join_uri(root_uri, table_path);
if let Some(branch) = branch.filter(|branch| *branch != "main") {
dataset_location = join_uri(&dataset_location, "tree");
for segment in branch.split('/') {
dataset_location = join_uri(&dataset_location, segment);
}
}
match storage_kind_for_uri(root_uri) {
StorageKind::Local => url::Url::from_file_path(&dataset_location)
.map(|uri| uri.to_string())
.unwrap_or(dataset_location),
StorageKind::S3 => dataset_location,
}
}
pub(super) fn namespace_internal_error(message: impl Into<String>) -> LanceNamespaceError {
LanceNamespaceError::namespace_source(Box::new(std::io::Error::other(message.into())))
}

View file

@ -0,0 +1,244 @@
use std::collections::HashMap;
use lance::Dataset;
use lance_namespace::Error as LanceNamespaceError;
use lance_namespace::models::{CreateTableVersionRequest, TableVersion};
use serde::{Deserialize, Serialize};
use crate::error::{OmniError, Result};
use crate::storage::{StorageKind, join_uri, storage_kind_for_uri};
use super::layout::table_id_to_key;
pub(super) const OMNIGRAPH_ROW_COUNT_KEY: &str = "omnigraph.row_count";
const OMNIGRAPH_TABLE_BRANCH_KEY: &str = "omnigraph.table_branch";
pub(super) fn namespace_version_metadata(
row_count: u64,
table_branch: Option<&str>,
) -> HashMap<String, String> {
let mut metadata =
HashMap::from([(OMNIGRAPH_ROW_COUNT_KEY.to_string(), row_count.to_string())]);
if let Some(table_branch) = table_branch {
metadata.insert(
OMNIGRAPH_TABLE_BRANCH_KEY.to_string(),
table_branch.to_string(),
);
}
metadata
}
pub(super) fn parse_namespace_version_request(
request: &CreateTableVersionRequest,
) -> lance_namespace::Result<(String, u64, u64, Option<String>, TableVersionMetadata)> {
let table_key = table_id_to_key(request.id.as_ref())?;
let version = u64::try_from(request.version)
.map_err(|_| LanceNamespaceError::invalid_input("table version must be non-negative"))?;
let metadata = request.metadata.as_ref().ok_or_else(|| {
LanceNamespaceError::invalid_input("version metadata is required for Omnigraph rows")
})?;
let row_count = metadata
.get(OMNIGRAPH_ROW_COUNT_KEY)
.ok_or_else(|| {
LanceNamespaceError::invalid_input("missing omnigraph.row_count in metadata")
})?
.parse::<u64>()
.map_err(|e| {
LanceNamespaceError::invalid_input(format!("invalid omnigraph.row_count value: {}", e))
})?;
let table_branch = metadata.get(OMNIGRAPH_TABLE_BRANCH_KEY).cloned();
let version_metadata = TableVersionMetadata {
manifest_path: request.manifest_path.clone(),
manifest_size: request.manifest_size.map(|size| size as u64),
e_tag: request.e_tag.clone(),
naming_scheme: request.naming_scheme.clone(),
};
Ok((
table_key,
version,
row_count,
table_branch,
version_metadata,
))
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) struct TableVersionMetadata {
manifest_path: String,
manifest_size: Option<u64>,
e_tag: Option<String>,
naming_scheme: Option<String>,
}
impl TableVersionMetadata {
pub(crate) fn from_dataset(
root_uri: &str,
table_path: &str,
dataset: &Dataset,
) -> Result<Self> {
Ok(Self {
manifest_path: full_manifest_object_store_path(
root_uri,
table_path,
&dataset.manifest_location().path.to_string(),
)?,
manifest_size: dataset.manifest_location().size,
e_tag: dataset.manifest_location().e_tag.clone(),
naming_scheme: Some(format!("{:?}", dataset.manifest_location().naming_scheme)),
})
}
pub(super) fn from_json_str(value: &str) -> Result<Self> {
serde_json::from_str(value).map_err(|e| {
OmniError::manifest_internal(format!("failed to decode manifest metadata: {e}"))
})
}
pub(super) fn to_json_string(&self) -> Result<String> {
serde_json::to_string(self).map_err(|e| {
OmniError::manifest_internal(format!("failed to encode manifest metadata: {e}"))
})
}
#[cfg(test)]
pub(crate) fn manifest_path(&self) -> &str {
&self.manifest_path
}
#[cfg(test)]
pub(crate) fn manifest_size(&self) -> Option<u64> {
self.manifest_size
}
#[cfg(test)]
pub(crate) fn e_tag(&self) -> Option<&str> {
self.e_tag.as_deref()
}
#[cfg(test)]
pub(crate) fn naming_scheme(&self) -> Option<&str> {
self.naming_scheme.as_deref()
}
pub(crate) fn to_create_table_version_request(
&self,
table_key: &str,
table_version: u64,
row_count: u64,
table_branch: Option<&str>,
) -> CreateTableVersionRequest {
let mut request =
CreateTableVersionRequest::new(table_version as i64, self.manifest_path.clone());
request.id = Some(vec![table_key.to_string()]);
request.manifest_size = self.manifest_size.map(|size| size as i64);
request.e_tag = self.e_tag.clone();
request.naming_scheme = self.naming_scheme.clone();
request.metadata = Some(namespace_version_metadata(row_count, table_branch));
request
}
pub(super) fn to_namespace_version(&self, version: u64) -> TableVersion {
self.to_namespace_version_with_details(version, None, None)
}
pub(super) fn to_namespace_version_with_details(
&self,
version: u64,
timestamp_millis: Option<i64>,
metadata: Option<HashMap<String, String>>,
) -> TableVersion {
let mut metadata = metadata.unwrap_or_default();
if let Some(naming_scheme) = &self.naming_scheme {
metadata.insert("naming_scheme".to_string(), naming_scheme.clone());
}
TableVersion {
version: version as i64,
manifest_path: self.manifest_path.clone(),
manifest_size: self.manifest_size.map(|size| size as i64),
e_tag: self.e_tag.clone(),
timestamp_millis,
metadata: (!metadata.is_empty()).then_some(metadata),
}
}
}
fn object_store_path_from_uri(uri: &str) -> Result<String> {
match storage_kind_for_uri(uri) {
StorageKind::Local => {
if uri.strip_prefix("file://").is_some() {
let path = url::Url::parse(uri)
.map_err(|e| {
OmniError::manifest_internal(format!("invalid file uri '{}': {}", uri, e))
})?
.to_file_path()
.map_err(|_| {
OmniError::manifest_internal(format!("invalid file uri '{}'", uri))
})?;
Ok(path.to_string_lossy().to_string())
} else {
Ok(uri.to_string())
}
}
StorageKind::S3 => {
let url = url::Url::parse(uri).map_err(|e| {
OmniError::manifest_internal(format!("invalid s3 uri '{}': {}", uri, e))
})?;
Ok(url.path().trim_start_matches('/').to_string())
}
}
}
fn full_manifest_object_store_path(
root_uri: &str,
table_path: &str,
manifest_path: &str,
) -> Result<String> {
if manifest_path.contains("://") {
return object_store_path_from_uri(manifest_path);
}
if manifest_path.contains(table_path) {
return Ok(manifest_path.to_string());
}
let dataset_uri = join_uri(root_uri, table_path);
let dataset_path = object_store_path_from_uri(&dataset_uri)?;
let manifest_path = manifest_path.trim_start_matches('/');
if manifest_path.is_empty() {
return Ok(dataset_path);
}
Ok(format!(
"{}/{}",
dataset_path.trim_end_matches('/'),
manifest_path
))
}
#[cfg(test)]
pub(super) async fn table_version_metadata_for_state(
root_uri: &str,
table_path: &str,
branch: Option<&str>,
version: u64,
) -> Result<TableVersionMetadata> {
let full_path = format!("{}/{}", root_uri.trim_end_matches('/'), table_path);
let ds = Dataset::open(&full_path)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let ds = match branch {
Some(branch) => ds
.checkout_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?,
None => ds,
};
let ds = ds
.checkout_version(version)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
TableVersionMetadata::from_dataset(root_uri, table_path, &ds)
}

View file

@ -0,0 +1,549 @@
use std::sync::Arc;
use async_trait::async_trait;
use lance::Dataset;
use lance::dataset::builder::DatasetBuilder;
use lance_namespace::models::{
CreateTableVersionRequest, CreateTableVersionResponse, DescribeTableRequest,
DescribeTableResponse, DescribeTableVersionRequest, DescribeTableVersionResponse,
ListTableVersionsRequest, ListTableVersionsResponse, TableExistsRequest, TableVersion,
};
use lance_namespace::{Error as LanceNamespaceError, LanceNamespace, NamespaceError};
use lance_table::io::commit::ManifestNamingScheme;
use object_store::{Error as ObjectStoreError, ObjectStore as _, PutMode, PutOptions, path::Path};
use crate::error::{OmniError, Result};
use super::layout::{
namespace_internal_error, open_manifest_dataset, table_id_to_key, table_uri_for_path,
};
use super::metadata::{
TableVersionMetadata, namespace_version_metadata, parse_namespace_version_request,
};
use super::publisher::GraphNamespacePublisher;
use super::state::{ManifestState, SubTableEntry, read_manifest_entries, read_manifest_state};
#[derive(Debug, Clone)]
struct BranchManifestNamespace {
root_uri: String,
branch: Option<String>,
}
impl BranchManifestNamespace {
fn new(root_uri: &str, branch: Option<&str>) -> Self {
Self {
root_uri: root_uri.trim_end_matches('/').to_string(),
branch: branch
.filter(|branch| *branch != "main")
.map(ToOwned::to_owned),
}
}
async fn dataset(&self) -> Result<Dataset> {
open_manifest_dataset(&self.root_uri, self.branch.as_deref()).await
}
async fn state(&self) -> Result<ManifestState> {
let dataset = self.dataset().await?;
read_manifest_state(&dataset).await
}
async fn version_entries(&self) -> Result<Vec<SubTableEntry>> {
let dataset = self.dataset().await?;
read_manifest_entries(&dataset).await
}
}
#[derive(Debug, Clone)]
struct StagedTableNamespace {
root_uri: String,
table_id: Vec<String>,
table_path: String,
branch: Option<String>,
}
impl StagedTableNamespace {
fn new(root_uri: &str, table_key: &str, table_path: &str, branch: Option<&str>) -> Self {
Self {
root_uri: root_uri.trim_end_matches('/').to_string(),
table_id: vec![table_key.to_string()],
table_path: table_path.to_string(),
branch: branch
.filter(|branch| *branch != "main")
.map(ToOwned::to_owned),
}
}
fn table_key(&self) -> &str {
&self.table_id[0]
}
fn table_uri(&self) -> String {
table_uri_for_path(&self.root_uri, &self.table_path, self.branch.as_deref())
}
fn ensure_request_table(
&self,
request_id: Option<&Vec<String>>,
) -> lance_namespace::Result<()> {
match request_id {
Some(request_id) if request_id == &self.table_id => Ok(()),
Some(request_id) => Err(LanceNamespaceError::namespace_source(Box::new(
NamespaceError::TableNotFound {
message: format!("table {:?} not found", request_id),
},
))),
None => Err(LanceNamespaceError::invalid_input("table id is required")),
}
}
async fn open_head(&self) -> Result<Dataset> {
Dataset::open(&self.table_uri())
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
async fn open_version(&self, version: u64) -> Result<Dataset> {
let ds = self.open_head().await?;
if ds.version().version == version {
Ok(ds)
} else {
ds.checkout_version(version)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
}
fn to_table_version(
&self,
dataset: &Dataset,
version: &lance::dataset::Version,
) -> Result<TableVersion> {
let metadata =
TableVersionMetadata::from_dataset(&self.root_uri, &self.table_path, dataset)?;
Ok(metadata.to_namespace_version_with_details(
version.version,
Some(version.timestamp.timestamp_millis()),
Some(
version
.metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
),
))
}
}
pub(crate) fn branch_manifest_namespace(
root_uri: &str,
branch: Option<&str>,
) -> Arc<dyn LanceNamespace> {
Arc::new(BranchManifestNamespace::new(root_uri, branch))
}
pub(crate) fn staged_table_namespace(
root_uri: &str,
table_key: &str,
table_path: &str,
branch: Option<&str>,
) -> Arc<dyn LanceNamespace> {
Arc::new(StagedTableNamespace::new(
root_uri, table_key, table_path, branch,
))
}
async fn load_table_from_namespace(
namespace: Arc<dyn LanceNamespace>,
table_key: &str,
branch: Option<&str>,
version: Option<u64>,
) -> Result<Dataset> {
let builder = DatasetBuilder::from_namespace(namespace, vec![table_key.to_string()])
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let builder = match (branch, version) {
(Some(branch), version) => builder.with_branch(branch, version),
(None, Some(version)) => builder.with_version(version),
(None, None) => builder,
};
builder
.load()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub(crate) async fn open_table_at_version_from_manifest(
root_uri: &str,
table_key: &str,
branch: Option<&str>,
version: u64,
) -> Result<Dataset> {
load_table_from_namespace(
branch_manifest_namespace(root_uri, branch),
table_key,
branch,
Some(version),
)
.await
}
#[async_trait]
impl LanceNamespace for BranchManifestNamespace {
fn namespace_id(&self) -> String {
"__manifest".to_string()
}
async fn describe_table(
&self,
request: DescribeTableRequest,
) -> lance_namespace::Result<DescribeTableResponse> {
let table_key = table_id_to_key(request.id.as_ref())?;
let state = self
.state()
.await
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?;
let entry = state
.entries
.into_iter()
.find(|entry| entry.table_key == table_key);
let entry = entry.ok_or_else(|| {
LanceNamespaceError::namespace_source(Box::new(NamespaceError::TableNotFound {
message: format!("table {} not found", table_key),
}))
})?;
let table_uri = table_uri_for_path(
&self.root_uri,
&entry.table_path,
entry.table_branch.as_deref(),
);
Ok(DescribeTableResponse {
table: Some(entry.table_key.clone()),
namespace: Some(Vec::new()),
version: Some(entry.table_version as i64),
location: Some(table_uri.clone()),
table_uri: request.with_table_uri.unwrap_or(false).then_some(table_uri),
schema: None,
storage_options: None,
stats: None,
metadata: None,
properties: None,
managed_versioning: Some(true),
})
}
async fn table_exists(&self, request: TableExistsRequest) -> lance_namespace::Result<()> {
let table_key = table_id_to_key(request.id.as_ref())?;
let state = self
.state()
.await
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?;
if state
.entries
.iter()
.any(|entry| entry.table_key == table_key)
{
Ok(())
} else {
Err(LanceNamespaceError::namespace_source(Box::new(
NamespaceError::TableNotFound {
message: format!("table {} not found", table_key),
},
)))
}
}
async fn list_table_versions(
&self,
request: ListTableVersionsRequest,
) -> lance_namespace::Result<ListTableVersionsResponse> {
let table_key = table_id_to_key(request.id.as_ref())?;
let mut versions: Vec<TableVersion> = self
.version_entries()
.await
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?
.into_iter()
.filter(|entry| entry.table_key == table_key)
.map(|entry| {
entry
.version_metadata
.to_namespace_version(entry.table_version)
})
.collect();
if request.descending.unwrap_or(false) {
versions.sort_by(|a, b| b.version.cmp(&a.version));
} else {
versions.sort_by(|a, b| a.version.cmp(&b.version));
}
if let Some(limit) = request.limit {
versions.truncate(limit as usize);
}
Ok(ListTableVersionsResponse {
versions,
page_token: None,
})
}
async fn describe_table_version(
&self,
request: DescribeTableVersionRequest,
) -> lance_namespace::Result<DescribeTableVersionResponse> {
let table_key = table_id_to_key(request.id.as_ref())?;
let version = request
.version
.ok_or_else(|| LanceNamespaceError::invalid_input("table version is required"))?;
let version = u64::try_from(version).map_err(|_| {
LanceNamespaceError::invalid_input("table version must be non-negative")
})?;
let entry = self
.version_entries()
.await
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?
.into_iter()
.find(|entry| entry.table_key == table_key && entry.table_version == version)
.ok_or_else(|| {
LanceNamespaceError::namespace_source(Box::new(
NamespaceError::TableVersionNotFound {
message: format!("table version {} not found for {}", version, table_key),
},
))
})?;
Ok(DescribeTableVersionResponse::new(
entry
.version_metadata
.to_namespace_version(entry.table_version),
))
}
async fn create_table_version(
&self,
request: CreateTableVersionRequest,
) -> lance_namespace::Result<CreateTableVersionResponse> {
let (table_key, table_version, row_count, table_branch, version_metadata) =
parse_namespace_version_request(&request)?;
GraphNamespacePublisher::new(&self.root_uri, self.branch.as_deref())
.publish_requests(std::slice::from_ref(&request))
.await
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?;
let mut response = CreateTableVersionResponse::new();
response.version = Some(Box::new(
version_metadata.to_namespace_version_with_details(
table_version,
None,
Some(namespace_version_metadata(
row_count,
table_branch.as_deref(),
)),
),
));
let _ = table_key;
Ok(response)
}
}
#[async_trait]
impl LanceNamespace for StagedTableNamespace {
fn namespace_id(&self) -> String {
"__manifest".to_string()
}
async fn describe_table(
&self,
request: DescribeTableRequest,
) -> lance_namespace::Result<DescribeTableResponse> {
self.ensure_request_table(request.id.as_ref())?;
let ds = self
.open_head()
.await
.map_err(|e| namespace_internal_error(e.to_string()))?;
let table_uri = self.table_uri();
Ok(DescribeTableResponse {
table: Some(self.table_key().to_string()),
namespace: Some(Vec::new()),
version: Some(ds.version().version as i64),
location: Some(table_uri.clone()),
table_uri: request.with_table_uri.unwrap_or(false).then_some(table_uri),
schema: None,
storage_options: None,
stats: None,
metadata: None,
properties: None,
managed_versioning: Some(true),
})
}
async fn table_exists(&self, request: TableExistsRequest) -> lance_namespace::Result<()> {
self.ensure_request_table(request.id.as_ref())?;
self.open_head()
.await
.map(|_| ())
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))
}
async fn list_table_versions(
&self,
request: ListTableVersionsRequest,
) -> lance_namespace::Result<ListTableVersionsResponse> {
self.ensure_request_table(request.id.as_ref())?;
if request.limit == Some(0) {
return Ok(ListTableVersionsResponse {
versions: Vec::new(),
page_token: None,
});
}
let head = self
.open_head()
.await
.map_err(|e| namespace_internal_error(e.to_string()))?;
let dataset_versions = head
.versions()
.await
.map_err(|e| namespace_internal_error(e.to_string()))?;
let mut versions = Vec::with_capacity(dataset_versions.len());
for version in dataset_versions {
let dataset = if version.version == head.version().version {
head.clone()
} else {
head.checkout_version(version.version)
.await
.map_err(|e| namespace_internal_error(e.to_string()))?
};
versions.push(
self.to_table_version(&dataset, &version)
.map_err(|e| namespace_internal_error(e.to_string()))?,
);
}
if request.descending.unwrap_or(false) {
versions.sort_by(|a, b| b.version.cmp(&a.version));
} else {
versions.sort_by(|a, b| a.version.cmp(&b.version));
}
if let Some(limit) = request.limit {
versions.truncate(limit as usize);
}
Ok(ListTableVersionsResponse {
versions,
page_token: None,
})
}
async fn describe_table_version(
&self,
request: DescribeTableVersionRequest,
) -> lance_namespace::Result<DescribeTableVersionResponse> {
self.ensure_request_table(request.id.as_ref())?;
let version = request
.version
.ok_or_else(|| LanceNamespaceError::invalid_input("table version is required"))?;
let version = u64::try_from(version).map_err(|_| {
LanceNamespaceError::invalid_input("table version must be non-negative")
})?;
let ds = self
.open_version(version)
.await
.map_err(|e| namespace_internal_error(e.to_string()))?;
let version_info = self
.to_table_version(
&ds,
&lance::dataset::Version {
version: ds.version().version,
timestamp: ds.version().timestamp,
metadata: ds.version().metadata,
},
)
.map_err(|e| namespace_internal_error(e.to_string()))?;
Ok(DescribeTableVersionResponse::new(version_info))
}
async fn create_table_version(
&self,
request: CreateTableVersionRequest,
) -> lance_namespace::Result<CreateTableVersionResponse> {
self.ensure_request_table(request.id.as_ref())?;
let version = u64::try_from(request.version).map_err(|_| {
LanceNamespaceError::invalid_input("table version must be non-negative")
})?;
let naming_scheme = match request.naming_scheme.as_deref() {
Some("V1") => ManifestNamingScheme::V1,
_ => ManifestNamingScheme::V2,
};
let (object_store, base_path, _) = DatasetBuilder::from_uri(&self.table_uri())
.build_object_store()
.await
.map_err(|e| namespace_internal_error(e.to_string()))?;
let staging_path = Path::from(request.manifest_path.clone());
let manifest_data = object_store
.inner
.get(&staging_path)
.await
.map_err(|e| namespace_internal_error(e.to_string()))?
.bytes()
.await
.map_err(|e| namespace_internal_error(e.to_string()))?;
let final_path = naming_scheme.manifest_path(&base_path, version);
object_store
.inner
.put_opts(
&final_path,
manifest_data.into(),
PutOptions {
mode: PutMode::Create,
..Default::default()
},
)
.await
.map_err(|e| match e {
ObjectStoreError::AlreadyExists { .. } | ObjectStoreError::Precondition { .. } => {
LanceNamespaceError::namespace_source(Box::new(
NamespaceError::ConcurrentModification {
message: format!(
"table version {} already exists for {}",
version,
self.table_key()
),
},
))
}
other => namespace_internal_error(other.to_string()),
})?;
let meta = object_store
.inner
.head(&final_path)
.await
.map_err(|e| namespace_internal_error(e.to_string()))?;
match object_store.inner.delete(&staging_path).await {
Ok(_) | Err(ObjectStoreError::NotFound { .. }) => {}
Err(e) => return Err(namespace_internal_error(e.to_string())),
}
let mut response = CreateTableVersionResponse::new();
response.version = Some(Box::new(TableVersion {
version: version as i64,
manifest_path: final_path.to_string(),
manifest_size: Some(meta.size as i64),
e_tag: meta.e_tag,
timestamp_millis: None,
metadata: request.metadata,
}));
Ok(response)
}
}
pub(crate) async fn open_table_head_for_write(
root_uri: &str,
table_key: &str,
table_path: &str,
branch: Option<&str>,
) -> Result<Dataset> {
load_table_from_namespace(
staged_table_namespace(root_uri, table_key, table_path, branch),
table_key,
branch,
None,
)
.await
}

View file

@ -0,0 +1,236 @@
//! Graph-level batch publish over the namespace `__manifest` table.
//!
//! Lance now owns most of the table/version control plane for Omnigraph:
//! table storage, table-local versioning, namespace lookup, and native table
//! history. This module exists for the remaining graph-specific gap:
//! Omnigraph needs one atomic publish point across multiple tables and the
//! current Rust namespace surface does not expose a branch-aware
//! `BatchCreateTableVersions` path for `DirectoryNamespace`.
//!
//! Until Lance exposes that operation directly, this publisher owns only:
//! - validating batch publish invariants against the current `__manifest` state
//! - atomically inserting immutable `table_version` rows into `__manifest`
//! - returning the refreshed manifest dataset that defines the visible graph
//!
//! This module should disappear once Lance Rust can do branch-aware batch table
//! version publication against a managed namespace manifest.
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::RecordBatchIterator;
use lance::Dataset;
use lance::dataset::{MergeInsertBuilder, WhenMatched, WhenNotMatched};
use lance_namespace::NamespaceError;
use lance_namespace::models::CreateTableVersionRequest;
use crate::error::{OmniError, Result};
use super::layout::{open_manifest_dataset, version_object_id};
use super::metadata::parse_namespace_version_request;
use super::state::{
manifest_rows_batch, manifest_schema, read_manifest_entries, read_manifest_state,
};
use super::{OBJECT_TYPE_TABLE_VERSION, SubTableEntry, SubTableUpdate};
#[async_trait]
pub(super) trait ManifestBatchPublisher: Send + Sync {
async fn publish(&self, updates: &[SubTableUpdate]) -> Result<Dataset>;
}
pub(super) struct GraphNamespacePublisher {
root_uri: String,
branch: Option<String>,
}
#[derive(Debug)]
struct PendingVersionRow {
object_id: String,
metadata: Option<String>,
table_key: String,
table_version: Option<u64>,
table_branch: Option<String>,
row_count: Option<u64>,
}
impl GraphNamespacePublisher {
pub(super) fn new(root_uri: &str, branch: Option<&str>) -> Self {
Self {
root_uri: root_uri.trim_end_matches('/').to_string(),
branch: branch
.filter(|branch| *branch != "main")
.map(ToOwned::to_owned),
}
}
async fn dataset(&self) -> Result<Dataset> {
open_manifest_dataset(&self.root_uri, self.branch.as_deref()).await
}
async fn load_publish_state(
&self,
) -> Result<(
Dataset,
HashMap<String, ()>,
HashMap<(String, u64), SubTableEntry>,
)> {
let dataset = self.dataset().await?;
let current = read_manifest_state(&dataset).await?;
let existing_entries = read_manifest_entries(&dataset).await?;
let known_tables = current
.entries
.iter()
.map(|entry| (entry.table_key.clone(), ()))
.collect();
let existing_versions = existing_entries
.iter()
.map(|entry| {
(
(entry.table_key.clone(), entry.table_version),
entry.clone(),
)
})
.collect();
Ok((dataset, known_tables, existing_versions))
}
fn build_pending_rows(
requests: &[CreateTableVersionRequest],
known_tables: &HashMap<String, ()>,
existing_versions: &HashMap<(String, u64), SubTableEntry>,
) -> Result<Vec<PendingVersionRow>> {
let mut request_versions = HashMap::<(String, u64), ()>::new();
let mut rows = Vec::with_capacity(requests.len());
for request in requests {
let (table_key, table_version, row_count, table_branch, version_metadata) =
parse_namespace_version_request(request)
.map_err(|e| OmniError::Lance(e.to_string()))?;
if !known_tables.contains_key(table_key.as_str()) {
return Err(OmniError::Lance(
NamespaceError::TableNotFound {
message: format!("table {} not found", table_key),
}
.to_string(),
));
}
if request_versions
.insert((table_key.clone(), table_version), ())
.is_some()
{
return Err(OmniError::Lance(
NamespaceError::ConcurrentModification {
message: format!(
"table version {} already exists for {}",
table_version, table_key
),
}
.to_string(),
));
}
if let Some(existing) = existing_versions.get(&(table_key.clone(), table_version)) {
let is_owner_branch_handoff =
existing.row_count == row_count && existing.table_branch != table_branch;
if !is_owner_branch_handoff {
return Err(OmniError::Lance(
NamespaceError::ConcurrentModification {
message: format!(
"table version {} already exists for {}",
table_version, table_key
),
}
.to_string(),
));
}
}
rows.push(PendingVersionRow {
object_id: version_object_id(&table_key, table_version),
metadata: Some(version_metadata.to_json_string()?),
table_key,
table_version: Some(table_version),
table_branch,
row_count: Some(row_count),
});
}
Ok(rows)
}
fn pending_rows_to_batch(rows: Vec<PendingVersionRow>) -> Result<arrow_array::RecordBatch> {
let mut object_ids = Vec::with_capacity(rows.len());
let mut object_types = Vec::with_capacity(rows.len());
let mut locations: Vec<Option<String>> = Vec::with_capacity(rows.len());
let mut metadata = Vec::with_capacity(rows.len());
let mut table_keys = Vec::with_capacity(rows.len());
let mut table_versions: Vec<Option<u64>> = Vec::with_capacity(rows.len());
let mut table_branches = Vec::with_capacity(rows.len());
let mut row_counts: Vec<Option<u64>> = Vec::with_capacity(rows.len());
for row in rows {
object_ids.push(row.object_id);
object_types.push(OBJECT_TYPE_TABLE_VERSION.to_string());
locations.push(None);
metadata.push(row.metadata);
table_keys.push(row.table_key);
table_versions.push(row.table_version);
table_branches.push(row.table_branch);
row_counts.push(row.row_count);
}
manifest_rows_batch(
object_ids,
object_types,
locations,
metadata,
table_keys,
table_versions,
table_branches,
row_counts,
)
}
async fn merge_rows(&self, dataset: Dataset, rows: Vec<PendingVersionRow>) -> Result<Dataset> {
let batch = Self::pending_rows_to_batch(rows)?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], manifest_schema());
let dataset = Arc::new(dataset);
let mut merge_builder = MergeInsertBuilder::try_new(dataset, vec!["object_id".to_string()])
.map_err(|e| OmniError::Lance(e.to_string()))?;
merge_builder.when_matched(WhenMatched::UpdateAll);
merge_builder.when_not_matched(WhenNotMatched::InsertAll);
merge_builder.conflict_retries(5);
merge_builder.use_index(false);
let (new_dataset, _stats) = merge_builder
.try_build()
.map_err(|e| OmniError::Lance(e.to_string()))?
.execute_reader(Box::new(reader))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(Arc::try_unwrap(new_dataset).unwrap_or_else(|arc| (*arc).clone()))
}
pub(super) async fn publish_requests(
&self,
requests: &[CreateTableVersionRequest],
) -> Result<Dataset> {
if requests.is_empty() {
return self.dataset().await;
}
let (dataset, known_tables, existing_versions) = self.load_publish_state().await?;
let rows = Self::build_pending_rows(requests, &known_tables, &existing_versions)?;
self.merge_rows(dataset, rows).await
}
}
#[async_trait]
impl ManifestBatchPublisher for GraphNamespacePublisher {
async fn publish(&self, updates: &[SubTableUpdate]) -> Result<Dataset> {
let requests: Vec<CreateTableVersionRequest> = updates
.iter()
.map(SubTableUpdate::to_create_table_version_request)
.collect();
self.publish_requests(&requests).await
}
}

View file

@ -0,0 +1,133 @@
use std::collections::HashMap;
use arrow_array::{RecordBatch, RecordBatchIterator};
use arrow_schema::SchemaRef;
use lance::Dataset;
use lance::dataset::{WriteMode, WriteParams};
use lance_file::version::LanceFileVersion;
use omnigraph_compiler::catalog::Catalog;
use crate::error::{OmniError, Result};
use super::TABLE_VERSION_MANAGEMENT_KEY;
use super::layout::{manifest_uri, open_manifest_dataset, type_name_hash};
use super::metadata::TableVersionMetadata;
use super::state::{
ManifestState, SubTableEntry, entries_to_batch, manifest_schema, read_manifest_state,
};
pub(super) async fn init_manifest_repo(
root_uri: &str,
catalog: &Catalog,
) -> Result<(Dataset, ManifestState)> {
let root = root_uri.trim_end_matches('/');
let (entries, version_metadata) = build_initial_entries(root, catalog).await?;
let manifest_batch = entries_to_batch(&entries, &version_metadata)?;
let schema = manifest_schema();
let reader = RecordBatchIterator::new(vec![Ok(manifest_batch)], schema);
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
};
let manifest_path = manifest_uri(root);
let mut dataset = Dataset::write(reader, &manifest_path, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
dataset
.update_config([(TABLE_VERSION_MANAGEMENT_KEY, Some("true"))])
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let known_state = read_manifest_state(&dataset).await?;
Ok((dataset, known_state))
}
pub(super) async fn open_manifest_repo(
root_uri: &str,
branch: Option<&str>,
) -> Result<(Dataset, ManifestState)> {
let dataset = open_manifest_dataset(root_uri.trim_end_matches('/'), branch).await?;
let known_state = read_manifest_state(&dataset).await?;
Ok((dataset, known_state))
}
pub(super) async fn snapshot_state_at(
root_uri: &str,
branch: Option<&str>,
version: u64,
) -> Result<ManifestState> {
let dataset = open_manifest_dataset(root_uri.trim_end_matches('/'), branch).await?;
let dataset = dataset
.checkout_version(version)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
read_manifest_state(&dataset).await
}
async fn build_initial_entries(
root_uri: &str,
catalog: &Catalog,
) -> Result<(Vec<SubTableEntry>, HashMap<String, String>)> {
let mut entries = Vec::new();
let mut version_metadata = HashMap::new();
for (name, node_type) in &catalog.node_types {
let hash = type_name_hash(name);
let table_path = format!("nodes/{}", hash);
let full_path = format!("{}/{}", root_uri, table_path);
let ds = create_empty_dataset(&full_path, &node_type.arrow_schema).await?;
let table_key = format!("node:{}", name);
let metadata = TableVersionMetadata::from_dataset(root_uri, &table_path, &ds)?;
entries.push(SubTableEntry {
table_key: table_key.clone(),
table_path: table_path.clone(),
table_version: ds.version().version,
table_branch: None,
row_count: 0,
version_metadata: metadata.clone(),
});
version_metadata.insert(table_key, metadata.to_json_string()?);
}
for (name, edge_type) in &catalog.edge_types {
let hash = type_name_hash(name);
let table_path = format!("edges/{}", hash);
let full_path = format!("{}/{}", root_uri, table_path);
let ds = create_empty_dataset(&full_path, &edge_type.arrow_schema).await?;
let table_key = format!("edge:{}", name);
let metadata = TableVersionMetadata::from_dataset(root_uri, &table_path, &ds)?;
entries.push(SubTableEntry {
table_key: table_key.clone(),
table_path: table_path.clone(),
table_version: ds.version().version,
table_branch: None,
row_count: 0,
version_metadata: metadata.clone(),
});
version_metadata.insert(table_key, metadata.to_json_string()?);
}
Ok((entries, version_metadata))
}
async fn create_empty_dataset(uri: &str, schema: &SchemaRef) -> Result<Dataset> {
let batch = RecordBatch::new_empty(schema.clone());
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
allow_external_blob_outside_bases: true,
..Default::default()
};
Dataset::write(reader, uri, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}

View file

@ -0,0 +1,274 @@
use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::{Array, RecordBatch, StringArray, UInt64Array, new_null_array};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use futures::TryStreamExt;
use lance::Dataset;
use crate::error::{OmniError, Result};
use super::layout::version_object_id;
use super::metadata::TableVersionMetadata;
use super::{OBJECT_TYPE_TABLE, OBJECT_TYPE_TABLE_VERSION};
#[derive(Debug, Clone)]
pub struct SubTableEntry {
pub table_key: String,
pub table_path: String,
pub table_version: u64,
pub table_branch: Option<String>,
pub row_count: u64,
pub(crate) version_metadata: TableVersionMetadata,
}
#[derive(Debug, Clone)]
pub(super) struct ManifestState {
pub(super) version: u64,
pub(super) entries: Vec<SubTableEntry>,
}
pub(super) fn manifest_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("object_id", DataType::Utf8, false),
Field::new("object_type", DataType::Utf8, false),
Field::new("location", DataType::Utf8, true),
Field::new("metadata", DataType::Utf8, true),
Field::new(
"base_objects",
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
true,
),
Field::new("table_key", DataType::Utf8, false),
Field::new("table_version", DataType::UInt64, true),
Field::new("table_branch", DataType::Utf8, true),
Field::new("row_count", DataType::UInt64, true),
]))
}
pub(super) async fn read_manifest_state(dataset: &Dataset) -> Result<ManifestState> {
let version = dataset.version().version;
let entries = read_manifest_entries(dataset).await?;
let mut latest_versions = HashMap::<String, SubTableEntry>::new();
for entry in entries {
match latest_versions.get(&entry.table_key) {
Some(existing) if existing.table_version >= entry.table_version => {}
_ => {
latest_versions.insert(entry.table_key.clone(), entry);
}
}
}
let mut entries: Vec<SubTableEntry> = latest_versions.into_values().collect();
entries.sort_by(|a, b| a.table_key.cmp(&b.table_key));
Ok(ManifestState { version, entries })
}
pub(super) async fn read_manifest_entries(dataset: &Dataset) -> Result<Vec<SubTableEntry>> {
let batches: Vec<RecordBatch> = dataset
.scan()
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut table_locations = HashMap::new();
let mut version_entries = Vec::new();
for batch in &batches {
let object_types = string_column(batch, "object_type")?;
let locations = string_column(batch, "location")?;
let metadata = string_column(batch, "metadata")?;
let table_keys = string_column(batch, "table_key")?;
let versions = u64_column(batch, "table_version")?;
let branches = string_column(batch, "table_branch")?;
let row_counts = u64_column(batch, "row_count")?;
for row in 0..batch.num_rows() {
let table_key = table_keys.value(row).to_string();
match object_types.value(row) {
OBJECT_TYPE_TABLE => {
if locations.is_null(row) {
return Err(OmniError::manifest_internal(format!(
"manifest table row missing location for {}",
table_key
)));
}
table_locations.insert(table_key, locations.value(row).to_string());
}
OBJECT_TYPE_TABLE_VERSION => {
let table_version = required_u64(versions, row, "table_version")?;
let row_count = required_u64(row_counts, row, "row_count")?;
if metadata.is_null(row) {
return Err(OmniError::manifest_internal(format!(
"manifest table_version row missing metadata for {}",
table_key
)));
}
let table_branch = if branches.is_null(row) {
None
} else {
Some(branches.value(row).to_string())
};
version_entries.push(SubTableEntry {
table_key: table_key.clone(),
table_path: String::new(),
table_version,
table_branch,
row_count,
version_metadata: TableVersionMetadata::from_json_str(metadata.value(row))?,
});
}
_ => {}
}
}
}
let mut entries = version_entries
.into_iter()
.map(|mut entry| {
entry.table_path = table_locations
.get(&entry.table_key)
.cloned()
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"manifest missing table row for {}",
entry.table_key
))
})?;
Ok(entry)
})
.collect::<Result<Vec<_>>>()?;
entries.sort_by(|a, b| {
a.table_key
.cmp(&b.table_key)
.then(a.table_version.cmp(&b.table_version))
});
Ok(entries)
}
pub(super) fn entries_to_batch(
entries: &[SubTableEntry],
version_metadata: &HashMap<String, String>,
) -> Result<RecordBatch> {
let mut object_ids = Vec::with_capacity(entries.len() * 2);
let mut object_types = Vec::with_capacity(entries.len() * 2);
let mut locations = Vec::with_capacity(entries.len() * 2);
let mut metadata = Vec::with_capacity(entries.len() * 2);
let mut table_keys = Vec::with_capacity(entries.len() * 2);
let mut table_versions = Vec::with_capacity(entries.len() * 2);
let mut table_branches = Vec::with_capacity(entries.len() * 2);
let mut row_counts = Vec::with_capacity(entries.len() * 2);
for entry in entries {
object_ids.push(entry.table_key.clone());
object_types.push(OBJECT_TYPE_TABLE.to_string());
locations.push(Some(entry.table_path.clone()));
metadata.push(None);
table_keys.push(entry.table_key.clone());
table_versions.push(None);
table_branches.push(None);
row_counts.push(None);
object_ids.push(version_object_id(&entry.table_key, entry.table_version));
object_types.push(OBJECT_TYPE_TABLE_VERSION.to_string());
locations.push(None);
metadata.push(Some(
version_metadata
.get(&entry.table_key)
.cloned()
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"missing initial version metadata for {}",
entry.table_key
))
})?,
));
table_keys.push(entry.table_key.clone());
table_versions.push(Some(entry.table_version));
table_branches.push(entry.table_branch.clone());
row_counts.push(Some(entry.row_count));
}
manifest_rows_batch(
object_ids,
object_types,
locations,
metadata,
table_keys,
table_versions,
table_branches,
row_counts,
)
}
pub(super) fn manifest_rows_batch(
object_ids: Vec<String>,
object_types: Vec<String>,
locations: Vec<Option<String>>,
metadata: Vec<Option<String>>,
table_keys: Vec<String>,
table_versions: Vec<Option<u64>>,
table_branches: Vec<Option<String>>,
row_counts: Vec<Option<u64>>,
) -> Result<RecordBatch> {
let len = object_ids.len();
RecordBatch::try_new(
manifest_schema(),
vec![
Arc::new(StringArray::from(object_ids)),
Arc::new(StringArray::from(object_types)),
Arc::new(StringArray::from(locations)),
Arc::new(StringArray::from(metadata)),
new_null_array(
&DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
len,
),
Arc::new(StringArray::from(table_keys)),
Arc::new(UInt64Array::from(table_versions)),
Arc::new(StringArray::from(table_branches)),
Arc::new(UInt64Array::from(row_counts)),
],
)
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub(super) fn string_column<'a>(batch: &'a RecordBatch, name: &str) -> Result<&'a StringArray> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("manifest batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
OmniError::manifest_internal(format!("manifest column '{name}' is not Utf8"))
})
}
fn u64_column<'a>(batch: &'a RecordBatch, name: &str) -> Result<&'a UInt64Array> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("manifest batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| {
OmniError::manifest_internal(format!("manifest column '{name}' is not UInt64"))
})
}
fn required_u64(column: &UInt64Array, row: usize, name: &str) -> Result<u64> {
if column.is_null(row) {
return Err(OmniError::manifest_internal(format!(
"manifest column '{name}' is null at row {row}"
)));
}
Ok(column.value(row))
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,13 @@
pub mod commit_graph;
pub mod graph_coordinator;
pub mod manifest;
mod omnigraph;
mod run_registry;
mod schema_state;
pub use commit_graph::GraphCommit;
pub use graph_coordinator::{GraphCoordinator, ReadTarget, ResolvedTarget, SnapshotId};
pub use manifest::{Snapshot, SubTableEntry, SubTableUpdate};
pub use omnigraph::{MergeOutcome, Omnigraph};
pub(crate) use run_registry::is_internal_run_branch;
pub use run_registry::{RunId, RunRecord, RunStatus};

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,622 @@
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use arrow_array::{
Array, RecordBatch, RecordBatchIterator, StringArray, TimestampMicrosecondArray, UInt64Array,
};
use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
use futures::TryStreamExt;
use lance::Dataset;
use lance::dataset::{WriteMode, WriteParams};
use lance_file::version::LanceFileVersion;
use crate::error::{OmniError, Result};
const GRAPH_RUNS_DIR: &str = "_graph_runs.lance";
const GRAPH_RUN_ACTORS_DIR: &str = "_graph_run_actors.lance";
pub(crate) const INTERNAL_RUN_BRANCH_PREFIX: &str = "__run__";
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RunId(String);
impl RunId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl fmt::Display for RunId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RunStatus {
Running,
Published,
Failed,
Aborted,
}
impl RunStatus {
pub fn as_str(self) -> &'static str {
match self {
RunStatus::Running => "running",
RunStatus::Published => "published",
RunStatus::Failed => "failed",
RunStatus::Aborted => "aborted",
}
}
fn parse(value: &str) -> Result<Self> {
match value {
"running" => Ok(Self::Running),
"published" => Ok(Self::Published),
"failed" => Ok(Self::Failed),
"aborted" => Ok(Self::Aborted),
other => Err(OmniError::manifest(format!(
"invalid run status '{}'",
other
))),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RunRecord {
pub run_id: RunId,
pub target_branch: String,
pub run_branch: String,
pub base_snapshot_id: String,
pub base_manifest_version: u64,
pub operation_hash: Option<String>,
pub actor_id: Option<String>,
pub status: RunStatus,
pub published_snapshot_id: Option<String>,
pub created_at: i64,
pub updated_at: i64,
}
impl RunRecord {
pub fn new(
target_branch: impl Into<String>,
base_snapshot_id: impl Into<String>,
base_manifest_version: u64,
operation_hash: Option<String>,
actor_id: Option<String>,
) -> Result<Self> {
let now = now_micros()?;
let run_id = RunId::new(ulid::Ulid::new().to_string());
Ok(Self {
run_branch: internal_run_branch_name(&run_id),
run_id,
target_branch: target_branch.into(),
base_snapshot_id: base_snapshot_id.into(),
base_manifest_version,
operation_hash,
actor_id,
status: RunStatus::Running,
published_snapshot_id: None,
created_at: now,
updated_at: now,
})
}
pub fn with_status(
&self,
status: RunStatus,
published_snapshot_id: Option<String>,
) -> Result<Self> {
Ok(Self {
run_id: self.run_id.clone(),
target_branch: self.target_branch.clone(),
run_branch: self.run_branch.clone(),
base_snapshot_id: self.base_snapshot_id.clone(),
base_manifest_version: self.base_manifest_version,
operation_hash: self.operation_hash.clone(),
actor_id: self.actor_id.clone(),
status,
published_snapshot_id,
created_at: self.created_at,
updated_at: now_micros()?,
})
}
}
pub struct RunRegistry {
dataset: Dataset,
actor_dataset: Option<Dataset>,
latest_by_id: HashMap<String, RunRecord>,
actor_by_run_id: HashMap<String, String>,
root_uri: String,
}
impl RunRegistry {
pub async fn init(root_uri: &str) -> Result<Self> {
let uri = graph_runs_uri(root_uri);
let batch = RecordBatch::new_empty(run_registry_schema());
let reader = RecordBatchIterator::new(vec![Ok(batch)], run_registry_schema());
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
};
let dataset = Dataset::write(reader, &uri as &str, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let actor_dataset = create_run_actor_dataset(root_uri).await?;
Ok(Self {
dataset,
actor_dataset: Some(actor_dataset),
latest_by_id: HashMap::new(),
actor_by_run_id: HashMap::new(),
root_uri: root_uri.to_string(),
})
}
pub async fn open(root_uri: &str) -> Result<Self> {
let dataset = Dataset::open(&graph_runs_uri(root_uri))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let actor_dataset = Dataset::open(&graph_run_actors_uri(root_uri)).await.ok();
let actor_by_run_id = match &actor_dataset {
Some(dataset) => load_run_actor_cache(dataset).await?,
None => HashMap::new(),
};
let latest_by_id = load_run_cache(&dataset, &actor_by_run_id).await?;
Ok(Self {
dataset,
actor_dataset,
latest_by_id,
actor_by_run_id,
root_uri: root_uri.to_string(),
})
}
pub async fn refresh(&mut self, root_uri: &str) -> Result<()> {
self.dataset = Dataset::open(&graph_runs_uri(root_uri))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.actor_dataset = Dataset::open(&graph_run_actors_uri(root_uri)).await.ok();
self.actor_by_run_id = match &self.actor_dataset {
Some(dataset) => load_run_actor_cache(dataset).await?,
None => HashMap::new(),
};
self.latest_by_id = load_run_cache(&self.dataset, &self.actor_by_run_id).await?;
self.root_uri = root_uri.to_string();
Ok(())
}
pub async fn append_record(&mut self, record: &RunRecord) -> Result<()> {
let batch = runs_to_batch(&[record.clone()])?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], run_registry_schema());
let mut ds = self.dataset.clone();
ds.append(reader, None)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.dataset = ds;
if let Some(actor_id) = &record.actor_id {
self.append_actor(record.run_id.as_str(), actor_id).await?;
}
let mut record = record.clone();
if record.actor_id.is_none() {
record.actor_id = self.actor_by_run_id.get(record.run_id.as_str()).cloned();
}
merge_latest_run(&mut self.latest_by_id, record);
Ok(())
}
pub async fn get_run(&self, run_id: &RunId) -> Result<Option<RunRecord>> {
Ok(self.latest_by_id.get(run_id.as_str()).cloned())
}
pub async fn list_runs(&self) -> Result<Vec<RunRecord>> {
self.load_runs().await
}
pub async fn load_runs(&self) -> Result<Vec<RunRecord>> {
let mut runs = self.latest_by_id.values().cloned().collect::<Vec<_>>();
runs.sort_by(|a, b| {
a.created_at
.cmp(&b.created_at)
.then_with(|| a.run_id.as_str().cmp(b.run_id.as_str()))
});
Ok(runs)
}
async fn append_actor(&mut self, run_id: &str, actor_id: &str) -> Result<()> {
if self
.actor_by_run_id
.get(run_id)
.is_some_and(|existing| existing == actor_id)
{
return Ok(());
}
let record = RunActorRecord {
run_id: run_id.to_string(),
actor_id: actor_id.to_string(),
created_at: now_micros()?,
};
let batch = run_actors_to_batch(&[record])?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], run_actor_schema());
let mut dataset = match self.actor_dataset.take() {
Some(dataset) => dataset,
None => create_run_actor_dataset(&self.root_uri).await?,
};
dataset
.append(reader, None)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.actor_by_run_id
.insert(run_id.to_string(), actor_id.to_string());
self.actor_dataset = Some(dataset);
Ok(())
}
}
pub(crate) fn is_internal_run_branch(name: &str) -> bool {
name.trim_start_matches('/')
.starts_with(INTERNAL_RUN_BRANCH_PREFIX)
}
pub(crate) fn internal_run_branch_name(run_id: &RunId) -> String {
format!("{}{}", INTERNAL_RUN_BRANCH_PREFIX, run_id.as_str())
}
pub(crate) fn graph_runs_uri(root_uri: &str) -> String {
format!("{}/{}", root_uri.trim_end_matches('/'), GRAPH_RUNS_DIR)
}
fn graph_run_actors_uri(root_uri: &str) -> String {
format!(
"{}/{}",
root_uri.trim_end_matches('/'),
GRAPH_RUN_ACTORS_DIR
)
}
fn run_registry_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("run_id", DataType::Utf8, false),
Field::new("target_branch", DataType::Utf8, false),
Field::new("run_branch", DataType::Utf8, false),
Field::new("base_snapshot_id", DataType::Utf8, false),
Field::new("base_manifest_version", DataType::UInt64, false),
Field::new("operation_hash", DataType::Utf8, true),
Field::new("status", DataType::Utf8, false),
Field::new("published_snapshot_id", DataType::Utf8, true),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
Field::new(
"updated_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
]))
}
fn run_actor_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("run_id", DataType::Utf8, false),
Field::new("actor_id", DataType::Utf8, false),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
]))
}
async fn create_run_actor_dataset(root_uri: &str) -> Result<Dataset> {
let batch = RecordBatch::new_empty(run_actor_schema());
let reader = RecordBatchIterator::new(vec![Ok(batch)], run_actor_schema());
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
};
Dataset::write(
reader,
&graph_run_actors_uri(root_uri) as &str,
Some(params),
)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
async fn load_run_cache(
dataset: &Dataset,
actor_by_run_id: &HashMap<String, String>,
) -> Result<HashMap<String, RunRecord>> {
let batches: Vec<RecordBatch> = dataset
.scan()
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut latest_by_id = HashMap::new();
for mut record in load_runs_from_batches(&batches)? {
record.actor_id = actor_by_run_id.get(record.run_id.as_str()).cloned();
merge_latest_run(&mut latest_by_id, record);
}
Ok(latest_by_id)
}
async fn load_run_actor_cache(dataset: &Dataset) -> Result<HashMap<String, String>> {
let batches: Vec<RecordBatch> = dataset
.scan()
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
let mut actors = HashMap::new();
for batch in batches {
let run_ids = string_column(&batch, "run_id", "run actor registry")?;
let actor_ids = string_column(&batch, "actor_id", "run actor registry")?;
for row in 0..batch.num_rows() {
actors.insert(
run_ids.value(row).to_string(),
actor_ids.value(row).to_string(),
);
}
}
Ok(actors)
}
fn load_runs_from_batches(batches: &[RecordBatch]) -> Result<Vec<RunRecord>> {
let mut runs = Vec::new();
for batch in batches {
let run_ids = string_column(batch, "run_id", "run registry")?;
let target_branches = string_column(batch, "target_branch", "run registry")?;
let run_branches = string_column(batch, "run_branch", "run registry")?;
let base_snapshot_ids = string_column(batch, "base_snapshot_id", "run registry")?;
let base_manifest_versions = u64_column(batch, "base_manifest_version", "run registry")?;
let operation_hashes = string_column(batch, "operation_hash", "run registry")?;
let statuses = string_column(batch, "status", "run registry")?;
let published_snapshot_ids = string_column(batch, "published_snapshot_id", "run registry")?;
let created_ats = timestamp_micros_column(batch, "created_at", "run registry")?;
let updated_ats = timestamp_micros_column(batch, "updated_at", "run registry")?;
for row in 0..batch.num_rows() {
runs.push(RunRecord {
run_id: RunId::new(run_ids.value(row)),
target_branch: target_branches.value(row).to_string(),
run_branch: run_branches.value(row).to_string(),
base_snapshot_id: base_snapshot_ids.value(row).to_string(),
base_manifest_version: base_manifest_versions.value(row),
operation_hash: if operation_hashes.is_null(row) {
None
} else {
Some(operation_hashes.value(row).to_string())
},
actor_id: None,
status: RunStatus::parse(statuses.value(row))?,
published_snapshot_id: if published_snapshot_ids.is_null(row) {
None
} else {
Some(published_snapshot_ids.value(row).to_string())
},
created_at: created_ats.value(row),
updated_at: updated_ats.value(row),
});
}
}
Ok(runs)
}
fn merge_latest_run(latest_by_id: &mut HashMap<String, RunRecord>, record: RunRecord) {
match latest_by_id.get(record.run_id.as_str()) {
Some(existing)
if existing.updated_at > record.updated_at
|| (existing.updated_at == record.updated_at
&& existing.created_at >= record.created_at) => {}
_ => {
latest_by_id.insert(record.run_id.as_str().to_string(), record);
}
}
}
fn string_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a StringArray> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} column '{name}' is not Utf8"))
})
}
fn u64_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a UInt64Array> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} column '{name}' is not UInt64"))
})
}
fn timestamp_micros_column<'a>(
batch: &'a RecordBatch,
name: &str,
context: &str,
) -> Result<&'a TimestampMicrosecondArray> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"{context} column '{name}' is not Timestamp(Microsecond)"
))
})
}
fn runs_to_batch(records: &[RunRecord]) -> Result<RecordBatch> {
let run_ids: Vec<&str> = records
.iter()
.map(|record| record.run_id.as_str())
.collect();
let target_branches: Vec<&str> = records
.iter()
.map(|record| record.target_branch.as_str())
.collect();
let run_branches: Vec<&str> = records
.iter()
.map(|record| record.run_branch.as_str())
.collect();
let base_snapshot_ids: Vec<&str> = records
.iter()
.map(|record| record.base_snapshot_id.as_str())
.collect();
let base_manifest_versions: Vec<u64> = records
.iter()
.map(|record| record.base_manifest_version)
.collect();
let operation_hashes: Vec<Option<&str>> = records
.iter()
.map(|record| record.operation_hash.as_deref())
.collect();
let statuses: Vec<&str> = records
.iter()
.map(|record| record.status.as_str())
.collect();
let published_snapshot_ids: Vec<Option<&str>> = records
.iter()
.map(|record| record.published_snapshot_id.as_deref())
.collect();
let created_ats: Vec<i64> = records.iter().map(|record| record.created_at).collect();
let updated_ats: Vec<i64> = records.iter().map(|record| record.updated_at).collect();
RecordBatch::try_new(
run_registry_schema(),
vec![
Arc::new(StringArray::from(run_ids)),
Arc::new(StringArray::from(target_branches)),
Arc::new(StringArray::from(run_branches)),
Arc::new(StringArray::from(base_snapshot_ids)),
Arc::new(UInt64Array::from(base_manifest_versions)),
Arc::new(StringArray::from(operation_hashes)),
Arc::new(StringArray::from(statuses)),
Arc::new(StringArray::from(published_snapshot_ids)),
Arc::new(TimestampMicrosecondArray::from(created_ats)),
Arc::new(TimestampMicrosecondArray::from(updated_ats)),
],
)
.map_err(|e| OmniError::Lance(e.to_string()))
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct RunActorRecord {
run_id: String,
actor_id: String,
created_at: i64,
}
fn run_actors_to_batch(records: &[RunActorRecord]) -> Result<RecordBatch> {
let run_ids: Vec<&str> = records
.iter()
.map(|record| record.run_id.as_str())
.collect();
let actor_ids: Vec<&str> = records
.iter()
.map(|record| record.actor_id.as_str())
.collect();
let created_ats: Vec<i64> = records.iter().map(|record| record.created_at).collect();
RecordBatch::try_new(
run_actor_schema(),
vec![
Arc::new(StringArray::from(run_ids)),
Arc::new(StringArray::from(actor_ids)),
Arc::new(TimestampMicrosecondArray::from(created_ats)),
],
)
.map_err(|e| OmniError::Lance(e.to_string()))
}
fn now_micros() -> Result<i64> {
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| OmniError::manifest(format!("system clock error: {}", e)))?;
Ok(duration.as_micros() as i64)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema};
use super::*;
#[test]
fn load_runs_from_batches_returns_error_for_bad_schema() {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("run_id", DataType::UInt64, false),
Field::new("target_branch", DataType::Utf8, false),
Field::new("run_branch", DataType::Utf8, false),
Field::new("base_snapshot_id", DataType::Utf8, false),
Field::new("base_manifest_version", DataType::UInt64, false),
Field::new("operation_hash", DataType::Utf8, true),
Field::new("status", DataType::Utf8, false),
Field::new("published_snapshot_id", DataType::Utf8, true),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
Field::new(
"updated_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
])),
vec![
Arc::new(UInt64Array::from(vec![1_u64])),
Arc::new(StringArray::from(vec!["main"])),
Arc::new(StringArray::from(vec!["__run__1"])),
Arc::new(StringArray::from(vec!["snap-1"])),
Arc::new(UInt64Array::from(vec![1_u64])),
Arc::new(StringArray::from(vec![None::<&str>])),
Arc::new(StringArray::from(vec!["running"])),
Arc::new(StringArray::from(vec![None::<&str>])),
Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
],
)
.unwrap();
let err = load_runs_from_batches(&[batch]).unwrap_err();
assert!(err.to_string().contains("run_id"));
}
}

View file

@ -0,0 +1,236 @@
use std::sync::Arc;
use omnigraph_compiler::schema::parser::parse_schema;
use omnigraph_compiler::{SchemaIR, build_schema_ir, schema_ir_hash, schema_ir_pretty_json};
use serde::{Deserialize, Serialize};
use crate::error::{OmniError, Result};
use crate::storage::{StorageAdapter, join_uri};
pub(crate) const SCHEMA_SOURCE_FILENAME: &str = "_schema.pg";
pub(crate) const SCHEMA_IR_FILENAME: &str = "_schema.ir.json";
pub(crate) const SCHEMA_STATE_FILENAME: &str = "__schema_state.json";
const SCHEMA_STATE_FORMAT_VERSION: u32 = 1;
const SCHEMA_IDENTITY_VERSION: u32 = 1;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) struct SchemaState {
pub(crate) format_version: u32,
pub(crate) schema_ir_hash: String,
pub(crate) schema_identity_version: u32,
}
impl SchemaState {
pub(crate) fn new(schema_ir_hash: String) -> Self {
Self {
format_version: SCHEMA_STATE_FORMAT_VERSION,
schema_ir_hash,
schema_identity_version: SCHEMA_IDENTITY_VERSION,
}
}
}
pub(crate) async fn load_or_bootstrap_schema_contract(
root_uri: &str,
storage: Arc<dyn StorageAdapter>,
public_branches: &[String],
current_source_ir: &SchemaIR,
) -> Result<(SchemaIR, SchemaState)> {
match read_schema_contract(root_uri, storage.as_ref()).await? {
SchemaContractRead::Present { ir, state } => {
validate_persisted_schema_contract(&ir, &state)?;
validate_current_source_matches(&state, current_source_ir)?;
Ok((ir, state))
}
SchemaContractRead::MissingAll => {
let public_non_main = public_branches
.iter()
.filter(|branch| branch.as_str() != "main")
.cloned()
.collect::<Vec<_>>();
if !public_non_main.is_empty() {
return Err(schema_lock_conflict(format!(
"repo is missing persisted schema state and has public branches ({}); public branches block schema evolution entirely",
public_non_main.join(", ")
)));
}
let state =
write_schema_contract(root_uri, storage.as_ref(), current_source_ir).await?;
Ok((current_source_ir.clone(), state))
}
SchemaContractRead::PartialMissing => Err(schema_lock_conflict(
"repo schema state is incomplete (_schema.ir.json and __schema_state.json must either both exist or both be absent)",
)),
}
}
pub(crate) async fn validate_schema_contract(
root_uri: &str,
storage: Arc<dyn StorageAdapter>,
) -> Result<()> {
let current_source_ir = read_current_source_ir(root_uri, storage.as_ref()).await?;
let (persisted_ir, state) = match read_schema_contract(root_uri, storage.as_ref()).await? {
SchemaContractRead::Present { ir, state } => (ir, state),
SchemaContractRead::MissingAll | SchemaContractRead::PartialMissing => {
return Err(schema_lock_conflict(
"repo is missing persisted schema state; manual coordination is required before schema changes are allowed",
));
}
};
validate_persisted_schema_contract(&persisted_ir, &state)?;
validate_current_source_matches(&state, &current_source_ir)
}
pub(crate) async fn write_schema_contract(
root_uri: &str,
storage: &dyn StorageAdapter,
schema_ir: &SchemaIR,
) -> Result<SchemaState> {
let ir_json = schema_ir_pretty_json(schema_ir)
.map_err(|err| OmniError::manifest_internal(err.to_string()))?;
let state = SchemaState::new(
schema_ir_hash(schema_ir).map_err(|err| OmniError::manifest_internal(err.to_string()))?,
);
let state_json = serde_json::to_string_pretty(&state).map_err(|err| {
OmniError::manifest_internal(format!("serialize schema state error: {}", err))
})?;
storage
.write_text(&schema_ir_uri(root_uri), &ir_json)
.await?;
storage
.write_text(&schema_state_uri(root_uri), &state_json)
.await?;
Ok(state)
}
pub(crate) async fn read_current_source_ir(
root_uri: &str,
storage: &dyn StorageAdapter,
) -> Result<SchemaIR> {
let source = storage.read_text(&schema_source_uri(root_uri)).await?;
compile_schema_source(&source)
}
pub(crate) async fn read_accepted_schema_ir(
root_uri: &str,
storage: Arc<dyn StorageAdapter>,
) -> Result<SchemaIR> {
match read_schema_contract(root_uri, storage.as_ref()).await? {
SchemaContractRead::Present { ir, state } => {
validate_persisted_schema_contract(&ir, &state)?;
Ok(ir)
}
SchemaContractRead::MissingAll | SchemaContractRead::PartialMissing => {
Err(schema_lock_conflict(
"repo is missing persisted schema state; manual coordination is required before schema changes are allowed",
))
}
}
}
pub(crate) fn schema_source_uri(root_uri: &str) -> String {
join_uri(root_uri, SCHEMA_SOURCE_FILENAME)
}
pub(crate) fn schema_ir_uri(root_uri: &str) -> String {
join_uri(root_uri, SCHEMA_IR_FILENAME)
}
pub(crate) fn schema_state_uri(root_uri: &str) -> String {
join_uri(root_uri, SCHEMA_STATE_FILENAME)
}
enum SchemaContractRead {
Present { ir: SchemaIR, state: SchemaState },
MissingAll,
PartialMissing,
}
async fn read_schema_contract(
root_uri: &str,
storage: &dyn StorageAdapter,
) -> Result<SchemaContractRead> {
let ir_uri = schema_ir_uri(root_uri);
let state_uri = schema_state_uri(root_uri);
let ir_exists = storage.exists(&ir_uri).await?;
let state_exists = storage.exists(&state_uri).await?;
match (ir_exists, state_exists) {
(false, false) => Ok(SchemaContractRead::MissingAll),
(true, true) => {
let ir_json = storage.read_text(&ir_uri).await?;
let state_json = storage.read_text(&state_uri).await?;
let ir = serde_json::from_str::<SchemaIR>(&ir_json).map_err(|err| {
schema_lock_conflict(format!(
"accepted compiled schema contract in {} is invalid: {}",
SCHEMA_IR_FILENAME, err
))
})?;
let state = serde_json::from_str::<SchemaState>(&state_json).map_err(|err| {
schema_lock_conflict(format!(
"repo schema state in {} is invalid: {}",
SCHEMA_STATE_FILENAME, err
))
})?;
Ok(SchemaContractRead::Present { ir, state })
}
_ => Ok(SchemaContractRead::PartialMissing),
}
}
fn validate_persisted_schema_contract(ir: &SchemaIR, state: &SchemaState) -> Result<()> {
if state.format_version != SCHEMA_STATE_FORMAT_VERSION {
return Err(schema_lock_conflict(format!(
"repo schema state format {} is unsupported",
state.format_version
)));
}
let actual_hash = schema_ir_hash(ir).map_err(|err| schema_lock_conflict(err.to_string()))?;
if actual_hash != state.schema_ir_hash {
return Err(schema_lock_conflict(
"accepted compiled schema does not match the recorded schema state",
));
}
Ok(())
}
fn validate_current_source_matches(
state: &SchemaState,
current_source_ir: &SchemaIR,
) -> Result<()> {
let current_hash =
schema_ir_hash(current_source_ir).map_err(|err| schema_lock_conflict(err.to_string()))?;
if current_hash != state.schema_ir_hash {
return Err(schema_lock_conflict(
"current _schema.pg no longer matches the accepted compiled schema",
));
}
Ok(())
}
fn compile_schema_source(source: &str) -> Result<SchemaIR> {
let schema = parse_schema(source).map_err(|err| {
schema_lock_conflict(format!(
"current _schema.pg is not a valid accepted schema definition: {}",
err
))
})?;
build_schema_ir(&schema).map_err(|err| {
schema_lock_conflict(format!(
"current _schema.pg could not be compiled into the accepted schema contract: {}",
err
))
})
}
fn schema_lock_conflict(detail: impl Into<String>) -> OmniError {
OmniError::manifest_conflict(format!(
"schema evolution is locked down in phase 1: {}; manual coordination is required",
detail.into()
))
}

View file

@ -0,0 +1,489 @@
use std::future::Future;
use std::time::Duration;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{Value, json};
use tokio::time::sleep;
use crate::error::{OmniError, Result};
const GEMINI_EMBED_MODEL: &str = "gemini-embedding-2-preview";
const DEFAULT_GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
const DEFAULT_TIMEOUT_MS: u64 = 30_000;
const DEFAULT_RETRY_ATTEMPTS: usize = 4;
const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
const QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY";
const DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT";
#[derive(Clone, Debug)]
enum EmbeddingTransport {
Mock,
Gemini {
api_key: String,
base_url: String,
http: Client,
},
}
#[derive(Clone, Debug)]
pub struct EmbeddingClient {
retry_attempts: usize,
retry_backoff_ms: u64,
transport: EmbeddingTransport,
}
struct EmbedCallError {
message: String,
retryable: bool,
}
#[derive(Debug, Deserialize)]
struct GeminiEmbedResponse {
embedding: GeminiContentEmbedding,
}
#[derive(Debug, Deserialize)]
struct GeminiContentEmbedding {
values: Vec<f32>,
}
#[derive(Debug, Deserialize)]
struct GoogleErrorEnvelope {
error: GoogleErrorBody,
}
#[derive(Debug, Deserialize)]
struct GoogleErrorBody {
message: String,
}
impl EmbeddingClient {
pub fn from_env() -> Result<Self> {
let retry_attempts =
parse_env_usize("OMNIGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
let retry_backoff_ms =
parse_env_u64("OMNIGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") {
return Ok(Self {
retry_attempts,
retry_backoff_ms,
transport: EmbeddingTransport::Mock,
});
}
let api_key = std::env::var("GEMINI_API_KEY")
.ok()
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.ok_or_else(|| {
OmniError::manifest_internal(
"GEMINI_API_KEY is required when nearest() needs a string embedding",
)
})?;
let base_url = std::env::var("OMNIGRAPH_GEMINI_BASE_URL")
.ok()
.map(|v| v.trim_end_matches('/').to_string())
.filter(|v| !v.is_empty())
.unwrap_or_else(|| DEFAULT_GEMINI_BASE_URL.to_string());
let timeout_ms = parse_env_u64("OMNIGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
let http = Client::builder()
.timeout(Duration::from_millis(timeout_ms))
.build()
.map_err(|e| {
OmniError::manifest_internal(format!("failed to initialize HTTP client: {}", e))
})?;
Ok(Self {
retry_attempts,
retry_backoff_ms,
transport: EmbeddingTransport::Gemini {
api_key,
base_url,
http,
},
})
}
#[cfg(test)]
fn mock_for_tests() -> Self {
Self {
retry_attempts: DEFAULT_RETRY_ATTEMPTS,
retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
transport: EmbeddingTransport::Mock,
}
}
pub async fn embed_query_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
self.embed_text(input, expected_dim, QUERY_TASK_TYPE).await
}
pub async fn embed_document_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
self.embed_text(input, expected_dim, DOCUMENT_TASK_TYPE)
.await
}
async fn embed_text(
&self,
input: &str,
expected_dim: usize,
task_type: &'static str,
) -> Result<Vec<f32>> {
if expected_dim == 0 {
return Err(OmniError::manifest_internal(
"embedding dimension must be greater than zero",
));
}
match &self.transport {
EmbeddingTransport::Mock => Ok(mock_embedding(input, expected_dim)),
EmbeddingTransport::Gemini { .. } => {
self.with_retry(|| self.embed_text_gemini_once(input, expected_dim, task_type))
.await
}
}
}
async fn with_retry<T, F, Fut>(&self, mut operation: F) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = std::result::Result<T, EmbedCallError>>,
{
let max_attempt = self.retry_attempts.max(1);
let mut attempt = 0usize;
loop {
attempt += 1;
match operation().await {
Ok(value) => return Ok(value),
Err(err) => {
if !err.retryable || attempt >= max_attempt {
return Err(OmniError::manifest_internal(err.message));
}
let shift = (attempt - 1).min(10) as u32;
let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
sleep(Duration::from_millis(delay)).await;
}
}
}
}
async fn embed_text_gemini_once(
&self,
input: &str,
expected_dim: usize,
task_type: &'static str,
) -> std::result::Result<Vec<f32>, EmbedCallError> {
let (api_key, base_url, http) = match &self.transport {
EmbeddingTransport::Gemini {
api_key,
base_url,
http,
} => (api_key, base_url, http),
EmbeddingTransport::Mock => unreachable!("mock transport should not call Gemini"),
};
let response = http
.post(gemini_endpoint(base_url))
.header("x-goog-api-key", api_key)
.json(&build_gemini_request(input, expected_dim, task_type))
.send()
.await;
let response = match response {
Ok(response) => response,
Err(err) => {
let retryable = err.is_timeout() || err.is_connect() || err.is_request();
return Err(EmbedCallError {
message: format!("embedding request failed: {}", err),
retryable,
});
}
};
let status = response.status();
let body = match response.text().await {
Ok(body) => body,
Err(err) => {
return Err(EmbedCallError {
message: format!(
"embedding response read failed (status {}): {}",
status, err
),
retryable: status.is_server_error() || status.as_u16() == 429,
});
}
};
if !status.is_success() {
let message = parse_google_error_message(&body).unwrap_or(body);
return Err(EmbedCallError {
message: format!(
"embedding request failed with status {}: {}",
status, message
),
retryable: status.is_server_error() || status.as_u16() == 429,
});
}
let parsed: GeminiEmbedResponse =
serde_json::from_str(&body).map_err(|err| EmbedCallError {
message: format!("embedding response decode failed: {}", err),
retryable: false,
})?;
validate_and_normalize_embedding(parsed.embedding.values, expected_dim).map_err(|message| {
EmbedCallError {
message,
retryable: false,
}
})
}
}
fn gemini_endpoint(base_url: &str) -> String {
format!(
"{}/models/{}:embedContent",
base_url.trim_end_matches('/'),
GEMINI_EMBED_MODEL
)
}
fn build_gemini_request(input: &str, expected_dim: usize, task_type: &'static str) -> Value {
json!({
"model": format!("models/{}", GEMINI_EMBED_MODEL),
"content": {
"parts": [
{
"text": input
}
]
},
"taskType": task_type,
"outputDimensionality": expected_dim,
})
}
fn validate_and_normalize_embedding(
values: Vec<f32>,
expected_dim: usize,
) -> std::result::Result<Vec<f32>, String> {
if values.len() != expected_dim {
return Err(format!(
"embedding dimension mismatch: expected {}, got {}",
expected_dim,
values.len()
));
}
Ok(normalize_vector(values))
}
fn normalize_vector(mut values: Vec<f32>) -> Vec<f32> {
let norm = values
.iter()
.map(|v| (*v as f64) * (*v as f64))
.sum::<f64>()
.sqrt() as f32;
if norm > f32::EPSILON {
for value in &mut values {
*value /= norm;
}
}
values
}
fn parse_google_error_message(body: &str) -> Option<String> {
serde_json::from_str::<GoogleErrorEnvelope>(body)
.ok()
.map(|e| e.error.message)
.filter(|msg| !msg.trim().is_empty())
}
fn parse_env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v > 0)
.unwrap_or(default)
}
fn parse_env_u64(name: &str, default: u64) -> u64 {
std::env::var(name)
.ok()
.and_then(|v| v.parse::<u64>().ok())
.filter(|v| *v > 0)
.unwrap_or(default)
}
fn env_flag(name: &str) -> bool {
std::env::var(name)
.ok()
.map(|v| {
let s = v.trim().to_ascii_lowercase();
s == "1" || s == "true" || s == "yes" || s == "on"
})
.unwrap_or(false)
}
fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
let mut seed = fnv1a64(input.as_bytes());
let mut out = Vec::with_capacity(dim);
for _ in 0..dim {
seed = xorshift64(seed);
let ratio = (seed as f64 / u64::MAX as f64) as f32;
out.push((ratio * 2.0) - 1.0);
}
normalize_vector(out)
}
fn fnv1a64(bytes: &[u8]) -> u64 {
let mut hash = 14695981039346656037u64;
for byte in bytes {
hash ^= *byte as u64;
hash = hash.wrapping_mul(1099511628211u64);
}
hash
}
fn xorshift64(mut x: u64) -> u64 {
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
x
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use serial_test::serial;
use super::*;
struct EnvGuard {
saved: Vec<(&'static str, Option<String>)>,
}
impl EnvGuard {
fn set(vars: &[(&'static str, Option<&str>)]) -> Self {
let saved = vars
.iter()
.map(|(name, _)| (*name, std::env::var(name).ok()))
.collect::<Vec<_>>();
for (name, value) in vars {
unsafe {
match value {
Some(value) => std::env::set_var(name, value),
None => std::env::remove_var(name),
}
}
}
Self { saved }
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
for (name, value) in self.saved.drain(..) {
unsafe {
match value {
Some(value) => std::env::set_var(name, value),
None => std::env::remove_var(name),
}
}
}
}
}
#[tokio::test]
async fn mock_embeddings_are_deterministic() {
let client = EmbeddingClient::mock_for_tests();
let a = client.embed_query_text("alpha", 8).await.unwrap();
let b = client.embed_query_text("alpha", 8).await.unwrap();
let c = client.embed_query_text("beta", 8).await.unwrap();
assert_eq!(a, b);
assert_ne!(a, c);
assert_eq!(a.len(), 8);
}
#[test]
fn gemini_request_uses_preview_model_retrieval_query_and_dimension() {
let request = build_gemini_request("alpha", 4, QUERY_TASK_TYPE);
assert_eq!(request["model"], "models/gemini-embedding-2-preview");
assert_eq!(request["taskType"], QUERY_TASK_TYPE);
assert_eq!(request["outputDimensionality"], 4);
assert_eq!(request["content"]["parts"][0]["text"], "alpha");
}
#[test]
fn gemini_document_request_uses_retrieval_document_task_type() {
let request = build_gemini_request("alpha", 4, DOCUMENT_TASK_TYPE);
assert_eq!(request["taskType"], DOCUMENT_TASK_TYPE);
}
#[test]
fn validate_and_normalize_embedding_enforces_dimension() {
let normalized = validate_and_normalize_embedding(vec![3.0, 4.0], 2).unwrap();
assert!((normalized[0] - 0.6).abs() < 1e-6);
assert!((normalized[1] - 0.8).abs() < 1e-6);
let err = validate_and_normalize_embedding(vec![1.0, 2.0], 3).unwrap_err();
assert!(err.contains("expected 3, got 2"));
}
#[tokio::test]
async fn with_retry_retries_retryable_failures() {
let client = EmbeddingClient::mock_for_tests();
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_for_call = Arc::clone(&attempts);
let value = client
.with_retry(|| {
let attempts_for_call = Arc::clone(&attempts_for_call);
async move {
let attempt = attempts_for_call.fetch_add(1, Ordering::SeqCst);
if attempt == 0 {
Err(EmbedCallError {
message: "retry me".to_string(),
retryable: true,
})
} else {
Ok("ok")
}
}
})
.await
.unwrap();
assert_eq!(value, "ok");
assert_eq!(attempts.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn with_retry_stops_on_non_retryable_failures() {
let client = EmbeddingClient::mock_for_tests();
let err = client
.with_retry(|| async {
Err::<(), _>(EmbedCallError {
message: "do not retry".to_string(),
retryable: false,
})
})
.await
.unwrap_err();
assert!(err.to_string().contains("do not retry"));
}
#[test]
#[serial]
fn from_env_requires_gemini_api_key_when_not_mocking() {
let _guard = EnvGuard::set(&[
("OMNIGRAPH_EMBEDDINGS_MOCK", None),
("GEMINI_API_KEY", None),
]);
let err = EmbeddingClient::from_env().unwrap_err();
assert!(err.to_string().contains("GEMINI_API_KEY"));
}
}

View file

@ -0,0 +1,80 @@
use thiserror::Error;
pub type Result<T> = std::result::Result<T, OmniError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ManifestErrorKind {
BadRequest,
NotFound,
Conflict,
Internal,
}
#[derive(Debug, Clone, Error)]
#[error("{message}")]
pub struct ManifestError {
pub kind: ManifestErrorKind,
pub message: String,
}
impl ManifestError {
pub fn new(kind: ManifestErrorKind, message: impl Into<String>) -> Self {
Self {
kind,
message: message.into(),
}
}
}
#[derive(Debug, Clone)]
pub struct MergeConflict {
pub table_key: String,
pub row_id: Option<String>,
pub kind: MergeConflictKind,
pub message: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MergeConflictKind {
DivergentInsert,
DivergentUpdate,
DeleteVsUpdate,
OrphanEdge,
UniqueViolation,
CardinalityViolation,
ValueConstraintViolation,
}
#[derive(Debug, Error)]
pub enum OmniError {
#[error("{0}")]
Compiler(#[from] omnigraph_compiler::error::NanoError),
#[error("storage: {0}")]
Lance(String),
#[error("query: {0}")]
DataFusion(String),
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
Manifest(ManifestError),
#[error("merge conflicts: {0:?}")]
MergeConflicts(Vec<MergeConflict>),
}
impl OmniError {
pub fn manifest(message: impl Into<String>) -> Self {
Self::Manifest(ManifestError::new(ManifestErrorKind::BadRequest, message))
}
pub fn manifest_not_found(message: impl Into<String>) -> Self {
Self::Manifest(ManifestError::new(ManifestErrorKind::NotFound, message))
}
pub fn manifest_conflict(message: impl Into<String>) -> Self {
Self::Manifest(ManifestError::new(ManifestErrorKind::Conflict, message))
}
pub fn manifest_internal(message: impl Into<String>) -> Self {
Self::Manifest(ManifestError::new(ManifestErrorKind::Internal, message))
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,37 @@
use crate::error::Result;
pub(crate) fn maybe_fail(_name: &str) -> Result<()> {
#[cfg(feature = "failpoints")]
{
let name = _name;
fail::fail_point!(name, |_| {
return Err(crate::error::OmniError::manifest(format!(
"injected failpoint triggered: {}",
name
)));
});
}
Ok(())
}
#[cfg(feature = "failpoints")]
pub struct ScopedFailPoint {
name: String,
}
#[cfg(feature = "failpoints")]
impl ScopedFailPoint {
pub fn new(name: &str, action: &str) -> Self {
fail::cfg(name, action).expect("configure failpoint");
Self {
name: name.to_string(),
}
}
}
#[cfg(feature = "failpoints")]
impl Drop for ScopedFailPoint {
fn drop(&mut self) {
fail::remove(&self.name);
}
}

View file

@ -0,0 +1,315 @@
use std::collections::HashMap;
use arrow_array::StringArray;
use futures::TryStreamExt;
use crate::db::Snapshot;
use crate::error::{OmniError, Result};
/// Dense u32 mapping for a single node type: String ID ↔ dense index.
#[derive(Debug, Clone)]
pub struct TypeIndex {
id_to_dense: HashMap<String, u32>,
dense_to_id: Vec<String>,
}
impl TypeIndex {
pub(crate) fn new() -> Self {
Self {
id_to_dense: HashMap::new(),
dense_to_id: Vec::new(),
}
}
/// Get or insert a string ID, returning its dense index.
pub(crate) fn get_or_insert(&mut self, id: &str) -> u32 {
if let Some(&idx) = self.id_to_dense.get(id) {
return idx;
}
let idx = self.dense_to_id.len() as u32;
self.dense_to_id.push(id.to_string());
self.id_to_dense.insert(id.to_string(), idx);
idx
}
pub fn to_dense(&self, id: &str) -> Option<u32> {
self.id_to_dense.get(id).copied()
}
pub fn to_id(&self, dense: u32) -> Option<&str> {
self.dense_to_id.get(dense as usize).map(|s| s.as_str())
}
pub fn len(&self) -> usize {
self.dense_to_id.len()
}
}
/// CSR (Compressed Sparse Row) adjacency index.
#[derive(Debug, Clone)]
pub struct CsrIndex {
/// offsets[i] .. offsets[i+1] gives the neighbor range for node i.
offsets: Vec<u32>,
/// Dense indices of destination nodes.
targets: Vec<u32>,
}
impl CsrIndex {
pub(crate) fn build(num_nodes: usize, edges: &[(u32, u32)]) -> Self {
// Count outgoing edges per source
let mut counts = vec![0u32; num_nodes];
for &(src, _) in edges {
counts[src as usize] += 1;
}
// Build offset array (prefix sum)
let mut offsets = Vec::with_capacity(num_nodes + 1);
offsets.push(0);
for &c in &counts {
offsets.push(offsets.last().unwrap() + c);
}
// Fill targets
let mut targets = vec![0u32; edges.len()];
let mut cursors = vec![0u32; num_nodes];
for &(src, dst) in edges {
let s = src as usize;
let pos = offsets[s] + cursors[s];
targets[pos as usize] = dst;
cursors[s] += 1;
}
Self { offsets, targets }
}
/// Return the dense indices of neighbors for a given dense node index.
pub fn neighbors(&self, node: u32) -> &[u32] {
let start = self.offsets[node as usize] as usize;
let end = self.offsets[node as usize + 1] as usize;
&self.targets[start..end]
}
/// Check if a node has any outgoing edges. O(1), no allocation.
pub fn has_neighbors(&self, node: u32) -> bool {
let n = node as usize;
self.offsets[n + 1] > self.offsets[n]
}
}
/// Topology-only graph index. No node data cached — just adjacency.
#[derive(Debug, Clone)]
pub struct GraphIndex {
/// Dense index per node type (built from edge src/dst columns).
type_indices: HashMap<String, TypeIndex>,
/// Outgoing adjacency per edge type.
csr: HashMap<String, CsrIndex>,
/// Incoming adjacency per edge type.
csc: HashMap<String, CsrIndex>,
}
impl GraphIndex {
/// Build a graph index by scanning edge sub-tables from a snapshot.
pub async fn build(
snapshot: &Snapshot,
edge_types: &HashMap<String, (String, String)>, // edge_name → (from_type, to_type)
) -> Result<Self> {
let mut type_indices: HashMap<String, TypeIndex> = HashMap::new();
let mut csr = HashMap::new();
let mut csc = HashMap::new();
// Phase 1: Scan all edges, build TypeIndices and collect edge pairs
let mut edge_pairs: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
for (edge_name, (from_type, to_type)) in edge_types {
let table_key = format!("edge:{}", edge_name);
if snapshot.entry(&table_key).is_none() {
continue;
}
let ds = snapshot.open(&table_key).await?;
let batches: Vec<arrow_array::RecordBatch> = ds
.scan()
.project(&["src", "dst"])
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
type_indices
.entry(from_type.clone())
.or_insert_with(TypeIndex::new);
type_indices
.entry(to_type.clone())
.or_insert_with(TypeIndex::new);
let mut edges: Vec<(u32, u32)> = Vec::new();
for batch in &batches {
let srcs = string_column(batch, "src")?;
let dsts = string_column(batch, "dst")?;
for i in 0..batch.num_rows() {
let src_dense = type_indices
.get_mut(from_type)
.unwrap()
.get_or_insert(srcs.value(i));
let dst_dense = type_indices
.get_mut(to_type)
.unwrap()
.get_or_insert(dsts.value(i));
edges.push((src_dense, dst_dense));
}
}
edge_pairs.insert(edge_name.clone(), edges);
}
// Phase 2: Build CSR/CSC using final TypeIndex sizes
for (edge_name, (from_type, to_type)) in edge_types {
let Some(edges) = edge_pairs.get(edge_name) else {
continue;
};
let src_count = type_indices[from_type].len();
let dst_count = type_indices[to_type].len();
csr.insert(edge_name.clone(), CsrIndex::build(src_count, edges));
let reversed: Vec<(u32, u32)> = edges.iter().map(|&(s, d)| (d, s)).collect();
csc.insert(edge_name.clone(), CsrIndex::build(dst_count, &reversed));
}
Ok(Self {
type_indices,
csr,
csc,
})
}
pub fn type_index(&self, type_name: &str) -> Option<&TypeIndex> {
self.type_indices.get(type_name)
}
pub fn csr(&self, edge_type: &str) -> Option<&CsrIndex> {
self.csr.get(edge_type)
}
pub fn csc(&self, edge_type: &str) -> Option<&CsrIndex> {
self.csc.get(edge_type)
}
#[cfg(test)]
pub(crate) fn empty_for_test() -> Self {
Self {
type_indices: HashMap::new(),
csr: HashMap::new(),
csc: HashMap::new(),
}
}
}
fn string_column<'a>(batch: &'a arrow_array::RecordBatch, name: &str) -> Result<&'a StringArray> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("graph index batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
OmniError::manifest_internal(format!("graph index column '{name}' is not Utf8"))
})
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::UInt64Array;
use arrow_schema::{DataType, Field, Schema};
use super::*;
#[test]
fn type_index_round_trip() {
let mut idx = TypeIndex::new();
let a = idx.get_or_insert("Alice");
let b = idx.get_or_insert("Bob");
let c = idx.get_or_insert("Charlie");
assert_eq!(idx.to_dense("Alice"), Some(a));
assert_eq!(idx.to_dense("Bob"), Some(b));
assert_eq!(idx.to_dense("Charlie"), Some(c));
assert_eq!(idx.to_id(a), Some("Alice"));
assert_eq!(idx.to_id(b), Some("Bob"));
assert_eq!(idx.to_id(c), Some("Charlie"));
assert_eq!(idx.len(), 3);
}
#[test]
fn type_index_idempotent_insert() {
let mut idx = TypeIndex::new();
let a1 = idx.get_or_insert("Alice");
let a2 = idx.get_or_insert("Alice");
assert_eq!(a1, a2);
assert_eq!(idx.len(), 1);
}
#[test]
fn type_index_unknown_returns_none() {
let idx = TypeIndex::new();
assert_eq!(idx.to_dense("unknown"), None);
assert_eq!(idx.to_id(999), None);
}
#[test]
fn csr_neighbors_correct() {
// Graph: 0→1, 0→2, 1→2
let edges = vec![(0, 1), (0, 2), (1, 2)];
let csr = CsrIndex::build(3, &edges);
let mut n0: Vec<u32> = csr.neighbors(0).to_vec();
n0.sort();
assert_eq!(n0, vec![1, 2]);
assert_eq!(csr.neighbors(1), &[2]);
assert_eq!(csr.neighbors(2), &[] as &[u32]);
}
#[test]
fn csr_empty_graph() {
let csr = CsrIndex::build(3, &[]);
assert_eq!(csr.neighbors(0), &[] as &[u32]);
assert_eq!(csr.neighbors(1), &[] as &[u32]);
assert_eq!(csr.neighbors(2), &[] as &[u32]);
assert!(!csr.has_neighbors(0));
}
#[test]
fn csr_has_neighbors() {
// 0→1, 1→2
let csr = CsrIndex::build(3, &[(0, 1), (1, 2)]);
assert!(csr.has_neighbors(0));
assert!(csr.has_neighbors(1));
assert!(!csr.has_neighbors(2));
}
#[test]
fn string_column_returns_error_for_bad_schema() {
let batch = arrow_array::RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"src",
DataType::UInt64,
false,
)])),
vec![Arc::new(UInt64Array::from(vec![1_u64]))],
)
.unwrap();
let err = string_column(&batch, "src").unwrap_err();
assert!(err.to_string().contains("src"));
}
}

View file

@ -0,0 +1,11 @@
pub mod changes;
pub mod db;
pub mod embedding;
pub mod error;
mod exec;
pub mod failpoints;
pub mod graph_index;
pub mod loader;
pub mod runtime_cache;
pub mod storage;
pub mod table_store;

View file

@ -0,0 +1,476 @@
use std::collections::HashMap;
#[cfg(test)]
use std::collections::HashSet;
use arrow_array::{
Array, ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array,
Int32Array, Int64Array, StringArray, UInt32Array, UInt64Array,
};
use arrow_schema::{DataType, Field, Schema};
use crate::catalog::schema_ir::SchemaIR;
use crate::error::{NanoError, Result};
use super::super::graph::DatasetAccumulator;
#[derive(Debug, Default)]
pub(crate) struct NodeConstraintAnnotations {
pub(crate) key_props: HashMap<String, String>,
pub(crate) unique_props: HashMap<String, Vec<String>>,
}
pub(crate) fn load_node_constraint_annotations(
schema_ir: &SchemaIR,
) -> Result<NodeConstraintAnnotations> {
let mut constraints = NodeConstraintAnnotations::default();
for node in schema_ir.node_types() {
let mut node_key_prop: Option<String> = None;
let mut node_unique_props: Vec<String> = Vec::new();
for prop in &node.properties {
if prop.key && node_key_prop.replace(prop.name.clone()).is_some() {
return Err(NanoError::Storage(format!(
"node type {} has multiple @key properties; only one is currently supported",
node.name
)));
}
if prop.unique {
node_unique_props.push(prop.name.clone());
}
}
if let Some(prop_name) = node_key_prop {
if !node_unique_props.contains(&prop_name) {
node_unique_props.push(prop_name.clone());
}
constraints.key_props.insert(node.name.clone(), prop_name);
}
if !node_unique_props.is_empty() {
node_unique_props.sort();
node_unique_props.dedup();
constraints
.unique_props
.insert(node.name.clone(), node_unique_props);
}
}
Ok(constraints)
}
pub(crate) fn enforce_node_unique_constraints(
storage: &DatasetAccumulator,
unique_props: &HashMap<String, Vec<String>>,
) -> Result<()> {
for (type_name, properties) in unique_props {
let Some(batch) = storage.get_all_nodes(type_name)? else {
continue;
};
for property in properties {
let prop_idx =
node_property_index(batch.schema().as_ref(), property).ok_or_else(|| {
NanoError::Storage(format!(
"node type {} missing @unique property {}",
type_name, property
))
})?;
let arr = batch.column(prop_idx);
let mut seen: HashMap<String, usize> = HashMap::new();
for row in 0..batch.num_rows() {
let Some(value) = unique_value_string(arr, row, type_name, property)? else {
continue;
};
if let Some(prev_row) = seen.insert(value.clone(), row) {
return Err(NanoError::UniqueConstraint {
type_name: type_name.clone(),
property: property.clone(),
value,
first_row: prev_row,
second_row: row,
});
}
}
}
}
Ok(())
}
#[cfg(test)]
pub(crate) fn collect_incoming_node_types(data_source: &str) -> Result<HashSet<String>> {
let mut node_types = HashSet::new();
for line in data_source.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with("//") {
continue;
}
let obj: serde_json::Value = serde_json::from_str(line)
.map_err(|e| NanoError::Storage(format!("JSON parse error: {}", e)))?;
if let Some(type_name) = obj.get("type").and_then(|v| v.as_str()) {
node_types.insert(type_name.to_string());
}
}
Ok(node_types)
}
pub(crate) fn build_name_seed_for_keyed_load(
storage: &DatasetAccumulator,
key_props: &HashMap<String, String>,
) -> Result<HashMap<(String, String), u64>> {
let mut seed = HashMap::new();
for (type_name, key_prop) in key_props {
let Some(batch) = storage.get_all_nodes(type_name)? else {
continue;
};
let key_idx = node_property_index(batch.schema().as_ref(), key_prop).ok_or_else(|| {
NanoError::Storage(format!(
"node type {} missing @key property {}",
type_name, key_prop
))
})?;
let key_arr = batch.column(key_idx).clone();
let id_arr = batch
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| {
NanoError::Storage(format!("node type {} has non-UInt64 id column", type_name))
})?;
for row in 0..batch.num_rows() {
let key = key_value_string(&key_arr, row, key_prop)?;
seed.insert((type_name.clone(), key), id_arr.value(row));
}
}
Ok(seed)
}
pub(crate) fn build_name_seed_for_append(
storage: &DatasetAccumulator,
key_props: &HashMap<String, String>,
) -> Result<HashMap<(String, String), u64>> {
build_name_seed_for_keyed_load(storage, key_props)
}
pub(crate) fn node_property_index(schema: &Schema, prop_name: &str) -> Option<usize> {
schema
.fields()
.iter()
.enumerate()
.skip(1)
.find_map(|(idx, field)| (field.name() == prop_name).then_some(idx))
}
pub(crate) fn node_property_field<'a>(schema: &'a Schema, prop_name: &str) -> Option<&'a Field> {
node_property_index(schema, prop_name).map(|idx| schema.field(idx))
}
pub(crate) fn key_value_string(array: &ArrayRef, row: usize, prop_name: &str) -> Result<String> {
let value = scalar_value_string(array, row, "key", None, prop_name)?;
if let Some(value) = value {
return Ok(value);
}
Err(NanoError::Storage(format!(
"@key property {} cannot be null",
prop_name
)))
}
fn unique_value_string(
array: &ArrayRef,
row: usize,
type_name: &str,
prop_name: &str,
) -> Result<Option<String>> {
scalar_value_string(array, row, "unique", Some(type_name), prop_name)
}
fn scalar_value_string(
array: &ArrayRef,
row: usize,
annotation: &str,
type_name: Option<&str>,
prop_name: &str,
) -> Result<Option<String>> {
if array.is_null(row) {
return Ok(None);
}
let value = match array.data_type() {
DataType::Utf8 => array
.as_any()
.downcast_ref::<StringArray>()
.map(|a| a.value(row).to_string()),
DataType::Boolean => array
.as_any()
.downcast_ref::<BooleanArray>()
.map(|a| a.value(row).to_string()),
DataType::Int32 => array
.as_any()
.downcast_ref::<Int32Array>()
.map(|a| a.value(row).to_string()),
DataType::Int64 => array
.as_any()
.downcast_ref::<Int64Array>()
.map(|a| a.value(row).to_string()),
DataType::UInt32 => array
.as_any()
.downcast_ref::<UInt32Array>()
.map(|a| a.value(row).to_string()),
DataType::UInt64 => array
.as_any()
.downcast_ref::<UInt64Array>()
.map(|a| a.value(row).to_string()),
DataType::Float32 => array
.as_any()
.downcast_ref::<Float32Array>()
.map(|a| a.value(row).to_string()),
DataType::Float64 => array
.as_any()
.downcast_ref::<Float64Array>()
.map(|a| a.value(row).to_string()),
DataType::Date32 => array
.as_any()
.downcast_ref::<Date32Array>()
.map(|a| a.value(row).to_string()),
DataType::Date64 => array
.as_any()
.downcast_ref::<Date64Array>()
.map(|a| a.value(row).to_string()),
_ => None,
};
let value = value.ok_or_else(|| {
let target = match type_name {
Some(name) => format!("{}.{}", name, prop_name),
None => prop_name.to_string(),
};
NanoError::Storage(format!(
"unsupported @{} data type {:?} for {}",
annotation,
array.data_type(),
target
))
})?;
Ok(Some(value))
}
#[cfg(test)]
mod tests {
use std::collections::{HashMap, HashSet};
use arrow_array::StringArray;
use crate::catalog::schema_ir::{build_catalog_from_ir, build_schema_ir};
use crate::schema::parser::parse_schema;
use super::super::jsonl::load_jsonl_data;
use super::*;
fn build_schema_ir_and_storage(schema_src: &str) -> (SchemaIR, DatasetAccumulator) {
let schema = parse_schema(schema_src).unwrap();
let ir = build_schema_ir(&schema).unwrap();
let catalog = build_catalog_from_ir(&ir).unwrap();
(ir, DatasetAccumulator::new(catalog))
}
#[test]
fn load_node_constraint_annotations_collects_key_and_unique() {
let schema = r#"node Person {
name: String @key
email: String @unique
alias: String? @unique
}"#;
let (ir, _) = build_schema_ir_and_storage(schema);
let annotations = load_node_constraint_annotations(&ir).unwrap();
assert_eq!(annotations.key_props.get("Person").unwrap(), "name");
assert_eq!(
annotations.unique_props.get("Person").unwrap(),
&vec!["alias".to_string(), "email".to_string(), "name".to_string()]
);
}
#[test]
fn collect_incoming_node_types_ignores_comments_and_blanks() {
let data = r#"
// comment
{"type":"Person","data":{"name":"Alice"}}
{"edge":"Knows","from":"Alice","to":"Bob"}
{"type":"Company","data":{"name":"Acme"}}
"#;
let types = collect_incoming_node_types(data).unwrap();
assert_eq!(
types,
HashSet::from(["Person".to_string(), "Company".to_string()])
);
}
#[test]
fn enforce_node_unique_constraints_detects_duplicate_non_null() {
let schema = r#"node Person {
name: String
email: String? @unique
}"#;
let (_, mut storage) = build_schema_ir_and_storage(schema);
let key_props = HashMap::new();
load_jsonl_data(
&mut storage,
r#"{"type":"Person","data":{"name":"Alice","email":"dupe@example.com"}}
{"type":"Person","data":{"name":"Bob","email":"dupe@example.com"}}"#,
&key_props,
)
.unwrap();
let unique_props = HashMap::from([("Person".to_string(), vec!["email".to_string()])]);
let err = enforce_node_unique_constraints(&storage, &unique_props).unwrap_err();
match err {
NanoError::UniqueConstraint {
type_name,
property,
value,
..
} => {
assert_eq!(type_name, "Person");
assert_eq!(property, "email");
assert_eq!(value, "dupe@example.com");
}
other => panic!("expected UniqueConstraint, got {other}"),
}
}
#[test]
fn enforce_node_unique_constraints_allows_multiple_nulls() {
let schema = r#"node Person {
name: String
nick: String? @unique
}"#;
let (_, mut storage) = build_schema_ir_and_storage(schema);
let key_props = HashMap::new();
load_jsonl_data(
&mut storage,
r#"{"type":"Person","data":{"name":"Alice","nick":null}}
{"type":"Person","data":{"name":"Bob","nick":null}}"#,
&key_props,
)
.unwrap();
let unique_props = HashMap::from([("Person".to_string(), vec!["nick".to_string()])]);
enforce_node_unique_constraints(&storage, &unique_props).unwrap();
}
#[test]
fn enforce_node_unique_constraints_uses_user_property_named_id() {
let schema = r#"node Person {
id: String @unique
name: String
}"#;
let (_, mut storage) = build_schema_ir_and_storage(schema);
let key_props = HashMap::new();
load_jsonl_data(
&mut storage,
r#"{"type":"Person","data":{"id":"user-1","name":"Alice"}}
{"type":"Person","data":{"id":"user-1","name":"Bob"}}"#,
&key_props,
)
.unwrap();
let unique_props = HashMap::from([("Person".to_string(), vec!["id".to_string()])]);
let err = enforce_node_unique_constraints(&storage, &unique_props).unwrap_err();
match err {
NanoError::UniqueConstraint {
type_name,
property,
value,
..
} => {
assert_eq!(type_name, "Person");
assert_eq!(property, "id");
assert_eq!(value, "user-1");
}
other => panic!("expected UniqueConstraint, got {other}"),
}
}
#[test]
fn build_name_seed_for_keyed_load_uses_declared_key_property() {
let schema = r#"node Person {
uid: String @key
name: String
}
node Company {
name: String
}"#;
let (_, mut storage) = build_schema_ir_and_storage(schema);
let key_props = HashMap::from([("Person".to_string(), "uid".to_string())]);
load_jsonl_data(
&mut storage,
r#"{"type":"Person","data":{"uid":"u1","name":"Alice"}}
{"type":"Company","data":{"name":"Acme"}}"#,
&key_props,
)
.unwrap();
let seed = build_name_seed_for_keyed_load(&storage, &key_props).unwrap();
assert!(seed.contains_key(&("Person".to_string(), "u1".to_string())));
assert!(!seed.contains_key(&("Company".to_string(), "Acme".to_string())));
}
#[test]
fn build_name_seed_for_keyed_load_uses_user_property_named_id() {
let schema = r#"node Person {
id: String @key
name: String
}"#;
let (_, mut storage) = build_schema_ir_and_storage(schema);
let key_props = HashMap::from([("Person".to_string(), "id".to_string())]);
load_jsonl_data(
&mut storage,
r#"{"type":"Person","data":{"id":"user-1","name":"Alice"}}"#,
&key_props,
)
.unwrap();
let seed = build_name_seed_for_keyed_load(&storage, &key_props).unwrap();
assert!(seed.contains_key(&("Person".to_string(), "user-1".to_string())));
}
#[test]
fn build_name_seed_for_append_keeps_all_existing_keyed_nodes() {
let schema = r#"node Person {
uid: String @key
name: String
}
node Company {
name: String
}"#;
let (_, mut storage) = build_schema_ir_and_storage(schema);
let key_props = HashMap::from([("Person".to_string(), "uid".to_string())]);
load_jsonl_data(
&mut storage,
r#"{"type":"Person","data":{"uid":"u1","name":"Alice"}}
{"type":"Company","data":{"name":"Acme"}}"#,
&key_props,
)
.unwrap();
let seed = build_name_seed_for_append(&storage, &key_props).unwrap();
assert!(seed.contains_key(&("Person".to_string(), "u1".to_string())));
assert!(!seed.contains_key(&("Company".to_string(), "Acme".to_string())));
}
#[test]
fn key_value_string_rejects_null() {
let arr: ArrayRef = std::sync::Arc::new(StringArray::from(vec![Some("x"), None]));
assert_eq!(key_value_string(&arr, 0, "name").unwrap(), "x");
let err = key_value_string(&arr, 1, "name").unwrap_err();
assert!(err.to_string().contains("cannot be null"));
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,159 @@
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use omnigraph_compiler::catalog::Catalog;
use tokio::sync::Mutex;
use crate::db::ResolvedTarget;
use crate::error::Result;
use crate::graph_index::GraphIndex;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct GraphIndexCacheKey {
snapshot_id: String,
edge_tables: Vec<GraphIndexTableState>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct GraphIndexTableState {
table_key: String,
table_version: u64,
table_branch: Option<String>,
}
#[derive(Debug, Default)]
pub struct RuntimeCache {
graph_indices: Mutex<GraphIndexCache>,
}
#[derive(Debug, Default)]
struct GraphIndexCache {
entries: HashMap<GraphIndexCacheKey, Arc<GraphIndex>>,
lru: VecDeque<GraphIndexCacheKey>,
}
impl RuntimeCache {
pub async fn invalidate_all(&self) {
let mut cache = self.graph_indices.lock().await;
cache.entries.clear();
cache.lru.clear();
}
pub async fn graph_index(
&self,
resolved: &ResolvedTarget,
catalog: &Catalog,
) -> Result<Arc<GraphIndex>> {
let key = graph_index_cache_key(resolved, catalog);
{
let mut cache = self.graph_indices.lock().await;
if let Some(index) = cache.entries.get(&key).cloned() {
cache.touch(key.clone());
return Ok(index);
}
}
let edge_types = catalog
.edge_types
.iter()
.map(|(name, et)| (name.clone(), (et.from_type.clone(), et.to_type.clone())))
.collect();
let index = Arc::new(GraphIndex::build(&resolved.snapshot, &edge_types).await?);
let mut cache = self.graph_indices.lock().await;
if let Some(existing) = cache.entries.get(&key).cloned() {
cache.touch(key);
return Ok(existing);
}
cache.insert(key, Arc::clone(&index));
Ok(index)
}
}
impl GraphIndexCache {
fn insert(&mut self, key: GraphIndexCacheKey, value: Arc<GraphIndex>) {
self.entries.insert(key.clone(), value);
self.touch(key);
while self.entries.len() > 8 {
let Some(oldest) = self.lru.pop_front() else {
break;
};
if self.entries.remove(&oldest).is_some() {
break;
}
}
}
fn touch(&mut self, key: GraphIndexCacheKey) {
self.lru.retain(|existing| existing != &key);
self.lru.push_back(key);
}
}
fn graph_index_cache_key(resolved: &ResolvedTarget, catalog: &Catalog) -> GraphIndexCacheKey {
let mut edge_tables: Vec<GraphIndexTableState> = catalog
.edge_types
.keys()
.filter_map(|edge_name| {
let table_key = format!("edge:{}", edge_name);
resolved
.snapshot
.entry(&table_key)
.map(|entry| GraphIndexTableState {
table_key,
table_version: entry.table_version,
table_branch: entry.table_branch.clone(),
})
})
.collect();
edge_tables.sort_by(|a, b| a.table_key.cmp(&b.table_key));
GraphIndexCacheKey {
snapshot_id: resolved.snapshot_id.as_str().to_string(),
edge_tables,
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
fn key(id: usize) -> GraphIndexCacheKey {
GraphIndexCacheKey {
snapshot_id: format!("snap-{id}"),
edge_tables: Vec::new(),
}
}
fn empty_index() -> Arc<GraphIndex> {
Arc::new(GraphIndex::empty_for_test())
}
#[test]
fn graph_index_cache_evicts_oldest_entry() {
let mut cache = GraphIndexCache::default();
for idx in 0..9 {
cache.insert(key(idx), empty_index());
}
assert_eq!(cache.entries.len(), 8);
assert!(!cache.entries.contains_key(&key(0)));
assert!(cache.entries.contains_key(&key(8)));
}
#[test]
fn graph_index_cache_touch_keeps_recent_entry() {
let mut cache = GraphIndexCache::default();
for idx in 0..8 {
cache.insert(key(idx), empty_index());
}
cache.touch(key(0));
cache.insert(key(8), empty_index());
assert!(cache.entries.contains_key(&key(0)));
assert!(!cache.entries.contains_key(&key(1)));
}
}

View file

@ -0,0 +1,325 @@
use std::env;
use std::fmt::Debug;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_trait::async_trait;
use futures::TryStreamExt;
use object_store::aws::AmazonS3Builder;
use object_store::path::Path as ObjectPath;
use object_store::{DynObjectStore, ObjectStore, PutPayload};
use url::Url;
use crate::error::{OmniError, Result};
const FILE_SCHEME_PREFIX: &str = "file://";
const S3_SCHEME_PREFIX: &str = "s3://";
#[async_trait]
pub trait StorageAdapter: Debug + Send + Sync {
async fn read_text(&self, uri: &str) -> Result<String>;
async fn write_text(&self, uri: &str, contents: &str) -> Result<()>;
async fn exists(&self, uri: &str) -> Result<bool>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageKind {
Local,
S3,
}
#[derive(Debug, Default)]
pub struct LocalStorageAdapter;
#[derive(Debug)]
pub struct S3StorageAdapter {
bucket: String,
store: Arc<DynObjectStore>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct S3Location {
bucket: String,
key: String,
}
#[async_trait]
impl StorageAdapter for LocalStorageAdapter {
async fn read_text(&self, uri: &str) -> Result<String> {
let path = local_path_from_uri(uri)?;
Ok(tokio::fs::read_to_string(&path).await?)
}
async fn write_text(&self, uri: &str, contents: &str) -> Result<()> {
let path = local_path_from_uri(uri)?;
tokio::fs::write(&path, contents).await?;
Ok(())
}
async fn exists(&self, uri: &str) -> Result<bool> {
Ok(local_path_from_uri(uri)?.exists())
}
}
#[async_trait]
impl StorageAdapter for S3StorageAdapter {
async fn read_text(&self, uri: &str) -> Result<String> {
let location = self.object_path(uri)?;
let bytes = self
.store
.get(&location)
.await
.map_err(|err| storage_backend_error("read", uri, err))?
.bytes()
.await
.map_err(|err| storage_backend_error("read", uri, err))?;
String::from_utf8(bytes.to_vec()).map_err(|err| {
OmniError::manifest_internal(format!("storage read failed for '{}': {}", uri, err))
})
}
async fn write_text(&self, uri: &str, contents: &str) -> Result<()> {
let location = self.object_path(uri)?;
self.store
.put(&location, PutPayload::from(contents.as_bytes().to_vec()))
.await
.map_err(|err| storage_backend_error("write", uri, err))?;
Ok(())
}
async fn exists(&self, uri: &str) -> Result<bool> {
let location = self.object_path(uri)?;
match self.store.head(&location).await {
Ok(_) => Ok(true),
Err(object_store::Error::NotFound { .. }) => {
let mut entries = self.store.list(Some(&location));
let has_prefix_entries = entries
.try_next()
.await
.map_err(|err| storage_backend_error("exists", uri, err))?
.is_some();
Ok(has_prefix_entries)
}
Err(err) => Err(storage_backend_error("exists", uri, err)),
}
}
}
impl S3StorageAdapter {
fn from_root_uri(root_uri: &str) -> Result<Self> {
let location = parse_s3_uri(root_uri)?;
let mut builder = AmazonS3Builder::from_env().with_bucket_name(&location.bucket);
if let Some(endpoint) = env::var("AWS_ENDPOINT_URL_S3")
.ok()
.or_else(|| env::var("AWS_ENDPOINT_URL").ok())
{
builder = builder.with_endpoint(&endpoint);
if endpoint.starts_with("http://") || env_var_truthy("AWS_ALLOW_HTTP") {
builder = builder.with_allow_http(true);
}
}
if env_var_truthy("AWS_S3_FORCE_PATH_STYLE") {
builder = builder.with_virtual_hosted_style_request(false);
}
let store = builder.build().map_err(|err| {
OmniError::manifest_internal(format!(
"failed to initialize s3 storage for '{}': {}",
root_uri, err
))
})?;
Ok(Self {
bucket: location.bucket,
store: Arc::new(store),
})
}
fn object_path(&self, uri: &str) -> Result<ObjectPath> {
let location = parse_s3_uri(uri)?;
if location.bucket != self.bucket {
return Err(OmniError::manifest_internal(format!(
"s3 storage bucket mismatch for '{}': expected '{}', found '{}'",
uri, self.bucket, location.bucket
)));
}
if location.key.is_empty() {
return Err(OmniError::manifest_internal(format!(
"s3 storage path is empty for '{}'",
uri
)));
}
ObjectPath::parse(&location.key).map_err(|err| {
OmniError::manifest_internal(format!("invalid s3 object path for '{}': {}", uri, err))
})
}
}
pub fn storage_kind_for_uri(uri: &str) -> StorageKind {
if uri.starts_with(S3_SCHEME_PREFIX) {
StorageKind::S3
} else {
StorageKind::Local
}
}
pub fn storage_for_uri(uri: &str) -> Result<Arc<dyn StorageAdapter>> {
match storage_kind_for_uri(uri) {
StorageKind::Local => Ok(Arc::new(LocalStorageAdapter)),
StorageKind::S3 => Ok(Arc::new(S3StorageAdapter::from_root_uri(uri)?)),
}
}
pub fn normalize_root_uri(uri: &str) -> Result<String> {
match storage_kind_for_uri(uri) {
StorageKind::Local => {
let path = local_path_from_uri(uri)?;
Ok(normalize_local_path(&path))
}
StorageKind::S3 => Ok(trim_trailing_slashes(uri)),
}
}
pub fn join_uri(root_uri: &str, relative_path: &str) -> String {
let relative_path = relative_path.trim_start_matches('/');
match storage_kind_for_uri(root_uri) {
StorageKind::S3 => {
let root = trim_trailing_slashes(root_uri);
if root.is_empty() {
relative_path.to_string()
} else {
format!("{}/{}", root, relative_path)
}
}
StorageKind::Local => {
let root = if root_uri.starts_with(FILE_SCHEME_PREFIX) {
local_path_from_file_uri(root_uri)
.map(|path| normalize_local_path(&path))
.unwrap_or_else(|_| trim_trailing_slashes(root_uri))
} else {
normalize_local_path(Path::new(root_uri))
};
let joined = Path::new(&root).join(relative_path);
normalize_local_path(&joined)
}
}
}
fn local_path_from_uri(uri: &str) -> Result<PathBuf> {
if uri.starts_with(FILE_SCHEME_PREFIX) {
return local_path_from_file_uri(uri);
}
Ok(PathBuf::from(uri))
}
fn local_path_from_file_uri(uri: &str) -> Result<PathBuf> {
let url = Url::parse(uri).map_err(|err| {
OmniError::manifest_internal(format!("invalid file uri '{}': {}", uri, err))
})?;
url.to_file_path()
.map_err(|_| OmniError::manifest_internal(format!("invalid file uri '{}'", uri)))
}
fn parse_s3_uri(uri: &str) -> Result<S3Location> {
let url = Url::parse(uri).map_err(|err| {
OmniError::manifest_internal(format!("invalid s3 uri '{}': {}", uri, err))
})?;
if url.scheme() != "s3" {
return Err(OmniError::manifest_internal(format!(
"unsupported s3 uri '{}'",
uri
)));
}
let bucket = url
.host_str()
.ok_or_else(|| OmniError::manifest_internal(format!("missing s3 bucket in '{}'", uri)))?;
Ok(S3Location {
bucket: bucket.to_string(),
key: url.path().trim_start_matches('/').to_string(),
})
}
fn storage_backend_error(action: &str, uri: &str, err: impl std::fmt::Display) -> OmniError {
OmniError::manifest_internal(format!("storage {} failed for '{}': {}", action, uri, err))
}
fn normalize_local_path(path: &Path) -> String {
let raw = path.as_os_str().to_string_lossy();
if raw == "/" {
return raw.to_string();
}
trim_trailing_slashes(&raw)
}
fn trim_trailing_slashes(value: &str) -> String {
let trimmed = value.trim_end_matches('/');
if trimmed.is_empty() {
value.to_string()
} else {
trimmed.to_string()
}
}
fn env_var_truthy(key: &str) -> bool {
matches!(
env::var(key).ok().as_deref(),
Some("1" | "true" | "TRUE" | "True" | "yes" | "YES" | "on" | "ON")
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn storage_backend_selection_is_scheme_aware() {
assert_eq!(storage_kind_for_uri("/tmp/repo"), StorageKind::Local);
assert_eq!(storage_kind_for_uri("file:///tmp/repo"), StorageKind::Local);
assert_eq!(
storage_kind_for_uri("s3://omnigraph-preview/repo"),
StorageKind::S3
);
}
#[test]
fn normalize_root_uri_preserves_local_and_s3_shapes() {
assert_eq!(
normalize_root_uri("/tmp/omnigraph/").unwrap(),
"/tmp/omnigraph"
);
assert_eq!(
normalize_root_uri("file:///tmp/omnigraph/").unwrap(),
"/tmp/omnigraph"
);
assert_eq!(
normalize_root_uri("s3://bucket/prefix/").unwrap(),
"s3://bucket/prefix"
);
}
#[test]
fn join_uri_handles_local_file_and_s3_roots() {
assert_eq!(
join_uri("/tmp/omnigraph", "_schema.pg"),
"/tmp/omnigraph/_schema.pg"
);
assert_eq!(
join_uri("file:///tmp/omnigraph", "_schema.pg"),
"/tmp/omnigraph/_schema.pg"
);
assert_eq!(
join_uri("s3://bucket/prefix", "_schema.pg"),
"s3://bucket/prefix/_schema.pg"
);
}
#[test]
fn parse_s3_uri_splits_bucket_and_key() {
let location = parse_s3_uri("s3://bucket/repo/_schema.pg").unwrap();
assert_eq!(location.bucket, "bucket");
assert_eq!(location.key, "repo/_schema.pg");
}
}

View file

@ -0,0 +1,603 @@
use arrow_array::{RecordBatch, UInt64Array};
use arrow_schema::SchemaRef;
use arrow_select::concat::concat_batches;
use futures::TryStreamExt;
use lance::Dataset;
use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner};
use lance::dataset::{MergeInsertBuilder, WhenMatched, WhenNotMatched, WriteMode, WriteParams};
use lance::datatypes::BlobHandling;
use lance::index::scalar::IndexDetails;
use lance_file::version::LanceFileVersion;
use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams};
use lance_index::{DatasetIndexExt, IndexType, is_system_index};
use lance_linalg::distance::MetricType;
use lance_table::format::IndexMetadata;
use std::sync::Arc;
use crate::db::manifest::{TableVersionMetadata, open_table_head_for_write};
use crate::db::{Snapshot, SubTableEntry};
use crate::error::{OmniError, Result};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TableState {
pub version: u64,
pub row_count: u64,
pub(crate) version_metadata: TableVersionMetadata,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeleteState {
pub version: u64,
pub row_count: u64,
pub deleted_rows: usize,
pub(crate) version_metadata: TableVersionMetadata,
}
#[derive(Debug, Clone)]
pub struct TableStore {
root_uri: String,
}
impl TableStore {
pub fn new(root_uri: &str) -> Self {
Self {
root_uri: root_uri.trim_end_matches('/').to_string(),
}
}
pub fn root_uri(&self) -> &str {
&self.root_uri
}
pub fn dataset_uri(&self, table_path: &str) -> String {
format!("{}/{}", self.root_uri, table_path)
}
fn table_path_from_dataset_uri(&self, dataset_uri: &str) -> Result<String> {
let prefix = format!("{}/", self.root_uri.trim_end_matches('/'));
let table_path = dataset_uri
.strip_prefix(&prefix)
.map(|path| path.to_string())
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"dataset uri '{}' is not under root '{}'",
dataset_uri, self.root_uri
))
})?;
Ok(table_path
.split_once("/tree/")
.map(|(path, _)| path.to_string())
.unwrap_or(table_path))
}
fn dataset_version_metadata(
&self,
dataset_uri: &str,
ds: &Dataset,
) -> Result<TableVersionMetadata> {
let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
TableVersionMetadata::from_dataset(&self.root_uri, &table_path, ds)
}
pub async fn open_snapshot_table(
&self,
snapshot: &Snapshot,
table_key: &str,
) -> Result<Dataset> {
snapshot.open(table_key).await
}
pub async fn open_at_entry(&self, entry: &SubTableEntry) -> Result<Dataset> {
entry.open(&self.root_uri).await
}
pub async fn open_dataset_head(
&self,
dataset_uri: &str,
branch: Option<&str>,
) -> Result<Dataset> {
let ds = Dataset::open(dataset_uri)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
match branch {
Some(branch) if branch != "main" => ds
.checkout_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string())),
_ => Ok(ds),
}
}
pub async fn open_dataset_head_for_write(
&self,
table_key: &str,
dataset_uri: &str,
branch: Option<&str>,
) -> Result<Dataset> {
let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
open_table_head_for_write(&self.root_uri, table_key, &table_path, branch).await
}
pub async fn delete_branch(&self, dataset_uri: &str, branch: &str) -> Result<()> {
let mut ds = Dataset::open(dataset_uri)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
ds.delete_branch(branch)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn open_dataset_at_state(
&self,
table_path: &str,
branch: Option<&str>,
version: u64,
) -> Result<Dataset> {
let ds = self
.open_dataset_head(&self.dataset_uri(table_path), branch)
.await?;
ds.checkout_version(version)
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub fn ensure_expected_version(
&self,
ds: &Dataset,
table_key: &str,
expected_version: u64,
) -> Result<()> {
if ds.version().version != expected_version {
return Err(OmniError::manifest_conflict(format!(
"version drift on {}: snapshot pinned v{} but dataset is at v{} — call sync_branch() and retry",
table_key,
expected_version,
ds.version().version
)));
}
Ok(())
}
pub async fn reopen_for_mutation(
&self,
dataset_uri: &str,
branch: Option<&str>,
table_key: &str,
expected_version: u64,
) -> Result<Dataset> {
let ds = self
.open_dataset_head_for_write(table_key, dataset_uri, branch)
.await?;
self.ensure_expected_version(&ds, table_key, expected_version)?;
Ok(ds)
}
pub async fn fork_branch_from_state(
&self,
dataset_uri: &str,
source_branch: Option<&str>,
table_key: &str,
source_version: u64,
target_branch: &str,
) -> Result<Dataset> {
let mut source_ds = self
.open_dataset_head(dataset_uri, source_branch)
.await?
.checkout_version(source_version)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.ensure_expected_version(&source_ds, table_key, source_version)?;
match source_ds
.create_branch(target_branch, source_version, None)
.await
{
Ok(_) => {}
Err(create_err) => match self
.open_dataset_head(dataset_uri, Some(target_branch))
.await
{
Ok(ds) => {
self.ensure_expected_version(&ds, table_key, source_version)?;
return Ok(ds);
}
Err(_) => return Err(OmniError::Lance(create_err.to_string())),
},
}
let ds = self
.open_dataset_head(dataset_uri, Some(target_branch))
.await?;
self.ensure_expected_version(&ds, table_key, source_version)?;
Ok(ds)
}
pub async fn scan_batches(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
self.scan(ds, None, None, None).await
}
pub async fn scan_batches_for_rewrite(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
if !has_blob_columns {
return self.scan_batches(ds).await;
}
let mut scanner = ds.scan();
scanner.blob_handling(BlobHandling::AllBinary);
scanner
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan_stream(
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
with_row_id: bool,
) -> Result<DatasetRecordBatchStream> {
Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, |_| Ok(())).await
}
pub async fn scan_stream_with<F>(
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
with_row_id: bool,
configure: F,
) -> Result<DatasetRecordBatchStream>
where
F: FnOnce(&mut Scanner) -> Result<()>,
{
let mut scanner = ds.scan();
if with_row_id {
scanner.with_row_id();
}
if let Some(columns) = projection {
scanner
.project(columns)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if let Some(filter_sql) = filter {
scanner
.filter(filter_sql)
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
if let Some(ordering) = order_by {
scanner
.order_by(Some(ordering))
.map_err(|e| OmniError::Lance(e.to_string()))?;
}
configure(&mut scanner)?;
scanner
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan(
&self,
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
) -> Result<Vec<RecordBatch>> {
Self::scan_stream(ds, projection, filter, order_by, false)
.await?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn scan_with<F>(
&self,
ds: &Dataset,
projection: Option<&[&str]>,
filter: Option<&str>,
order_by: Option<Vec<ColumnOrdering>>,
with_row_id: bool,
configure: F,
) -> Result<Vec<RecordBatch>>
where
F: FnOnce(&mut Scanner) -> Result<()>,
{
Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, configure)
.await?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn count_rows(&self, ds: &Dataset, filter: Option<String>) -> Result<usize> {
ds.count_rows(filter)
.await
.map(|count| count as usize)
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub fn dataset_version(&self, ds: &Dataset) -> u64 {
ds.version().version
}
pub async fn table_state(&self, dataset_uri: &str, ds: &Dataset) -> Result<TableState> {
Ok(TableState {
version: self.dataset_version(ds),
row_count: self.count_rows(ds, None).await? as u64,
version_metadata: self.dataset_version_metadata(dataset_uri, ds)?,
})
}
pub async fn append_batch(
&self,
dataset_uri: &str,
ds: &mut Dataset,
batch: RecordBatch,
) -> Result<TableState> {
if batch.num_rows() == 0 {
return self.table_state(dataset_uri, ds).await;
}
let schema = batch.schema();
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
let params = WriteParams {
mode: WriteMode::Append,
allow_external_blob_outside_bases: true,
..Default::default()
};
ds.append(reader, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.table_state(dataset_uri, ds).await
}
pub async fn append_or_create_batch(
dataset_uri: &str,
dataset: Option<Dataset>,
batch: RecordBatch,
) -> Result<Dataset> {
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
match dataset {
Some(mut ds) => {
let params = WriteParams {
mode: WriteMode::Append,
allow_external_blob_outside_bases: true,
..Default::default()
};
ds.append(reader, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(ds)
}
None => {
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
allow_external_blob_outside_bases: true,
..Default::default()
};
Dataset::write(reader, dataset_uri, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
}
}
pub async fn overwrite_batch(
&self,
dataset_uri: &str,
ds: &mut Dataset,
batch: RecordBatch,
) -> Result<TableState> {
ds.truncate_table()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.append_batch(dataset_uri, ds, batch).await
}
pub async fn merge_insert_batch(
&self,
dataset_uri: &str,
ds: Dataset,
batch: RecordBatch,
key_columns: Vec<String>,
when_matched: WhenMatched,
when_not_matched: WhenNotMatched,
) -> Result<TableState> {
if batch.num_rows() == 0 {
return self.table_state(dataset_uri, &ds).await;
}
// TODO(lance-upstream): MergeInsertBuilder does not accept WriteParams,
// so allow_external_blob_outside_bases cannot be set here. External URI
// blobs via merge_insert (LoadMode::Merge, mutations) are unsupported
// until Lance exposes WriteParams on MergeInsertBuilder.
let ds = Arc::new(ds);
let job = MergeInsertBuilder::try_new(ds, key_columns)
.map_err(|e| OmniError::Lance(e.to_string()))?
.when_matched(when_matched)
.when_not_matched(when_not_matched)
.try_build()
.map_err(|e| OmniError::Lance(e.to_string()))?;
let schema = batch.schema();
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
let (new_ds, _stats) = job
.execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
self.table_state(dataset_uri, &new_ds).await
}
pub async fn merge_insert_batches(
&self,
dataset_uri: &str,
ds: Dataset,
batches: Vec<RecordBatch>,
key_columns: Vec<String>,
when_matched: WhenMatched,
when_not_matched: WhenNotMatched,
) -> Result<TableState> {
if batches.is_empty() {
return self.table_state(dataset_uri, &ds).await;
}
let batch = if batches.len() == 1 {
batches.into_iter().next().unwrap()
} else {
let schema = batches[0].schema();
concat_batches(&schema, &batches).map_err(|e| OmniError::Lance(e.to_string()))?
};
self.merge_insert_batch(
dataset_uri,
ds,
batch,
key_columns,
when_matched,
when_not_matched,
)
.await
}
pub async fn delete_where(
&self,
dataset_uri: &str,
ds: &mut Dataset,
filter: &str,
) -> Result<DeleteState> {
let delete_result = ds
.delete(filter)
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(DeleteState {
version: delete_result.new_dataset.version().version,
row_count: self.count_rows(&delete_result.new_dataset, None).await? as u64,
deleted_rows: delete_result.num_deleted_rows as usize,
version_metadata: self
.dataset_version_metadata(dataset_uri, &delete_result.new_dataset)?,
})
}
async fn user_indices_for_column(
&self,
ds: &Dataset,
column: &str,
) -> Result<Vec<IndexMetadata>> {
let field_id = ds
.schema()
.field(column)
.map(|field| field.id)
.ok_or_else(|| {
OmniError::manifest_internal(format!(
"dataset is missing expected index column '{}'",
column
))
})?;
let indices = ds
.load_indices()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(indices
.iter()
.filter(|index| !is_system_index(index))
.filter(|index| index.fields.len() == 1 && index.fields[0] == field_id)
.cloned()
.collect())
}
pub async fn has_btree_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
let indices = self.user_indices_for_column(ds, column).await?;
Ok(indices.iter().any(|index| {
index
.index_details
.as_ref()
.map(|details| details.type_url.ends_with("BTreeIndexDetails"))
.unwrap_or(false)
}))
}
pub async fn has_fts_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
let indices = self.user_indices_for_column(ds, column).await?;
Ok(indices.iter().any(|index| {
index
.index_details
.as_ref()
.map(|details| IndexDetails(details.clone()).supports_fts())
.unwrap_or(false)
}))
}
pub async fn has_vector_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
let indices = self.user_indices_for_column(ds, column).await?;
Ok(indices.iter().any(|index| {
index
.index_details
.as_ref()
.map(|details| IndexDetails(details.clone()).is_vector())
.unwrap_or(false)
}))
}
pub async fn create_btree_index(&self, ds: &mut Dataset, columns: &[&str]) -> Result<()> {
let params = ScalarIndexParams::default();
ds.create_index_builder(columns, IndexType::BTree, &params)
.replace(true)
.await
.map(|_| ())
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn create_inverted_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
let params = InvertedIndexParams::default();
ds.create_index_builder(&[column], IndexType::Inverted, &params)
.replace(true)
.await
.map(|_| ())
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn create_vector_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
let params = lance::index::vector::VectorIndexParams::ivf_flat(1, MetricType::L2);
ds.create_index_builder(&[column], IndexType::Vector, &params)
.replace(true)
.await
.map(|_| ())
.map_err(|e| OmniError::Lance(e.to_string()))
}
pub async fn create_empty_dataset(dataset_uri: &str, schema: &SchemaRef) -> Result<Dataset> {
let batch = RecordBatch::new_empty(schema.clone());
Self::write_dataset(dataset_uri, batch).await
}
pub async fn first_row_id_for_filter(&self, ds: &Dataset, filter: &str) -> Result<Option<u64>> {
let batches = Self::scan_stream(ds, Some(&["id"]), Some(filter), None, true)
.await?
.try_collect::<Vec<RecordBatch>>()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
Ok(batches.iter().find_map(|batch| {
batch
.column_by_name("_rowid")
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
.and_then(|arr| (arr.len() > 0).then(|| arr.value(0)))
}))
}
pub async fn write_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
allow_external_blob_outside_bases: true,
..Default::default()
};
Dataset::write(reader, dataset_uri, Some(params))
.await
.map_err(|e| OmniError::Lance(e.to_string()))
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,677 @@
mod helpers;
use omnigraph::changes::{ChangeFilter, ChangeOp, EntityKind};
use omnigraph::db::commit_graph::CommitGraph;
use omnigraph::db::{MergeOutcome, Omnigraph, ReadTarget};
use helpers::*;
async fn head_commit_id(uri: &str, branch: Option<&str>) -> String {
let commit_graph = match branch {
Some(branch) => CommitGraph::open_at_branch(uri, branch).await.unwrap(),
None => CommitGraph::open(uri).await.unwrap(),
};
commit_graph.head_commit_id().await.unwrap().unwrap()
}
fn change_tuples(change_set: &omnigraph::changes::ChangeSet) -> Vec<(String, String, ChangeOp)> {
let mut tuples: Vec<_> = change_set
.changes
.iter()
.map(|change| (change.table_key.clone(), change.id.clone(), change.op))
.collect();
tuples.sort_by(|a, b| {
a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)).then_with(|| {
let a_op = match a.2 {
ChangeOp::Insert => 0,
ChangeOp::Update => 1,
ChangeOp::Delete => 2,
};
let b_op = match b.2 {
ChangeOp::Insert => 0,
ChangeOp::Update => 1,
ChangeOp::Delete => 2,
};
a_op.cmp(&b_op)
})
});
tuples
}
// ─── Same-branch diff tests ────────────────────────────────────────────────
#[tokio::test]
async fn diff_empty_when_nothing_changed() {
let dir = tempfile::tempdir().unwrap();
let db = init_and_load(&dir).await;
let v = snapshot_id(&db, "main").await.unwrap();
let cs = db
.diff_between(
ReadTarget::Snapshot(v.clone()),
ReadTarget::Snapshot(v),
&ChangeFilter::default(),
)
.await
.unwrap();
assert!(cs.changes.is_empty());
assert_eq!(cs.stats.inserts, 0);
assert_eq!(cs.stats.updates, 0);
assert_eq!(cs.stats.deletes, 0);
}
#[tokio::test]
async fn diff_detects_node_insert() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let v_before = snapshot_id(&db, "main").await.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
let cs = diff_since_branch(&db, "main", v_before, &ChangeFilter::default())
.await
.unwrap();
let inserts: Vec<_> = cs
.changes
.iter()
.filter(|c| c.op == ChangeOp::Insert && c.table_key == "node:Person")
.collect();
assert!(
!inserts.is_empty(),
"Should detect the Person insert. Got changes: {:?}",
cs.changes
.iter()
.map(|c| (&c.table_key, &c.id, c.op))
.collect::<Vec<_>>()
);
assert!(
inserts.iter().any(|c| c.id == "Eve"),
"Insert should contain Eve. Got: {:?}",
inserts.iter().map(|c| &c.id).collect::<Vec<_>>()
);
assert_eq!(inserts[0].kind, EntityKind::Node);
assert_eq!(inserts[0].endpoints, None);
}
#[tokio::test]
async fn diff_detects_node_update() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let v_before = snapshot_id(&db, "main").await.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Bob")], &[("$age", 99)]),
)
.await
.unwrap();
let cs = diff_since_branch(&db, "main", v_before, &ChangeFilter::default())
.await
.unwrap();
let updates: Vec<_> = cs
.changes
.iter()
.filter(|c| c.op == ChangeOp::Update && c.table_key == "node:Person")
.collect();
assert!(
!updates.is_empty(),
"Should detect the Person update. Got changes: {:?}",
cs.changes
.iter()
.map(|c| (&c.table_key, &c.id, c.op))
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn diff_detects_node_delete_with_cascade() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let v_before = snapshot_id(&db, "main").await.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"remove_person",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
let cs = diff_since_branch(&db, "main", v_before, &ChangeFilter::default())
.await
.unwrap();
// Should have node:Person delete
let person_deletes: Vec<_> = cs
.changes
.iter()
.filter(|c| c.op == ChangeOp::Delete && c.table_key == "node:Person")
.collect();
assert!(
!person_deletes.is_empty(),
"Should detect Person delete. Changes: {:?}",
cs.changes
.iter()
.map(|c| (&c.table_key, &c.id, c.op))
.collect::<Vec<_>>()
);
// Should also have edge:Knows cascade deletes
let edge_deletes: Vec<_> = cs
.changes
.iter()
.filter(|c| c.op == ChangeOp::Delete && c.table_key == "edge:Knows")
.collect();
assert!(
!edge_deletes.is_empty(),
"Should detect cascaded Knows edge deletes. Changes: {:?}",
cs.changes
.iter()
.map(|c| (&c.table_key, &c.id, c.op))
.collect::<Vec<_>>()
);
// Cascaded edge deletes should have endpoints
for edge_del in &edge_deletes {
assert!(
edge_del.endpoints.is_some(),
"Deleted edge should have endpoint context"
);
}
}
#[tokio::test]
async fn diff_detects_edge_insert_with_endpoints() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let v_before = snapshot_id(&db, "main").await.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"add_friend",
&params(&[("$from", "Bob"), ("$to", "Charlie")]),
)
.await
.unwrap();
let cs = diff_since_branch(&db, "main", v_before, &ChangeFilter::default())
.await
.unwrap();
let edge_inserts: Vec<_> = cs
.changes
.iter()
.filter(|c| c.op == ChangeOp::Insert && c.table_key == "edge:Knows")
.collect();
assert!(
!edge_inserts.is_empty(),
"Should detect Knows edge insert. Changes: {:?}",
cs.changes
.iter()
.map(|c| (&c.table_key, &c.id, c.op))
.collect::<Vec<_>>()
);
let e = &edge_inserts[0];
assert_eq!(e.kind, EntityKind::Edge);
let ep = e
.endpoints
.as_ref()
.expect("Edge insert should have endpoints");
assert!(!ep.src.is_empty(), "src should not be empty");
assert!(!ep.dst.is_empty(), "dst should not be empty");
}
// ─── Filter tests ──────────────────────────────────────────────────────────
#[tokio::test]
async fn filter_by_type_name_skips_non_matching() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let v_before = snapshot_id(&db, "main").await.unwrap();
// Insert a person (node:Person) and add a friend (edge:Knows)
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "FilterTest")], &[("$age", 30)]),
)
.await
.unwrap();
// Filter to Company only — should not see Person changes
let filter = ChangeFilter {
type_names: Some(vec!["Company".to_string()]),
..Default::default()
};
let cs = diff_since_branch(&db, "main", v_before, &filter)
.await
.unwrap();
assert!(
cs.changes.is_empty(),
"Filter to Company should skip Person changes. Got: {:?}",
cs.changes
.iter()
.map(|c| (&c.table_key, &c.id, c.op))
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn filter_by_op_skips_unwanted_operations() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let v_before = snapshot_id(&db, "main").await.unwrap();
// Insert Eve, update Bob, delete Alice
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Bob")], &[("$age", 99)]),
)
.await
.unwrap();
// Filter to Insert only
let filter = ChangeFilter {
ops: Some(vec![ChangeOp::Insert]),
..Default::default()
};
let cs = diff_since_branch(&db, "main", v_before, &filter)
.await
.unwrap();
// Should only have inserts, no updates or deletes
for c in &cs.changes {
assert_eq!(
c.op,
ChangeOp::Insert,
"Filter for Insert-only should not include {:?} for {} ({})",
c.op,
c.table_key,
c.id
);
}
}
// ─── Cross-branch diff tests ──────────────────────────────────────────────
#[tokio::test]
async fn diff_after_merge_reports_actual_changes() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = init_and_load(&dir).await;
main.ensure_indices().await.unwrap();
let v_before_branch = snapshot_id(&main, "main").await.unwrap();
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
// Main updates Bob
mutate_main(
&mut main,
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Bob")], &[("$age", 26)]),
)
.await
.unwrap();
// Feature inserts Eve
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
let outcome = main.branch_merge("feature", "main").await.unwrap();
assert_eq!(outcome, MergeOutcome::Merged);
// Diff from pre-branch to post-merge on main
let cs = diff_since_branch(&main, "main", v_before_branch, &ChangeFilter::default())
.await
.unwrap();
// Should have:
// - Person insert (Eve) — from the merge
// - Person update (Bob) — from the main write
// Should NOT have: all original persons re-reported as inserts
let person_changes: Vec<_> = cs
.changes
.iter()
.filter(|c| c.table_key == "node:Person")
.collect();
let person_inserts: Vec<_> = person_changes
.iter()
.filter(|c| c.op == ChangeOp::Insert)
.collect();
let person_updates: Vec<_> = person_changes
.iter()
.filter(|c| c.op == ChangeOp::Update)
.collect();
// There should be exactly 1 insert (Eve) not all persons
assert!(
person_inserts.len() <= 2,
"After surgical merge, should not re-report all persons as inserts. \
Got {} inserts: {:?}",
person_inserts.len(),
person_inserts.iter().map(|c| &c.id).collect::<Vec<_>>()
);
// Bob's update should be detected
assert!(
!person_updates.is_empty() || person_inserts.len() > 0,
"Should detect Bob's age update or Eve's insert"
);
}
#[tokio::test]
async fn diff_commits_resolves_feature_commit_from_main_handle() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = init_and_load(&dir).await;
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
let main_head = CommitGraph::open(uri)
.await
.unwrap()
.head_commit()
.await
.unwrap()
.unwrap()
.graph_commit_id;
let feature_head = CommitGraph::open_at_branch(uri, "feature")
.await
.unwrap()
.head_commit()
.await
.unwrap()
.unwrap()
.graph_commit_id;
let cs = main
.diff_commits(&main_head, &feature_head, &ChangeFilter::default())
.await
.unwrap();
assert!(
cs.changes
.iter()
.any(|change| change.op == ChangeOp::Insert && change.id == "Eve"),
"expected feature-only insert to be diffable from a main handle"
);
}
#[tokio::test]
async fn cross_branch_diff_honors_insert_only_filter() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = init_and_load(&dir).await;
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
let main_head = CommitGraph::open(uri)
.await
.unwrap()
.head_commit()
.await
.unwrap()
.unwrap()
.graph_commit_id;
let feature_head = CommitGraph::open_at_branch(uri, "feature")
.await
.unwrap()
.head_commit()
.await
.unwrap()
.unwrap()
.graph_commit_id;
let filter = ChangeFilter {
ops: Some(vec![ChangeOp::Insert]),
..Default::default()
};
let cs = main
.diff_commits(&main_head, &feature_head, &filter)
.await
.unwrap();
assert!(!cs.changes.is_empty());
assert!(
cs.changes
.iter()
.all(|change| change.op == ChangeOp::Insert)
);
}
#[tokio::test]
async fn diff_commits_resolves_commits_across_branches_from_any_handle() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = init_and_load(&dir).await;
let base_commit = head_commit_id(uri, None).await;
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
let feature_commit = head_commit_id(uri, Some("feature")).await;
let from_main = main
.diff_commits(&base_commit, &feature_commit, &ChangeFilter::default())
.await
.unwrap();
let from_feature = feature
.diff_commits(&base_commit, &feature_commit, &ChangeFilter::default())
.await
.unwrap();
assert_eq!(change_tuples(&from_main), change_tuples(&from_feature));
assert!(from_main.changes.iter().any(|change| {
change.table_key == "node:Person" && change.id == "Eve" && change.op == ChangeOp::Insert
}));
}
#[tokio::test]
async fn cross_lineage_diff_honors_delete_only_filter() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = init_and_load(&dir).await;
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
let before = snapshot_id(&feature, "feature").await.unwrap();
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Bob")], &[("$age", 99)]),
)
.await
.unwrap();
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"remove_person",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
let filter = ChangeFilter {
ops: Some(vec![ChangeOp::Delete]),
..Default::default()
};
let change_set = diff_since_branch(&feature, "feature", before, &filter)
.await
.unwrap();
assert!(
!change_set.changes.is_empty(),
"expected delete changes after removing Alice"
);
assert!(
change_set
.changes
.iter()
.all(|change| change.op == ChangeOp::Delete)
);
}
#[tokio::test]
async fn same_branch_diff_across_first_lazy_fork_detects_update() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = init_and_load(&dir).await;
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
let before = snapshot_id(&feature, "feature").await.unwrap();
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Bob")], &[("$age", 77)]),
)
.await
.unwrap();
let change_set = diff_since_branch(&feature, "feature", before, &ChangeFilter::default())
.await
.unwrap();
assert!(change_set.changes.iter().any(|change| {
change.table_key == "node:Person" && change.id == "Bob" && change.op == ChangeOp::Update
}));
}
#[tokio::test]
async fn diff_commits_cross_branch_reports_property_only_updates() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = init_and_load(&dir).await;
let base_commit = head_commit_id(uri, None).await;
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Bob")], &[("$age", 55)]),
)
.await
.unwrap();
let feature_commit = head_commit_id(uri, Some("feature")).await;
let change_set = main
.diff_commits(&base_commit, &feature_commit, &ChangeFilter::default())
.await
.unwrap();
assert!(change_set.changes.iter().any(|change| {
change.table_key == "node:Person" && change.id == "Bob" && change.op == ChangeOp::Update
}));
assert!(!change_set.changes.iter().any(|change| {
change.table_key == "node:Person" && change.id == "Bob" && change.op == ChangeOp::Insert
}));
}
#[tokio::test]
async fn diff_commits_ignores_row_version_only_differences() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = init_and_load(&dir).await;
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(uri).await.unwrap();
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Bob")], &[("$age", 55)]),
)
.await
.unwrap();
let feature_commit = head_commit_id(uri, Some("feature")).await;
mutate_main(
&mut main,
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Bob")], &[("$age", 55)]),
)
.await
.unwrap();
let main_commit = head_commit_id(uri, None).await;
let change_set = main
.diff_commits(&main_commit, &feature_commit, &ChangeFilter::default())
.await
.unwrap();
assert!(
change_set.changes.is_empty(),
"identical user-visible state should not produce diff entries: {:?}",
change_set.changes
);
}

View file

@ -0,0 +1,574 @@
mod helpers;
use arrow_array::{Array, Date32Array, Int32Array, StringArray};
use futures::TryStreamExt;
use omnigraph::db::Omnigraph;
use omnigraph::loader::{LoadMode, load_jsonl};
use omnigraph_compiler::ir::ParamMap;
use omnigraph_compiler::query::ast::Literal;
use helpers::*;
// ─── Snapshot data-level isolation ──────────────────────────────────────────
#[tokio::test]
async fn snapshot_returns_stale_data_after_write() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Snapshot BEFORE mutation
let snap_before = snapshot_main(&db).await.unwrap();
// Insert a new person
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
// Snapshot AFTER mutation
let snap_after = snapshot_main(&db).await.unwrap();
// Old snapshot should still see 4 persons
let ds_before = snap_before.open("node:Person").await.unwrap();
assert_eq!(ds_before.count_rows(None).await.unwrap(), 4);
// New snapshot should see 5 persons
let ds_after = snap_after.open("node:Person").await.unwrap();
assert_eq!(ds_after.count_rows(None).await.unwrap(), 5);
// Verify Eve is NOT in old snapshot's data
let batches_before: Vec<arrow_array::RecordBatch> = ds_before
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let ids_before = collect_column_strings(&batches_before, "id");
assert!(!ids_before.contains(&"Eve".to_string()));
// Verify Eve IS in new snapshot's data
let batches_after: Vec<arrow_array::RecordBatch> = ds_after
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let ids_after = collect_column_strings(&batches_after, "id");
assert!(ids_after.contains(&"Eve".to_string()));
}
// ─── LoadMode::Merge ────────────────────────────────────────────────────────
#[tokio::test]
async fn load_merge_upserts_existing_and_inserts_new() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
// Load Alice(30) and Bob(25) via Overwrite
let initial = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}}
{"type": "Person", "data": {"name": "Bob", "age": 25}}"#;
load_jsonl(&mut db, initial, LoadMode::Overwrite)
.await
.unwrap();
assert_eq!(count_rows(&db, "node:Person").await, 2);
// Merge: Alice updated to age=31, Charlie is new
let merge_data = r#"{"type": "Person", "data": {"name": "Alice", "age": 31}}
{"type": "Person", "data": {"name": "Charlie", "age": 35}}"#;
load_jsonl(&mut db, merge_data, LoadMode::Merge)
.await
.unwrap();
// Should have 3 persons total (not 4)
assert_eq!(count_rows(&db, "node:Person").await, 3);
// Verify individual values
let batches = read_table(&db, "node:Person").await;
let batch = &batches[0];
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let ages = batch
.column_by_name("age")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
for i in 0..batch.num_rows() {
match ids.value(i) {
"Alice" => assert_eq!(ages.value(i), 31, "Alice should be updated to 31"),
"Bob" => assert_eq!(ages.value(i), 25, "Bob should be unchanged"),
"Charlie" => assert_eq!(ages.value(i), 35, "Charlie should be inserted"),
other => panic!("unexpected person: {}", other),
}
}
}
#[tokio::test]
async fn cross_type_traversal_deduplicates_duplicate_edges() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let schema = r#"
node Person { name: String @key }
node Company { name: String @key }
edge WorksAt: Person -> Company
"#;
let data = r#"{"type":"Person","data":{"name":"Alice"}}
{"type":"Company","data":{"name":"Acme"}}
{"edge":"WorksAt","from":"Alice","to":"Acme"}
{"edge":"WorksAt","from":"Alice","to":"Acme"}"#;
let query = r#"
query company($name: String) {
match {
$p: Person { name: $name }
$p worksAt $c
}
return { $c.name }
}
"#;
let mut db = Omnigraph::init(uri, schema).await.unwrap();
load_jsonl(&mut db, data, LoadMode::Overwrite)
.await
.unwrap();
let result = query_main(&mut db, query, "company", &params(&[("$name", "Alice")]))
.await
.unwrap();
assert_eq!(result.num_rows(), 1);
}
// ─── Multi-writer refresh ───────────────────────────────────────────────────
#[tokio::test]
async fn explicit_target_query_sees_other_writer_commits_without_refresh() {
let dir = tempfile::tempdir().unwrap();
let _db = init_and_load(&dir).await;
drop(_db);
let uri = dir.path().to_str().unwrap();
// Two independent handles to the same repo
let mut db1 = Omnigraph::open(uri).await.unwrap();
let mut db2 = Omnigraph::open(uri).await.unwrap();
// Writer 1 inserts Eve
mutate_main(
&mut db1,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
// Explicit-target reads resolve the latest branch head and should see Eve
let qr = query_main(
&mut db2,
TEST_QUERIES,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(qr.num_rows(), 1, "explicit target reads should see Eve");
}
#[tokio::test]
async fn explicit_target_query_rebuilds_graph_index_after_external_edge_write() {
let dir = tempfile::tempdir().unwrap();
let _db = init_and_load(&dir).await;
drop(_db);
let uri = dir.path().to_str().unwrap();
let mut db1 = Omnigraph::open(uri).await.unwrap();
let mut db2 = Omnigraph::open(uri).await.unwrap();
let warm = query_main(
&mut db2,
TEST_QUERIES,
"friends_of",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
assert_eq!(warm.num_rows(), 2);
mutate_main(
&mut db1,
MUTATION_QUERIES,
"add_friend",
&params(&[("$from", "Alice"), ("$to", "Diana")]),
)
.await
.unwrap();
let refreshed = query_main(
&mut db2,
TEST_QUERIES,
"friends_of",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
assert_eq!(
refreshed.num_rows(),
3,
"explicit target reads should rebuild topology after edge change"
);
let batch = refreshed.concat_batches().unwrap();
let names = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let values: Vec<&str> = (0..names.len()).map(|i| names.value(i)).collect();
assert!(values.contains(&"Bob"));
assert!(values.contains(&"Diana"));
}
// ─── Null handling ──────────────────────────────────────────────────────────
#[tokio::test]
async fn null_values_in_filter_and_projection() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
// Load data: Alice has age, Bob has null age, Charlie has age
let data = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}}
{"type": "Person", "data": {"name": "Bob"}}
{"type": "Person", "data": {"name": "Charlie", "age": 35}}"#;
load_jsonl(&mut db, data, LoadMode::Overwrite)
.await
.unwrap();
// Filter: age > 30 should exclude Bob (null) and Alice (30), keep Charlie (35)
let queries = r#"
query older_than_30() {
match {
$p: Person
$p.age > 30
}
return { $p.name, $p.age }
order { $p.age desc }
}
query all_persons() {
match { $p: Person }
return { $p.name, $p.age }
order { $p.age desc }
}
"#;
let result = query_main(&mut db, queries, "older_than_30", &ParamMap::new())
.await
.unwrap();
assert_eq!(result.num_rows(), 1);
let batch = &result.batches()[0];
let names = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "Charlie");
// Projection: Bob's age should be null
let all = query_main(&mut db, queries, "all_persons", &ParamMap::new())
.await
.unwrap();
let batch = &all.batches()[0];
let ids = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let ages = batch
.column(1)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
for i in 0..batch.num_rows() {
if ids.value(i) == "Bob" {
assert!(ages.is_null(i), "Bob's age should be null");
}
}
}
// ─── Graph index after node+edge insert ─────────────────────────────────────
#[tokio::test]
async fn traversal_works_after_node_then_edge_insert() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Warm up the graph index cache by running a traversal
let _ = query_main(
&mut db,
TEST_QUERIES,
"friends_of",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
// Insert a new node (does NOT invalidate graph index)
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Frank")], &[("$age", 40)]),
)
.await
.unwrap();
// Insert an edge from Frank → Alice (DOES invalidate graph index)
mutate_main(
&mut db,
MUTATION_QUERIES,
"add_friend",
&params(&[("$from", "Frank"), ("$to", "Alice")]),
)
.await
.unwrap();
// Traversal should work: Frank → Alice
let result = query_main(
&mut db,
TEST_QUERIES,
"friends_of",
&params(&[("$name", "Frank")]),
)
.await
.unwrap();
assert_eq!(result.num_rows(), 1);
let batch = result.concat_batches().unwrap();
let names = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "Alice");
}
// ─── Edge property insert ───────────────────────────────────────────────────
#[tokio::test]
async fn insert_edge_with_property() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Knows has `since: Date?` property
let queries = r#"
query add_friend_since($from: String, $to: String, $since: Date) {
insert Knows { from: $from, to: $to, since: $since }
}
"#;
let mut p = params(&[("$from", "Diana"), ("$to", "Bob")]);
p.insert("since".to_string(), Literal::Date("2024-06-15".to_string()));
let result = mutate_main(&mut db, queries, "add_friend_since", &p)
.await
.unwrap();
assert_eq!(result.affected_edges, 1);
// Verify the edge property was stored
let batches = read_table(&db, "edge:Knows").await;
let mut found = false;
for batch in &batches {
let srcs = batch
.column_by_name("src")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let dsts = batch
.column_by_name("dst")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let since = batch
.column_by_name("since")
.unwrap()
.as_any()
.downcast_ref::<Date32Array>()
.unwrap();
for i in 0..batch.num_rows() {
if srcs.value(i) == "Diana" && dsts.value(i) == "Bob" {
assert!(!since.is_null(i), "since should not be null");
found = true;
}
}
}
assert!(found, "should find Diana→Bob edge");
}
// ─── Update / delete no-match ───────────────────────────────────────────────
#[tokio::test]
async fn update_nonexistent_returns_zero_affected() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let result = mutate_main(
&mut db,
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Nobody")], &[("$age", 99)]),
)
.await
.unwrap();
assert_eq!(result.affected_nodes, 0);
}
#[tokio::test]
async fn delete_nonexistent_returns_zero_affected() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let result = mutate_main(
&mut db,
MUTATION_QUERIES,
"remove_person",
&params(&[("$name", "Nobody")]),
)
.await
.unwrap();
assert_eq!(result.affected_nodes, 0);
assert_eq!(result.affected_edges, 0);
// All 4 persons still intact
assert_eq!(count_rows(&db, "node:Person").await, 4);
}
// ─── Large batch load ───────────────────────────────────────────────────────
#[tokio::test]
async fn large_batch_load_and_query() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let schema = r#"
node Item {
name: String @key
value: I32
}
"#;
let mut db = Omnigraph::init(uri, schema).await.unwrap();
// Generate 500 items
let mut lines = Vec::with_capacity(500);
for i in 0..500 {
lines.push(format!(
r#"{{"type": "Item", "data": {{"name": "item_{:04}", "value": {}}}}}"#,
i, i
));
}
let data = lines.join("\n");
load_jsonl(&mut db, &data, LoadMode::Overwrite)
.await
.unwrap();
assert_eq!(count_rows(&db, "node:Item").await, 500);
// Query with filter — value > 490
let queries = r#"
query high_value() {
match {
$i: Item
$i.value > 490
}
return { $i.name, $i.value }
order { $i.value asc }
}
"#;
let result = query_main(&mut db, queries, "high_value", &ParamMap::new())
.await
.unwrap();
// Items 491..499 = 9 items
assert_eq!(result.num_rows(), 9);
let batch = &result.batches()[0];
let values = batch
.column(1)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(values.value(0), 491);
assert_eq!(values.value(8), 499);
}
// ─── Regression: public mutation on stale handle still applies to latest head ──────────────
#[tokio::test]
async fn stale_handle_public_mutation_uses_latest_target_head() {
let dir = tempfile::tempdir().unwrap();
let _db = init_and_load(&dir).await;
drop(_db);
let uri = dir.path().to_str().unwrap();
let mut db1 = Omnigraph::open(uri).await.unwrap();
let mut db2 = Omnigraph::open(uri).await.unwrap();
// Writer 1 inserts — advances the Person sub-table version
mutate_main(
&mut db1,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
// Writer 2 (stale) mutates through the public transactional path.
// It should stage from the latest target head rather than replaying a stale write.
mutate_main(
&mut db2,
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Alice")], &[("$age", 99)]),
)
.await
.unwrap();
let result = query_main(
&mut db2,
TEST_QUERIES,
"get_person",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
assert_eq!(result.num_rows(), 1);
assert_eq!(result.to_rust_json()[0]["p.age"], serde_json::json!(99));
let eve = query_main(
&mut db2,
TEST_QUERIES,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(eve.num_rows(), 1, "concurrent insert should be preserved");
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,183 @@
mod helpers;
use arrow_array::{Array, StringArray};
use omnigraph::db::{Omnigraph, ReadTarget};
use omnigraph::loader::{LoadMode, load_jsonl};
use helpers::*;
const EXPORT_MUTATIONS: &str = r#"
query insert_person($name: String, $age: I32) {
insert Person { name: $name, age: $age }
}
query add_friend($from: String, $to: String) {
insert Knows { from: $from, to: $to }
}
"#;
const NOTE_SCHEMA: &str = r#"
node Note {
text: String
}
edge References: Note -> Note
"#;
const NOTE_DATA: &str = r#"
{"type":"Note","data":{"id":"note-1","text":"Alpha"}}
{"type":"Note","data":{"id":"note-2","text":"Beta"}}
{"edge":"References","from":"note-1","to":"note-2","data":{"id":"edge-1"}}
"#;
#[tokio::test]
async fn export_jsonl_round_trips_branch_snapshot() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
db.branch_create_from(ReadTarget::branch("main"), "feature")
.await
.unwrap();
db.mutate(
"feature",
EXPORT_MUTATIONS,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 29)]),
)
.await
.unwrap();
db.mutate(
"feature",
EXPORT_MUTATIONS,
"add_friend",
&params(&[("$from", "Eve"), ("$to", "Alice")]),
)
.await
.unwrap();
let main_jsonl = db.export_jsonl("main", &[], &[]).await.unwrap();
let feature_jsonl = db.export_jsonl("feature", &[], &[]).await.unwrap();
let imported_main_dir = tempfile::tempdir().unwrap();
let imported_feature_dir = tempfile::tempdir().unwrap();
let mut imported_main =
Omnigraph::init(imported_main_dir.path().to_str().unwrap(), TEST_SCHEMA)
.await
.unwrap();
let mut imported_feature =
Omnigraph::init(imported_feature_dir.path().to_str().unwrap(), TEST_SCHEMA)
.await
.unwrap();
load_jsonl(&mut imported_main, &main_jsonl, LoadMode::Overwrite)
.await
.unwrap();
load_jsonl(&mut imported_feature, &feature_jsonl, LoadMode::Overwrite)
.await
.unwrap();
assert_eq!(count_rows(&db, "node:Person").await, 4);
assert_eq!(count_rows_branch(&db, "feature", "node:Person").await, 5);
assert_eq!(count_rows(&imported_main, "node:Person").await, 4);
assert_eq!(count_rows(&imported_feature, "node:Person").await, 5);
assert_eq!(count_rows(&imported_main, "edge:Knows").await, 3);
assert_eq!(count_rows(&imported_feature, "edge:Knows").await, 4);
}
#[tokio::test]
async fn export_jsonl_preserves_explicit_ids_for_non_key_graphs() {
let dir = tempfile::tempdir().unwrap();
let mut db = Omnigraph::init(dir.path().to_str().unwrap(), NOTE_SCHEMA)
.await
.unwrap();
load_jsonl(&mut db, NOTE_DATA, LoadMode::Overwrite)
.await
.unwrap();
let exported = db.export_jsonl("main", &[], &[]).await.unwrap();
let imported_dir = tempfile::tempdir().unwrap();
let mut imported = Omnigraph::init(imported_dir.path().to_str().unwrap(), NOTE_SCHEMA)
.await
.unwrap();
load_jsonl(&mut imported, &exported, LoadMode::Overwrite)
.await
.unwrap();
let node_batches = read_table(&imported, "node:Note").await;
let node_ids = collect_column_strings(&node_batches, "id");
assert_eq!(node_ids, vec!["note-1".to_string(), "note-2".to_string()]);
let edge_batches = read_table(&imported, "edge:References").await;
let edge_ids = collect_column_strings(&edge_batches, "id");
assert_eq!(edge_ids, vec!["edge-1".to_string()]);
let srcs = edge_batches[0]
.column_by_name("src")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let dsts = edge_batches[0]
.column_by_name("dst")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(srcs.value(0), "note-1");
assert_eq!(dsts.value(0), "note-2");
}
// ─── Regression: export with blob columns ────────────────────────────────────
#[tokio::test]
async fn export_jsonl_with_blob_type() {
// Regression: export on types with blob columns failed with
// "Schema error: Can not append column _rowaddr on schema" because
// Lance 4's take_blobs duplicated _rowaddr on the unsorted path.
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
const BLOB_SCHEMA: &str = r#"
node Document {
title: String @key
content: Blob?
}
"#;
let mut db = Omnigraph::init(uri, BLOB_SCHEMA).await.unwrap();
let data = concat!(
"{\"type\": \"Document\", \"data\": {\"title\": \"readme\", \"content\": \"base64:SGVsbG8=\"}}\n",
"{\"type\": \"Document\", \"data\": {\"title\": \"empty\"}}\n",
);
load_jsonl(&mut db, data, LoadMode::Overwrite)
.await
.unwrap();
// Export should succeed
let exported = db.export_jsonl("main", &[], &[]).await.unwrap();
assert!(
exported.contains("readme"),
"export should contain readme doc"
);
// Verify blob value is in the export
assert!(
exported.contains("base64:") || exported.contains("SGVsbG8"),
"export should contain blob data as base64"
);
// Round-trip: re-import and verify blob data survives
let imported_dir = tempfile::tempdir().unwrap();
let imported_uri = imported_dir.path().to_str().unwrap();
let mut imported = Omnigraph::init(imported_uri, BLOB_SCHEMA).await.unwrap();
load_jsonl(&mut imported, &exported, LoadMode::Overwrite)
.await
.unwrap();
let blob = imported
.read_blob("Document", "readme", "content")
.await
.unwrap();
let bytes = blob.read().await.unwrap();
assert_eq!(&bytes[..], b"Hello");
}

View file

@ -0,0 +1,47 @@
#![cfg(feature = "failpoints")]
mod helpers;
use fail::FailScenario;
use omnigraph::db::Omnigraph;
use omnigraph::failpoints::ScopedFailPoint;
use helpers::{MUTATION_QUERIES, mixed_params};
#[tokio::test]
async fn branch_create_failpoint_triggers() {
let _scenario = FailScenario::setup();
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, helpers::TEST_SCHEMA).await.unwrap();
let _failpoint = ScopedFailPoint::new("branch_create.after_manifest_branch_create", "return");
let err = db.branch_create("feature").await.unwrap_err();
assert!(
err.to_string()
.contains("injected failpoint triggered: branch_create.after_manifest_branch_create")
);
}
#[tokio::test]
async fn graph_publish_failpoint_triggers_before_commit_append() {
let _scenario = FailScenario::setup();
let dir = tempfile::tempdir().unwrap();
let mut db = Omnigraph::init(dir.path().to_str().unwrap(), helpers::TEST_SCHEMA)
.await
.unwrap();
let _failpoint = ScopedFailPoint::new("graph_publish.before_commit_append", "return");
let err = mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap_err();
assert!(
err.to_string()
.contains("injected failpoint triggered: graph_publish.before_commit_append")
);
}

View file

@ -0,0 +1,13 @@
{"type": "Actor", "data": {"slug": "aaron", "name": "Aaron"}}
{"type": "Actor", "data": {"slug": "bruno", "name": "Bruno"}}
{"type": "Actor", "data": {"slug": "jorge", "name": "Jorge"}}
{"type": "Actor", "data": {"slug": "muneeb", "name": "Muneeb"}}
{"type": "Actor", "data": {"slug": "ragnor", "name": "Ragnor"}}
{"type": "Actor", "data": {"slug": "andrew", "name": "Andrew"}}
{"type": "Signal", "data": {"slug": "zylon-private-ai-platform", "title": "Zylon.ai positions as complete on-premise enterprise AI platform for regulated industries", "body": "Zylon.ai positions itself as a complete, fully private (100% on-premise) enterprise AI platform built explicitly for regulated industries, emphasizing air-gapped deployability, data sovereignty (no external cloud dependency), and predictable fixed-cost economics (no per-token pricing).", "category": "competitor", "strength": "strong", "observed_at": "2026-03-27", "source": "https://www.zylon.ai/"}}
{"type": "Decision", "data": {"slug": "create-360-ai-infra", "title": "Create 360 AI Infra offering", "body": "Build a comprehensive 360-degree AI infrastructure offering in response to competitors like Zylon positioning complete on-premise AI platforms for regulated industries.", "status": "proposed", "urgency": "high", "decided_at": "2026-03-27"}}
{"type": "Trace", "data": {"slug": "jorge-spots-zylon", "title": "Jorge spots Zylon.ai competitor signal", "body": "Jorge identified Zylon.ai as a new competitor positioning a fully private enterprise AI platform targeting regulated industries with air-gapped deployment and fixed-cost pricing.", "kind": "note", "recorded_at": "2026-03-27", "source": "https://www.zylon.ai/"}}
{"edge": "OwnedBy", "from": "create-360-ai-infra", "to": "andrew"}
{"edge": "RecordedBy", "from": "jorge-spots-zylon", "to": "jorge"}
{"edge": "Triggered", "from": "zylon-private-ai-platform", "to": "create-360-ai-infra"}
{"edge": "Supports", "from": "jorge-spots-zylon", "to": "create-360-ai-infra"}

View file

@ -0,0 +1,78 @@
// Context graph: decisions, the people behind them,
// the evidence trail, and market signals that inform them.
// ── Nodes ────────────────────────────────────────────
node Actor {
slug: String @key
name: String
email: String? @unique
}
node Decision {
slug: String @key
title: String @index
body: String?
status: enum(proposed, accepted, rejected, superseded)
urgency: enum(low, normal, high, critical)
decided_at: Date?
}
node Trace {
slug: String @key
title: String @index
body: String?
kind: enum(note, discussion, experiment, review, meeting, document)
recorded_at: Date
source: String?
}
node Signal {
slug: String @key
title: String @index
body: String?
category: enum(competitor, market, regulatory, technology, customer)
strength: enum(strong, moderate, weak)
observed_at: Date
source: String?
}
node Artifact {
slug: String @key
title: String @index
kind: enum(doc, presentation, proposal, spec, report, memo)
url: String?
created_at: Date
}
// ── Ownership / participation ────────────────────────
edge OwnedBy: Decision -> Actor @card(1..1)
edge ParticipatedIn: Actor -> Decision
edge RecordedBy: Trace -> Actor @card(1..1)
edge AuthoredBy: Artifact -> Actor @card(1..1)
// ── Evidence trail ───────────────────────────────────
edge Supports: Trace -> Decision
edge Attached: Artifact -> Decision
edge CitedIn: Artifact -> Trace
// ── Signal linkage ───────────────────────────────────
edge Triggered: Signal -> Decision
edge Correlates: Signal -> Signal {
@unique(src, dst)
}
// ── Decision lineage ─────────────────────────────────
edge Supersedes: Decision -> Decision {
@unique(src, dst)
}

View file

@ -0,0 +1,48 @@
# Enterprise Procurement Risk Memo
## Situation
The buyer entered procurement for annual renewal and asked for a bundled proposal that combines platform licensing, implementation support, and SLA uplift.
Legal requested two rounds of redlines and now requires explicit language for data residency, deletion timelines, and subprocessors.
Security asked for the full questionnaire, pen-test summary, SOC evidence package, and a named escalation owner for incident response coordination.
Finance requested price hold terms through quarter close and requires a clean net amount with no conditional side letters.
## Current Friction
The commercial owner reports that each team is operating on a different timeline.
Procurement prefers a single consolidated response packet, but legal and security are still updating separate drafts.
The buyer champion is supportive but cannot route final approval until the redline and security sections are complete.
Two approvers are out next week, which introduces a calendar risk for final signoff.
## Evidence From Recent Calls
- Buyer said the risk is not product fit, it is internal process load.
- Procurement requested one owner for all responses to avoid thread drift.
- Legal asked for an explicit breach notification interval in the MSA.
- Security flagged third-party dependency disclosure as incomplete.
- Finance asked for forecast certainty before they release PO authority.
## Operational Notes
1. The account team should treat this as a coordination problem, not a persuasion problem.
2. Every open item needs owner, due date, and blocking dependency.
3. Replies should be centralized in one tracker to avoid inconsistent statements.
4. Escalation should happen early when legal language depends on security attestations.
5. The champion should get a concise status summary after each workday.
## Risk Register
- **Timeline risk:** medium-high due to calendar compression and approver availability.
- **Compliance risk:** medium due to unresolved security questionnaire fields.
- **Commercial risk:** medium because procurement is requesting fixed pricing through quarter end.
- **Execution risk:** high if response ownership remains fragmented across teams.
## Recommended Plan
Create a single response packet and assign one coordinator.
Pre-fill all known legal and security answers from existing templates.
Schedule a thirty-minute cross-functional triage with legal, security, and sales operations.
Lock a daily cutoff time for updates and send one canonical status note to stakeholders.
Escalate unresolved blockers to leadership forty-eight hours before target sign date.
## Success Criteria
- Security questionnaire submitted with no unresolved critical fields.
- Redline package accepted or narrowed to non-blocking items.
- Pricing terms approved by finance for the requested window.
- Purchase order process initiated before the internal close date.
- Owner confirms all blocker tickets are either resolved or explicitly waived.

View file

@ -0,0 +1,44 @@
query text_search($q: String) {
match {
$d: Doc
search($d.title, $q)
}
return { $d.slug, $d.title }
}
query fuzzy_search($q: String) {
match {
$d: Doc
fuzzy($d.title, $q, 2)
}
return { $d.slug, $d.title }
}
query phrase_search($q: String) {
match {
$d: Doc
match_text($d.body, $q)
}
return { $d.slug, $d.title }
}
query vector_search($q: Vector(4)) {
match { $d: Doc }
return { $d.slug, $d.title }
order { nearest($d.embedding, $q) }
limit 3
}
query bm25_search($q: String) {
match { $d: Doc }
return { $d.slug, $d.title }
order { bm25($d.title, $q) }
limit 3
}
query hybrid_search($vq: Vector(4), $tq: String) {
match { $d: Doc }
return { $d.slug, $d.title }
order { rrf(nearest($d.embedding, $vq), bm25($d.title, $tq)) }
limit 3
}

View file

@ -0,0 +1,5 @@
{"type": "Doc", "data": {"slug": "ml-intro", "title": "Introduction to Machine Learning", "body": "Machine learning is a subset of artificial intelligence that focuses on algorithms", "embedding": [0.1, 0.2, 0.3, 0.4]}}
{"type": "Doc", "data": {"slug": "dl-basics", "title": "Deep Learning Basics", "body": "Deep learning uses neural networks with many layers to learn representations", "embedding": [0.5, 0.6, 0.7, 0.8]}}
{"type": "Doc", "data": {"slug": "nlp-guide", "title": "Natural Language Processing Guide", "body": "NLP applies machine learning to understand and generate human language", "embedding": [0.2, 0.3, 0.4, 0.5]}}
{"type": "Doc", "data": {"slug": "cv-overview", "title": "Computer Vision Overview", "body": "Computer vision enables machines to interpret visual information from images", "embedding": [0.8, 0.7, 0.6, 0.5]}}
{"type": "Doc", "data": {"slug": "rl-intro", "title": "Reinforcement Learning Introduction", "body": "Reinforcement learning trains agents through reward and punishment signals", "embedding": [0.3, 0.4, 0.5, 0.6]}}

View file

@ -0,0 +1,6 @@
node Doc {
slug: String @key
title: String @index
body: String @index
embedding: Vector(4)
}

View file

@ -0,0 +1,46 @@
{"type": "Company", "data": {"slug": "aws", "name": "AWS", "sector": "hyperscaler"}}
{"type": "Company", "data": {"slug": "cerebras", "name": "Cerebras", "sector": "chipmaker"}}
{"type": "Company", "data": {"slug": "vast", "name": "VAST Data", "sector": "startup"}}
{"type": "Company", "data": {"slug": "oracle", "name": "Oracle", "sector": "hyperscaler"}}
{"type": "Company", "data": {"slug": "benchmark", "name": "Benchmark", "sector": "investor"}}
{"type": "Company", "data": {"slug": "xai", "name": "xAI", "sector": "lab"}}
{"type": "Company", "data": {"slug": "openai", "name": "OpenAI", "sector": "lab"}}
{"type": "Company", "data": {"slug": "anthropic", "name": "Anthropic", "sector": "lab"}}
{"type": "Company", "data": {"slug": "nvidia", "name": "NVIDIA", "sector": "chipmaker"}}
{"type": "Tech", "data": {"slug": "cs3", "name": "CS-3", "kind": "infra", "tier": "growth"}}
{"type": "Tech", "data": {"slug": "trainium", "name": "Trainium", "kind": "infra", "tier": "growth"}}
{"type": "Tech", "data": {"slug": "grok", "name": "Grok", "kind": "model", "tier": "emerging"}}
{"type": "Tech", "data": {"slug": "vast-ai-os", "name": "VAST AI OS", "kind": "platform", "tier": "growth"}}
{"type": "Signal", "data": {"slug": "aws-cerebras-inference", "title": "AWS and Cerebras collaborate on disaggregated inference: Trainium for prefill, CS-3 for decode, 5x capacity", "source": "press-release", "strength": "strong", "observed": "2026-03-13"}}
{"type": "Signal", "data": {"slug": "vast-1b-raise", "title": "VAST Data raises $1B at $30B, unveils AI OS bundling storage, compute, and agent runtimes into one stack", "source": "funding-round", "strength": "strong", "observed": "2026-03-12"}}
{"type": "Signal", "data": {"slug": "oracle-cerebras-mention", "title": "Oracle names Cerebras alongside NVIDIA and AMD as enterprise AI chip option", "source": "earnings-call", "strength": "moderate", "observed": "2026-03-10"}}
{"type": "Signal", "data": {"slug": "cerebras-23b-round", "title": "Cerebras valued at $23B after $1B round; Benchmark raises $225M in special vehicles to double down", "source": "funding-round", "strength": "strong", "observed": "2026-02-04"}}
{"type": "Signal", "data": {"slug": "xai-field-engineers", "title": "xAI deploys engineers on-site at enterprise clients to win deals from OpenAI and Anthropic", "source": "reporting", "strength": "strong", "observed": "2026-03-20"}}
{"type": "Pattern", "data": {"slug": "rise-of-fde", "name": "Rise of FDE", "category": "expansion"}}
{"type": "Pattern", "data": {"slug": "alt-chip-breakout", "name": "Alt-chip breakout", "category": "adoption"}}
{"type": "Pattern", "data": {"slug": "stack-collapse", "name": "Stack collapse", "category": "convergence"}}
{"type": "Pattern", "data": {"slug": "inference-specialization", "name": "Inference specialization", "category": "convergence"}}
{"edge": "Builds", "from": "cerebras", "to": "cs3"}
{"edge": "Builds", "from": "aws", "to": "trainium"}
{"edge": "Builds", "from": "xai", "to": "grok"}
{"edge": "Builds", "from": "vast", "to": "vast-ai-os"}
{"edge": "FundedBy", "from": "cerebras", "to": "benchmark", "data": {"amount": "$225M"}}
{"edge": "PartnersWith", "from": "aws", "to": "cerebras"}
{"edge": "PartnersWith", "from": "cerebras", "to": "openai"}
{"edge": "Mentions", "from": "aws-cerebras-inference", "to": "cs3"}
{"edge": "Mentions", "from": "aws-cerebras-inference", "to": "trainium"}
{"edge": "Mentions", "from": "vast-1b-raise", "to": "vast-ai-os"}
{"edge": "Mentions", "from": "cerebras-23b-round", "to": "cs3"}
{"edge": "Mentions", "from": "xai-field-engineers", "to": "grok"}
{"edge": "Indicates", "from": "aws-cerebras-inference", "to": "alt-chip-breakout"}
{"edge": "Indicates", "from": "aws-cerebras-inference", "to": "inference-specialization"}
{"edge": "Indicates", "from": "vast-1b-raise", "to": "stack-collapse"}
{"edge": "Indicates", "from": "oracle-cerebras-mention", "to": "alt-chip-breakout"}
{"edge": "Indicates", "from": "cerebras-23b-round", "to": "alt-chip-breakout"}
{"edge": "Indicates", "from": "xai-field-engineers", "to": "rise-of-fde"}
{"edge": "Involves", "from": "rise-of-fde", "to": "grok"}
{"edge": "Involves", "from": "alt-chip-breakout", "to": "cs3"}
{"edge": "Involves", "from": "alt-chip-breakout", "to": "trainium"}
{"edge": "Involves", "from": "stack-collapse", "to": "vast-ai-os"}
{"edge": "Involves", "from": "inference-specialization", "to": "cs3"}
{"edge": "Involves", "from": "inference-specialization", "to": "trainium"}

View file

@ -0,0 +1,44 @@
// Industry signals around AI models and major players.
// Branch use case: main tracks confirmed signals, branches
// model speculative or emerging interpretations.
node Signal {
slug: String @key
title: String
source: String
strength: enum(strong, moderate, weak)
observed: Date?
}
node Pattern {
slug: String @key
name: String
category: enum(adoption, churn, expansion, contraction, convergence)
}
node Tech {
slug: String @key
name: String
kind: enum(model, platform, infra, framework, tool)
tier: enum(emerging, growth, mature, declining)
}
node Company {
slug: String @key
name: String
sector: enum(lab, hyperscaler, chipmaker, investor, startup)
}
edge Indicates: Signal -> Pattern
edge Mentions: Signal -> Tech
edge Involves: Pattern -> Tech
edge Builds: Company -> Tech
edge FundedBy: Company -> Company {
amount: String?
}
edge PartnersWith: Company -> Company

78
crates/omnigraph/tests/fixtures/test.gq vendored Normal file
View file

@ -0,0 +1,78 @@
// Basic: find person by name
query get_person($name: String) {
match {
$p: Person { name: $name }
}
return { $p.name, $p.age }
}
// Filter by age
query adults() {
match {
$p: Person
$p.age > 30
}
return { $p.name, $p.age }
order { $p.age desc }
}
// One hop traversal
query friends_of($name: String) {
match {
$p: Person { name: $name }
$p knows $f
}
return { $f.name, $f.age }
}
// Reverse traversal: who works at a company
query employees_of($company: String) {
match {
$c: Company { name: $company }
$p worksAt $c
}
return { $p.name }
}
// Two hop: friends of friends
query friends_of_friends($name: String) {
match {
$p: Person { name: $name }
$p knows $mid
$mid knows $fof
}
return { $fof.name }
}
// Negation: people who don't work anywhere
query unemployed() {
match {
$p: Person
not { $p worksAt $_ }
}
return { $p.name }
}
// Aggregation: friend count
query friend_counts() {
match {
$p: Person
$p knows $f
}
return {
$p.name
count($f) as friends
}
order { friends desc }
limit 20
}
// Order and limit
query top_by_age() {
match {
$p: Person
}
return { $p.name, $p.age }
order { $p.age desc }
limit 2
}

View file

@ -0,0 +1,11 @@
{"type": "Person", "data": {"name": "Alice", "age": 30}}
{"type": "Person", "data": {"name": "Bob", "age": 25}}
{"type": "Person", "data": {"name": "Charlie", "age": 35}}
{"type": "Person", "data": {"name": "Diana", "age": 28}}
{"type": "Company", "data": {"name": "Acme"}}
{"type": "Company", "data": {"name": "Globex"}}
{"edge": "Knows", "from": "Alice", "to": "Bob"}
{"edge": "Knows", "from": "Alice", "to": "Charlie"}
{"edge": "Knows", "from": "Bob", "to": "Diana"}
{"edge": "WorksAt", "from": "Alice", "to": "Acme"}
{"edge": "WorksAt", "from": "Bob", "to": "Globex"}

14
crates/omnigraph/tests/fixtures/test.pg vendored Normal file
View file

@ -0,0 +1,14 @@
node Person {
name: String @key
age: I32?
}
node Company {
name: String @key
}
edge Knows: Person -> Person {
since: Date?
}
edge WorksAt: Person -> Company

View file

@ -0,0 +1,256 @@
#![allow(dead_code)]
use arrow_array::{Array, RecordBatch, StringArray};
use futures::TryStreamExt;
use omnigraph::changes::{ChangeFilter, ChangeSet};
use omnigraph::db::{Omnigraph, ReadTarget, Snapshot, SnapshotId};
use omnigraph::error::Result;
use omnigraph::loader::{LoadMode, load_jsonl};
use omnigraph_compiler::ir::ParamMap;
use omnigraph_compiler::query::ast::Literal;
use omnigraph_compiler::result::{MutationResult, QueryResult};
pub const TEST_SCHEMA: &str = include_str!("../fixtures/test.pg");
pub const TEST_DATA: &str = include_str!("../fixtures/test.jsonl");
pub const TEST_QUERIES: &str = include_str!("../fixtures/test.gq");
pub const MUTATION_QUERIES: &str = r#"
query insert_person($name: String, $age: I32) {
insert Person { name: $name, age: $age }
}
query add_friend($from: String, $to: String) {
insert Knows { from: $from, to: $to }
}
query set_age($name: String, $age: I32) {
update Person set { age: $age } where name = $name
}
query remove_person($name: String) {
delete Person where name = $name
}
query remove_friendship($from: String) {
delete Knows where from = $from
}
"#;
/// Init a repo and load the standard test data.
pub async fn init_and_load(dir: &tempfile::TempDir) -> Omnigraph {
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
.await
.unwrap();
db
}
/// Read all rows from a sub-table by table_key.
pub async fn read_table(db: &Omnigraph, table_key: &str) -> Vec<RecordBatch> {
let snap = snapshot_main(db).await.unwrap();
let ds = snap.open(table_key).await.unwrap();
ds.scan()
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap()
}
/// Read all rows from a branch-local sub-table by table_key.
pub async fn read_table_branch(db: &Omnigraph, branch: &str, table_key: &str) -> Vec<RecordBatch> {
let snap = snapshot_branch(db, branch).await.unwrap();
let ds = snap.open(table_key).await.unwrap();
ds.scan()
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap()
}
/// Count rows in a sub-table.
pub async fn count_rows(db: &Omnigraph, table_key: &str) -> usize {
let snap = snapshot_main(db).await.unwrap();
let ds = snap.open(table_key).await.unwrap();
ds.count_rows(None).await.unwrap()
}
/// Count rows in a branch-local sub-table.
pub async fn count_rows_branch(db: &Omnigraph, branch: &str, table_key: &str) -> usize {
let snap = snapshot_branch(db, branch).await.unwrap();
let ds = snap.open(table_key).await.unwrap();
ds.count_rows(None).await.unwrap()
}
/// Collect all string values from a named column across batches.
pub fn collect_column_strings(batches: &[RecordBatch], col: &str) -> Vec<String> {
let mut out = Vec::new();
for batch in batches {
let arr = batch
.column_by_name(col)
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..arr.len() {
if !arr.is_null(i) {
out.push(arr.value(i).to_string());
}
}
}
out
}
pub async fn query_main(
db: &mut Omnigraph,
query_source: &str,
query_name: &str,
params: &ParamMap,
) -> Result<QueryResult> {
db.query(ReadTarget::branch("main"), query_source, query_name, params)
.await
}
pub async fn query_branch(
db: &mut Omnigraph,
branch: &str,
query_source: &str,
query_name: &str,
params: &ParamMap,
) -> Result<QueryResult> {
db.query(ReadTarget::branch(branch), query_source, query_name, params)
.await
}
pub async fn mutate_main(
db: &mut Omnigraph,
query_source: &str,
query_name: &str,
params: &ParamMap,
) -> Result<MutationResult> {
db.mutate("main", query_source, query_name, params).await
}
pub async fn mutate_branch(
db: &mut Omnigraph,
branch: &str,
query_source: &str,
query_name: &str,
params: &ParamMap,
) -> Result<MutationResult> {
db.mutate(branch, query_source, query_name, params).await
}
pub async fn snapshot_main(db: &Omnigraph) -> Result<Snapshot> {
db.snapshot_of(ReadTarget::branch("main")).await
}
pub async fn snapshot_branch(db: &Omnigraph, branch: &str) -> Result<Snapshot> {
db.snapshot_of(ReadTarget::branch(branch)).await
}
pub async fn version_main(db: &Omnigraph) -> Result<u64> {
db.version_of(ReadTarget::branch("main")).await
}
pub async fn version_branch(db: &Omnigraph, branch: &str) -> Result<u64> {
db.version_of(ReadTarget::branch(branch)).await
}
pub async fn sync_main(db: &mut Omnigraph) -> Result<()> {
db.sync_branch("main").await
}
pub async fn sync_named_branch(db: &mut Omnigraph, branch: &str) -> Result<()> {
db.sync_branch(branch).await
}
pub async fn snapshot_id(db: &Omnigraph, branch: &str) -> Result<SnapshotId> {
db.resolve_snapshot(branch).await
}
pub async fn diff_since_branch(
db: &Omnigraph,
branch: &str,
from_snapshot: SnapshotId,
filter: &ChangeFilter,
) -> Result<ChangeSet> {
db.diff_between(
ReadTarget::Snapshot(from_snapshot),
ReadTarget::branch(branch),
filter,
)
.await
}
/// Build a ParamMap from string key-value pairs.
pub fn params(pairs: &[(&str, &str)]) -> ParamMap {
pairs
.iter()
.map(|(k, v)| {
let key = k.strip_prefix('$').unwrap_or(k);
(key.to_string(), Literal::String(v.to_string()))
})
.collect()
}
/// Build a ParamMap from integer key-value pairs.
pub fn int_params(pairs: &[(&str, i64)]) -> ParamMap {
pairs
.iter()
.map(|(k, v)| {
let key = k.strip_prefix('$').unwrap_or(k);
(key.to_string(), Literal::Integer(*v))
})
.collect()
}
/// Build a ParamMap from mixed string + integer pairs.
pub fn mixed_params(str_pairs: &[(&str, &str)], int_pairs: &[(&str, i64)]) -> ParamMap {
let mut map = params(str_pairs);
for (k, v) in int_pairs {
let key = k.strip_prefix('$').unwrap_or(k);
map.insert(key.to_string(), Literal::Integer(*v));
}
map
}
/// Build a ParamMap with a single vector parameter.
pub fn vector_param(name: &str, values: &[f32]) -> ParamMap {
let key = name.strip_prefix('$').unwrap_or(name).to_string();
let lit = Literal::List(values.iter().map(|v| Literal::Float(*v as f64)).collect());
let mut map = ParamMap::new();
map.insert(key, lit);
map
}
/// Build a ParamMap with a vector param and a string param.
pub fn vector_and_string_params(
vec_name: &str,
vec_values: &[f32],
str_name: &str,
str_value: &str,
) -> ParamMap {
let mut map = vector_param(vec_name, vec_values);
let key = str_name.strip_prefix('$').unwrap_or(str_name).to_string();
map.insert(key, Literal::String(str_value.to_string()));
map
}
pub fn s3_test_repo_uri(suite: &str) -> Option<String> {
let bucket = std::env::var("OMNIGRAPH_S3_TEST_BUCKET").ok()?;
let prefix = std::env::var("OMNIGRAPH_S3_TEST_PREFIX")
.ok()
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| "omnigraph-itests".to_string());
let unique = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()?
.as_nanos();
Some(format!("s3://{}/{}/{}/{}", bucket, prefix, suite, unique))
}

View file

@ -0,0 +1,268 @@
/// Investigation test: understand how Lance stamps `_row_created_at_version` and
/// `_row_last_updated_at_version` for different write modes (append, merge_insert new,
/// merge_insert update).
use std::sync::Arc;
use arrow_array::{Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lance::Dataset;
use lance::dataset::{WriteMode, WriteParams};
use lance_file::version::LanceFileVersion;
async fn create_test_dataset(uri: &str) -> Dataset {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("value", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["alice", "bob"])),
Arc::new(Int32Array::from(vec![1, 2])),
],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let params = WriteParams {
mode: WriteMode::Create,
enable_stable_row_ids: true,
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
};
Dataset::write(reader, uri, Some(params)).await.unwrap()
}
fn read_version_columns(batches: &[RecordBatch]) -> Vec<(String, i32, u64, u64)> {
let mut rows = Vec::new();
for batch in batches {
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let vals = batch
.column_by_name("value")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let created = batch
.column_by_name("_row_created_at_version")
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
let updated = batch
.column_by_name("_row_last_updated_at_version")
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
for i in 0..ids.len() {
rows.push((
ids.value(i).to_string(),
vals.value(i),
created.value(i),
updated.value(i),
));
}
}
rows.sort_by(|a, b| a.0.cmp(&b.0));
rows
}
async fn scan_with_versions(ds: &Dataset) -> Vec<(String, i32, u64, u64)> {
let mut scanner = ds.scan();
scanner
.project(&[
"id",
"value",
"_row_created_at_version",
"_row_last_updated_at_version",
])
.unwrap();
let batches: Vec<RecordBatch> = scanner
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
read_version_columns(&batches)
}
#[tokio::test]
async fn lance_append_stamps_created_at_version_correctly() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().join("test.lance");
let uri_str = uri.to_str().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("value", DataType::Int32, false),
]));
let ds = create_test_dataset(uri_str).await;
let v1 = ds.version().version;
// Append a new row
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["charlie"])),
Arc::new(Int32Array::from(vec![3])),
],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let mut ds = ds;
ds.append(reader, None).await.unwrap();
let v2 = ds.version().version;
let rows = scan_with_versions(&ds).await;
eprintln!("After append (v1={}, v2={}):", v1, v2);
for (id, val, created, updated) in &rows {
eprintln!(
" id={:<10} val={:<4} created_v={:<4} updated_v={}",
id, val, created, updated
);
}
// Alice and Bob: created at v1
let alice = rows.iter().find(|r| r.0 == "alice").unwrap();
assert_eq!(alice.2, v1, "alice created_at should be v1");
// Charlie: created at v2 (the append version)
let charlie = rows.iter().find(|r| r.0 == "charlie").unwrap();
assert_eq!(
charlie.2, v2,
"charlie created_at should be v2 (append version)"
);
}
#[tokio::test]
async fn lance_merge_insert_new_row_stamps_created_at_version() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().join("test.lance");
let uri_str = uri.to_str().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("value", DataType::Int32, false),
]));
let ds = create_test_dataset(uri_str).await;
let v1 = ds.version().version;
// merge_insert a NEW row (eve doesn't exist)
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["eve"])),
Arc::new(Int32Array::from(vec![4])),
],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let ds_arc = Arc::new(ds);
let job = lance::dataset::MergeInsertBuilder::try_new(ds_arc, vec!["id".to_string()])
.unwrap()
.when_matched(lance::dataset::WhenMatched::UpdateAll)
.when_not_matched(lance::dataset::WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let (new_ds, _) = job
.execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
.await
.unwrap();
let v2 = new_ds.version().version;
let rows = scan_with_versions(&new_ds).await;
eprintln!("After merge_insert NEW eve (v1={}, v2={}):", v1, v2);
for (id, val, created, updated) in &rows {
eprintln!(
" id={:<10} val={:<4} created_v={:<4} updated_v={}",
id, val, created, updated
);
}
let eve = rows.iter().find(|r| r.0 == "eve").unwrap();
eprintln!("Eve: created_at_version={}, v1={}, v2={}", eve.2, v1, v2);
// Lance behavior (as of 3.0.1): merge_insert stamps new rows with
// _row_created_at_version = dataset_creation_version (v1), NOT the
// merge_insert commit version (v2). This is why Omnigraph's change
// detection uses _row_last_updated_at_version + ID set membership
// to classify inserts vs updates, not _row_created_at_version alone.
assert_eq!(
eve.2, v1,
"Lance merge_insert stamps new rows with created_at = dataset creation version, not commit version"
);
assert_eq!(
eve.3, v2,
"Lance merge_insert stamps new rows with last_updated_at = commit version"
);
}
#[tokio::test]
async fn lance_merge_insert_update_preserves_created_at_version() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().join("test.lance");
let uri_str = uri.to_str().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("value", DataType::Int32, false),
]));
let ds = create_test_dataset(uri_str).await;
let v1 = ds.version().version;
// merge_insert an EXISTING row (update bob's value)
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["bob"])),
Arc::new(Int32Array::from(vec![99])),
],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let ds_arc = Arc::new(ds);
let job = lance::dataset::MergeInsertBuilder::try_new(ds_arc, vec!["id".to_string()])
.unwrap()
.when_matched(lance::dataset::WhenMatched::UpdateAll)
.when_not_matched(lance::dataset::WhenNotMatched::InsertAll)
.try_build()
.unwrap();
let (new_ds, _) = job
.execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
.await
.unwrap();
let v2 = new_ds.version().version;
let rows = scan_with_versions(&new_ds).await;
eprintln!("After merge_insert UPDATE bob (v1={}, v2={}):", v1, v2);
for (id, val, created, updated) in &rows {
eprintln!(
" id={:<10} val={:<4} created_v={:<4} updated_v={}",
id, val, created, updated
);
}
let alice = rows.iter().find(|r| r.0 == "alice").unwrap();
let bob = rows.iter().find(|r| r.0 == "bob").unwrap();
// Alice: untouched, should keep original versions
assert_eq!(alice.2, v1, "alice created_at should still be v1");
assert_eq!(alice.3, v1, "alice updated_at should still be v1");
// Bob: updated via merge_insert
// created_at should be preserved (v1), updated_at should be bumped (v2)
eprintln!(
"Bob: created_at={}, updated_at={}, v1={}, v2={}",
bob.2, bob.3, v1, v2
);
assert_eq!(bob.1, 99, "bob's value should be updated to 99");
}

View file

@ -0,0 +1,736 @@
mod helpers;
use arrow_array::{Array, Int32Array};
use helpers::*;
use omnigraph::db::Omnigraph;
use omnigraph_compiler::ir::ParamMap;
// ─── Inline queries for point-in-time tests ─────────────────────────────────
const ALL_PERSONS_QUERY: &str = r#"
query all_persons() {
match {
$p: Person
}
return { $p.name, $p.age }
order { $p.name asc }
}
"#;
const FRIENDS_QUERY: &str = r#"
query friends_of($name: String) {
match {
$p: Person { name: $name }
$p knows $f
}
return { $f.name }
order { $f.name asc }
}
"#;
const UNEMPLOYED_QUERY: &str = r#"
query unemployed() {
match {
$p: Person
not { $p worksAt $_ }
}
return { $p.name }
order { $p.name asc }
}
"#;
const FILTERED_QUERY: &str = r#"
query older_than($min_age: I32) {
match {
$p: Person
$p.age > $min_age
}
return { $p.name, $p.age }
order { $p.name asc }
}
"#;
const GET_PERSON_QUERY: &str = r#"
query get_person($name: String) {
match {
$p: Person { name: $name }
}
return { $p.name, $p.age }
}
"#;
// ─── Morphological matrix ───────────────────────────────────────────────────
//
// Dimensions:
// Query type: Tabular | Traversal | Negation (AntiJoin) | Filtered | Aggregation
// Mutation: Insert | Update | Delete node | Delete edge
// Branch: Main | Named branch
// Result shape: Empty→non-empty | Non-empty→empty | Count changes | Value changes
//
// Existing coverage (4 tests):
// Tabular × Insert × Main (returns_historical_data)
// Traversal × Insert × Main (traversal_uses_historical_graph_index)
// Tabular × Update × Main (multiple_versions_sees_correct_state)
// Error case (snapshot_at_version_fails_for_nonexistent_version)
//
// New coverage (9 tests below):
// Tabular × Delete node × Main → non-empty becomes smaller
// Traversal × Delete edge × Main → edge disappears from historical
// Negation × Insert edge × Main → anti-join result shrinks after insert
// Negation × Delete edge × Main → anti-join result grows after delete
// Filtered × Update × Main → entity enters/exits filter after age change
// Multi-hop × Insert(n+e) × Main → friends-of-friends grows after new path
// Traversal × Delete node × Main → cascade removes edges from traversal
// Tabular × Insert × Branch → branch isolation for point-in-time
// Tabular × Multi-step × Main → 4-version chain: insert, update, delete
// ─── Original tests ────────────────────────────────────────────────────────
#[tokio::test]
async fn run_query_at_returns_historical_data() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let v_before = version_main(&db).await.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
let historical = db
.run_query_at(v_before, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
.await
.unwrap();
let current = query_main(&mut db, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
.await
.unwrap();
assert_eq!(historical.num_rows(), 4, "historical should have 4 persons");
assert_eq!(current.num_rows(), 5, "current should have 5 persons");
let historical_names = collect_column_strings(historical.batches(), "p.name");
assert!(!historical_names.contains(&"Eve".to_string()));
let current_names = collect_column_strings(current.batches(), "p.name");
assert!(current_names.contains(&"Eve".to_string()));
}
#[tokio::test]
async fn run_query_at_traversal_uses_historical_graph_index() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let v_before = version_main(&db).await.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"add_friend",
&params(&[("$from", "Eve"), ("$to", "Alice")]),
)
.await
.unwrap();
let historical = db
.run_query_at(
v_before,
FRIENDS_QUERY,
"friends_of",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
let current = query_main(
&mut db,
FRIENDS_QUERY,
"friends_of",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(historical.num_rows(), 2);
let hist_names = collect_column_strings(historical.batches(), "f.name");
assert!(hist_names.contains(&"Bob".to_string()));
assert!(hist_names.contains(&"Charlie".to_string()));
assert_eq!(current.num_rows(), 1);
let cur_names = collect_column_strings(current.batches(), "f.name");
assert!(cur_names.contains(&"Alice".to_string()));
}
#[tokio::test]
async fn snapshot_at_version_fails_for_nonexistent_version() {
let dir = tempfile::tempdir().unwrap();
let db = init_and_load(&dir).await;
let result = db.snapshot_at_version(99999).await;
assert!(result.is_err(), "non-existent version should return error");
}
#[tokio::test]
async fn run_query_at_multiple_versions_sees_correct_state() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let v1 = version_main(&db).await.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Alice")], &[("$age", 99)]),
)
.await
.unwrap();
let v2 = version_main(&db).await.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Frank")], &[("$age", 40)]),
)
.await
.unwrap();
let at_v1 = db
.run_query_at(v1, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
.await
.unwrap();
assert_eq!(at_v1.num_rows(), 4, "v1 should have 4 persons");
let v1_names = collect_column_strings(at_v1.batches(), "p.name");
assert!(!v1_names.contains(&"Frank".to_string()));
let at_v2 = db
.run_query_at(v2, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
.await
.unwrap();
assert_eq!(at_v2.num_rows(), 4, "v2 should have 4 persons");
let v2_names = collect_column_strings(at_v2.batches(), "p.name");
assert!(!v2_names.contains(&"Frank".to_string()));
let current = query_main(&mut db, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
.await
.unwrap();
assert_eq!(current.num_rows(), 5, "current should have 5 persons");
let cur_names = collect_column_strings(current.batches(), "p.name");
assert!(cur_names.contains(&"Frank".to_string()));
}
// ─── Tabular × Delete node ─────────────────────────────────────────────────
#[tokio::test]
async fn tabular_delete_node_invisible_at_historical_version() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Fixture: Alice, Bob, Charlie, Diana
let v_before = version_main(&db).await.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"remove_person",
&params(&[("$name", "Charlie")]),
)
.await
.unwrap();
// Historical: Charlie still exists
let historical = db
.run_query_at(v_before, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
.await
.unwrap();
let hist_names = collect_column_strings(historical.batches(), "p.name");
assert_eq!(historical.num_rows(), 4);
assert!(hist_names.contains(&"Charlie".to_string()));
// Current: Charlie is gone
let current = query_main(&mut db, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
.await
.unwrap();
let cur_names = collect_column_strings(current.batches(), "p.name");
assert_eq!(current.num_rows(), 3);
assert!(!cur_names.contains(&"Charlie".to_string()));
}
// ─── Traversal × Delete edge ───────────────────────────────────────────────
#[tokio::test]
async fn traversal_delete_edge_invisible_at_historical_version() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Fixture: Alice knows Bob, Alice knows Charlie
let v_before = version_main(&db).await.unwrap();
// Remove all Knows edges FROM Alice
mutate_main(
&mut db,
MUTATION_QUERIES,
"remove_friendship",
&params(&[("$from", "Alice")]),
)
.await
.unwrap();
// Historical traversal: Alice's friends at v_before = Bob, Charlie
let historical = db
.run_query_at(
v_before,
FRIENDS_QUERY,
"friends_of",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
assert_eq!(historical.num_rows(), 2);
let hist_names = collect_column_strings(historical.batches(), "f.name");
assert!(hist_names.contains(&"Bob".to_string()));
assert!(hist_names.contains(&"Charlie".to_string()));
// Current: Alice has no friends (edges deleted)
let current = query_main(
&mut db,
FRIENDS_QUERY,
"friends_of",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
assert_eq!(
current.num_rows(),
0,
"Alice should have no friends after edge deletion"
);
}
// ─── Negation (AntiJoin) × Insert ──────────────────────────────────────────
#[tokio::test]
async fn negation_insert_shrinks_antijoin_result() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Fixture: Alice worksAt Acme, Bob worksAt Globex
// Unemployed: Charlie, Diana
let v_before = version_main(&db).await.unwrap();
// Give Charlie a job
mutate_main(
&mut db,
r#"
query hire($from: String, $to: String) {
insert WorksAt { from: $from, to: $to }
}
"#,
"hire",
&params(&[("$from", "Charlie"), ("$to", "Acme")]),
)
.await
.unwrap();
// Historical: Charlie and Diana were unemployed
let historical = db
.run_query_at(v_before, UNEMPLOYED_QUERY, "unemployed", &ParamMap::new())
.await
.unwrap();
let hist_names = collect_column_strings(historical.batches(), "p.name");
assert_eq!(historical.num_rows(), 2);
assert!(hist_names.contains(&"Charlie".to_string()));
assert!(hist_names.contains(&"Diana".to_string()));
// Current: only Diana is unemployed
let current = query_main(&mut db, UNEMPLOYED_QUERY, "unemployed", &ParamMap::new())
.await
.unwrap();
let cur_names = collect_column_strings(current.batches(), "p.name");
assert_eq!(current.num_rows(), 1);
assert!(cur_names.contains(&"Diana".to_string()));
assert!(!cur_names.contains(&"Charlie".to_string()));
}
// ─── Negation (AntiJoin) × Delete edge ─────────────────────────────────────
#[tokio::test]
async fn negation_delete_edge_grows_antijoin_result() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Fixture: Alice worksAt Acme, Bob worksAt Globex
// Unemployed at start: Charlie, Diana
let v_before = version_main(&db).await.unwrap();
// Fire Alice (delete WorksAt edge)
mutate_main(
&mut db,
r#"
query fire($from: String) {
delete WorksAt where from = $from
}
"#,
"fire",
&params(&[("$from", "Alice")]),
)
.await
.unwrap();
// Historical: 2 unemployed (Charlie, Diana)
let historical = db
.run_query_at(v_before, UNEMPLOYED_QUERY, "unemployed", &ParamMap::new())
.await
.unwrap();
assert_eq!(historical.num_rows(), 2);
let hist_names = collect_column_strings(historical.batches(), "p.name");
assert!(!hist_names.contains(&"Alice".to_string()));
// Current: 3 unemployed (Alice, Charlie, Diana)
let current = query_main(&mut db, UNEMPLOYED_QUERY, "unemployed", &ParamMap::new())
.await
.unwrap();
assert_eq!(current.num_rows(), 3);
let cur_names = collect_column_strings(current.batches(), "p.name");
assert!(cur_names.contains(&"Alice".to_string()));
}
// ─── Filtered × Update (value enters/exits filter) ─────────────────────────
#[tokio::test]
async fn filtered_update_entity_crosses_filter_boundary() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Fixture: Alice(30), Bob(25), Charlie(35), Diana(28)
// older_than(30): Charlie(35) only
let v_before = version_main(&db).await.unwrap();
// Update Bob's age from 25 to 40 → enters the filter
mutate_main(
&mut db,
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Bob")], &[("$age", 40)]),
)
.await
.unwrap();
// Historical: only Charlie is older than 30
let historical = db
.run_query_at(
v_before,
FILTERED_QUERY,
"older_than",
&int_params(&[("$min_age", 30)]),
)
.await
.unwrap();
assert_eq!(historical.num_rows(), 1);
let hist_names = collect_column_strings(historical.batches(), "p.name");
assert_eq!(hist_names, vec!["Charlie"]);
// Current: Bob(40) and Charlie(35) are older than 30
let current = query_main(
&mut db,
FILTERED_QUERY,
"older_than",
&int_params(&[("$min_age", 30)]),
)
.await
.unwrap();
assert_eq!(current.num_rows(), 2);
let cur_names = collect_column_strings(current.batches(), "p.name");
assert!(cur_names.contains(&"Bob".to_string()));
assert!(cur_names.contains(&"Charlie".to_string()));
}
// ─── Multi-hop traversal × Insert ──────────────────────────────────────────
#[tokio::test]
async fn multi_hop_traversal_historical_version() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Fixture: Alice→Bob, Alice→Charlie, Bob→Diana
// friends_of_friends(Alice) = Diana (Alice→Bob→Diana)
let v_before = version_main(&db).await.unwrap();
// Insert Eve and edge: Charlie→Eve
// Now friends_of_friends(Alice) = Diana + Eve (Alice→Charlie→Eve)
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
mutate_main(
&mut db,
MUTATION_QUERIES,
"add_friend",
&params(&[("$from", "Charlie"), ("$to", "Eve")]),
)
.await
.unwrap();
let fof_query = r#"
query fof($name: String) {
match {
$p: Person { name: $name }
$p knows $mid
$mid knows $f
}
return { $f.name }
order { $f.name asc }
}
"#;
// Historical: friends-of-friends of Alice = Diana only
let historical = db
.run_query_at(v_before, fof_query, "fof", &params(&[("$name", "Alice")]))
.await
.unwrap();
assert_eq!(historical.num_rows(), 1);
let hist_names = collect_column_strings(historical.batches(), "f.name");
assert_eq!(hist_names, vec!["Diana"]);
// Current: friends-of-friends of Alice = Diana + Eve
let current = query_main(&mut db, fof_query, "fof", &params(&[("$name", "Alice")]))
.await
.unwrap();
assert_eq!(current.num_rows(), 2);
let cur_names = collect_column_strings(current.batches(), "f.name");
assert!(cur_names.contains(&"Diana".to_string()));
assert!(cur_names.contains(&"Eve".to_string()));
}
// ─── Traversal × Delete node (cascade removes edges) ───────────────────────
#[tokio::test]
async fn traversal_delete_node_cascade_removes_edges() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// Fixture: Alice knows Bob, Alice knows Charlie, Bob knows Diana
let v_before = version_main(&db).await.unwrap();
// Delete Bob → cascades to Knows edges involving Bob
mutate_main(
&mut db,
MUTATION_QUERIES,
"remove_person",
&params(&[("$name", "Bob")]),
)
.await
.unwrap();
// Historical: Alice's friends = Bob, Charlie
let historical = db
.run_query_at(
v_before,
FRIENDS_QUERY,
"friends_of",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
assert_eq!(historical.num_rows(), 2);
let hist_names = collect_column_strings(historical.batches(), "f.name");
assert!(hist_names.contains(&"Bob".to_string()));
assert!(hist_names.contains(&"Charlie".to_string()));
// Current: Alice's friends = Charlie only (Bob was deleted, edge cascaded)
let current = query_main(
&mut db,
FRIENDS_QUERY,
"friends_of",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
assert_eq!(current.num_rows(), 1);
let cur_names = collect_column_strings(current.batches(), "f.name");
assert_eq!(cur_names, vec!["Charlie"]);
}
// ─── Branch isolation for point-in-time ────────────────────────────────────
#[tokio::test]
async fn branch_point_in_time_isolated_from_main() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut main = init_and_load(&dir).await;
main.branch_create("feature").await.unwrap();
let v_main_before = version_main(&main).await.unwrap();
// Insert Eve on main
mutate_main(
&mut main,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
// Insert Frank on feature branch
let mut feature = Omnigraph::open(uri).await.unwrap();
mutate_branch(
&mut feature,
"feature",
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Frank")], &[("$age", 33)]),
)
.await
.unwrap();
// Historical main at v_main_before: 4 persons, no Eve, no Frank
let hist_main = main
.run_query_at(
v_main_before,
ALL_PERSONS_QUERY,
"all_persons",
&ParamMap::new(),
)
.await
.unwrap();
assert_eq!(hist_main.num_rows(), 4);
let hist_names = collect_column_strings(hist_main.batches(), "p.name");
assert!(!hist_names.contains(&"Eve".to_string()));
assert!(!hist_names.contains(&"Frank".to_string()));
// Current main: 5 persons (Eve present, Frank not visible on main)
let cur_main = query_main(
&mut main,
ALL_PERSONS_QUERY,
"all_persons",
&ParamMap::new(),
)
.await
.unwrap();
assert_eq!(cur_main.num_rows(), 5);
let cur_names = collect_column_strings(cur_main.batches(), "p.name");
assert!(cur_names.contains(&"Eve".to_string()));
assert!(!cur_names.contains(&"Frank".to_string()));
// Feature branch: 5 persons (Frank present, Eve not visible on feature)
let cur_feature = query_branch(
&mut feature,
"feature",
ALL_PERSONS_QUERY,
"all_persons",
&ParamMap::new(),
)
.await
.unwrap();
assert_eq!(cur_feature.num_rows(), 5);
let feat_names = collect_column_strings(cur_feature.batches(), "p.name");
assert!(feat_names.contains(&"Frank".to_string()));
assert!(!feat_names.contains(&"Eve".to_string()));
}
// ─── Multi-step version chain: insert → update → delete ────────────────────
#[tokio::test]
async fn four_version_chain_insert_update_delete() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
// v1: baseline (Alice=30, Bob=25, Charlie=35, Diana=28)
let v1 = version_main(&db).await.unwrap();
// v2: insert Eve(22)
mutate_main(
&mut db,
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
let v2 = version_main(&db).await.unwrap();
// v3: update Eve's age to 50
mutate_main(
&mut db,
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Eve")], &[("$age", 50)]),
)
.await
.unwrap();
let v3 = version_main(&db).await.unwrap();
// v4: delete Eve
mutate_main(
&mut db,
MUTATION_QUERIES,
"remove_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
// v1: no Eve, 4 persons
let at_v1 = db
.run_query_at(v1, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
.await
.unwrap();
assert_eq!(at_v1.num_rows(), 4);
let v1_names = collect_column_strings(at_v1.batches(), "p.name");
assert!(!v1_names.contains(&"Eve".to_string()));
// v2: Eve exists with age 22, 5 persons
let at_v2 = db
.run_query_at(
v2,
GET_PERSON_QUERY,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(at_v2.num_rows(), 1);
let v2_batch = at_v2.concat_batches().unwrap();
let v2_ages = v2_batch
.column_by_name("p.age")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(v2_ages.value(0), 22);
// v3: Eve exists with age 50
let at_v3 = db
.run_query_at(
v3,
GET_PERSON_QUERY,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(at_v3.num_rows(), 1);
let v3_batch = at_v3.concat_batches().unwrap();
let v3_ages = v3_batch
.column_by_name("p.age")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(v3_ages.value(0), 50);
// v4 (current): Eve is gone, back to 4
let current = query_main(&mut db, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
.await
.unwrap();
assert_eq!(current.num_rows(), 4);
let cur_names = collect_column_strings(current.batches(), "p.name");
assert!(!cur_names.contains(&"Eve".to_string()));
}

View file

@ -0,0 +1,533 @@
mod helpers;
use std::collections::HashMap;
use arrow_array::{Array, RecordBatch, StringArray, TimestampMicrosecondArray};
use futures::TryStreamExt;
use lance::Dataset;
use omnigraph::db::commit_graph::CommitGraph;
use omnigraph::db::{Omnigraph, ReadTarget, RunStatus};
use omnigraph::error::OmniError;
use omnigraph::loader::{LoadMode, load_jsonl};
use helpers::*;
#[derive(Debug, Clone)]
struct PersistedRun {
run_id: String,
target_branch: String,
run_branch: String,
status: String,
updated_at: i64,
}
async fn latest_runs(uri: &str) -> Vec<PersistedRun> {
let runs_uri = format!("{}/_graph_runs.lance", uri);
let ds = Dataset::open(&runs_uri).await.unwrap();
let batches: Vec<RecordBatch> = ds
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let mut latest: HashMap<String, PersistedRun> = HashMap::new();
for batch in batches {
let run_ids = batch
.column_by_name("run_id")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let target_branches = batch
.column_by_name("target_branch")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let run_branches = batch
.column_by_name("run_branch")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let statuses = batch
.column_by_name("status")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let updated_ats = batch
.column_by_name("updated_at")
.unwrap()
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.unwrap();
for row in 0..batch.num_rows() {
let record = PersistedRun {
run_id: run_ids.value(row).to_string(),
target_branch: target_branches.value(row).to_string(),
run_branch: run_branches.value(row).to_string(),
status: statuses.value(row).to_string(),
updated_at: updated_ats.value(row),
};
match latest.get(record.run_id.as_str()) {
Some(existing) if existing.updated_at >= record.updated_at => {}
_ => {
latest.insert(record.run_id.clone(), record);
}
}
}
}
let mut records = latest.into_values().collect::<Vec<_>>();
records.sort_by(|a, b| a.run_id.cmp(&b.run_id));
records
}
#[tokio::test]
async fn begin_run_creates_hidden_internal_branch_and_isolates_writes() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let base_snapshot = db.resolve_snapshot("main").await.unwrap();
let run = db.begin_run("main", Some("test-load")).await.unwrap();
assert!(run.run_branch.starts_with("__run__"));
assert_eq!(run.target_branch, "main");
assert_eq!(run.base_snapshot_id, base_snapshot.as_str());
assert_eq!(run.status, RunStatus::Running);
assert_eq!(db.branch_list().await.unwrap(), vec!["main"]);
db.load(
&run.run_branch,
r#"{"type":"Person","data":{"name":"Eve","age":22}}"#,
LoadMode::Append,
)
.await
.unwrap();
let main_qr = db
.query(
ReadTarget::branch("main"),
TEST_QUERIES,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(main_qr.num_rows(), 0);
let run_qr = db
.query(
ReadTarget::branch(run.run_branch.as_str()),
TEST_QUERIES,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(run_qr.num_rows(), 1);
}
#[tokio::test]
async fn publish_run_merges_internal_branch_into_target_and_marks_record() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let run = db.begin_run("main", Some("publish-test")).await.unwrap();
db.load(
&run.run_branch,
r#"{"type":"Person","data":{"name":"Eve","age":22}}"#,
LoadMode::Append,
)
.await
.unwrap();
let published_snapshot = db.publish_run(&run.run_id).await.unwrap();
let record = db.get_run(&run.run_id).await.unwrap();
assert_eq!(record.status, RunStatus::Published);
assert_eq!(
record.published_snapshot_id.as_deref(),
Some(published_snapshot.as_str())
);
let main_qr = db
.query(
ReadTarget::branch("main"),
TEST_QUERIES,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(main_qr.num_rows(), 1);
}
#[tokio::test]
async fn abort_run_keeps_target_unchanged_and_preserves_hidden_branch_for_inspection() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let run = db.begin_run("main", Some("abort-test")).await.unwrap();
db.load(
&run.run_branch,
r#"{"type":"Person","data":{"name":"Eve","age":22}}"#,
LoadMode::Append,
)
.await
.unwrap();
let aborted = db.abort_run(&run.run_id).await.unwrap();
assert_eq!(aborted.status, RunStatus::Aborted);
let main_qr = db
.query(
ReadTarget::branch("main"),
TEST_QUERIES,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(main_qr.num_rows(), 0);
let run_qr = db
.query(
ReadTarget::branch(run.run_branch.as_str()),
TEST_QUERIES,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(run_qr.num_rows(), 1);
}
#[tokio::test]
async fn public_branch_apis_reject_internal_run_refs() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let run = db.begin_run("main", Some("guard-test")).await.unwrap();
let merge_err = db.branch_merge(&run.run_branch, "main").await.unwrap_err();
match merge_err {
OmniError::Manifest(message) => assert!(message.message.contains("internal run refs")),
other => panic!("unexpected error: {}", other),
}
let create_err = db.branch_create(&run.run_branch).await.unwrap_err();
match create_err {
OmniError::Manifest(message) => assert!(message.message.contains("internal run ref")),
other => panic!("unexpected error: {}", other),
}
let delete_err = db.branch_delete(&run.run_branch).await.unwrap_err();
match delete_err {
OmniError::Manifest(message) => assert!(message.message.contains("internal run ref")),
other => panic!("unexpected error: {}", other),
}
let fork_err = db
.branch_create_from(ReadTarget::branch(run.run_branch.as_str()), "child")
.await
.unwrap_err();
match fork_err {
OmniError::Manifest(message) => assert!(message.message.contains("internal run ref")),
other => panic!("unexpected error: {}", other),
}
}
#[tokio::test]
async fn branch_delete_rejects_target_branches_with_active_runs() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
db.branch_create("feature").await.unwrap();
let run = db.begin_run("feature", Some("delete-guard")).await.unwrap();
let err = db.branch_delete("feature").await.unwrap_err();
assert!(err.to_string().contains(run.run_id.as_str()));
assert!(err.to_string().contains("targeting it is running"));
}
#[tokio::test]
async fn public_load_uses_hidden_transactional_run_and_publishes_it() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
let result = load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
.await
.unwrap();
assert_eq!(result.nodes_loaded.len(), 2);
assert_eq!(result.edges_loaded.len(), 2);
assert_eq!(db.branch_list().await.unwrap(), vec!["main"]);
let runs = latest_runs(uri).await;
assert_eq!(runs.len(), 1);
assert_eq!(runs[0].target_branch, "main");
assert_eq!(runs[0].status, "published");
assert!(runs[0].run_branch.starts_with("__run__"));
let qr = db
.query(
ReadTarget::branch("main"),
TEST_QUERIES,
"get_person",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
assert_eq!(qr.num_rows(), 1);
}
#[tokio::test]
async fn public_load_preserves_staged_edge_ids_on_publish() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
.await
.unwrap();
let runs = latest_runs(uri).await;
let run_branch = runs[0].run_branch.clone();
let mut main_ids = collect_column_strings(&read_table(&db, "edge:Knows").await, "id");
let mut run_ids = collect_column_strings(
&read_table_branch(&db, run_branch.as_str(), "edge:Knows").await,
"id",
);
main_ids.sort();
run_ids.sort();
assert_eq!(main_ids, run_ids);
}
#[tokio::test]
async fn failed_public_load_marks_run_failed_and_leaves_target_unchanged() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
let bad = r#"{"type":"Person","data":{"name":"Alice","age":30}}
{"edge":"Knows","from":"Alice","to":"Missing"}"#;
let err = load_jsonl(&mut db, bad, LoadMode::Overwrite)
.await
.unwrap_err();
match err {
OmniError::Manifest(message) => assert!(message.message.contains("not found in Person")),
other => panic!("unexpected error: {}", other),
}
let runs = latest_runs(uri).await;
assert_eq!(runs.len(), 1);
assert_eq!(runs[0].status, "failed");
assert!(runs[0].run_branch.starts_with("__run__"));
let snap = snapshot_main(&db).await.unwrap();
let person_count = snap
.open("node:Person")
.await
.unwrap()
.count_rows(None)
.await
.unwrap();
assert_eq!(person_count, 0);
}
#[tokio::test]
async fn public_mutation_uses_hidden_transactional_run_and_publishes_it() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = init_and_load(&dir).await;
let result = db
.mutate(
"main",
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
)
.await
.unwrap();
assert_eq!(result.affected_nodes, 1);
assert_eq!(result.affected_edges, 0);
let runs = latest_runs(uri).await;
assert!(!runs.is_empty());
let latest = runs.last().unwrap();
assert_eq!(latest.target_branch, "main");
assert_eq!(latest.status, "published");
assert!(latest.run_branch.starts_with("__run__"));
let qr = db
.query(
ReadTarget::branch("main"),
TEST_QUERIES,
"get_person",
&params(&[("$name", "Eve")]),
)
.await
.unwrap();
assert_eq!(qr.num_rows(), 1);
}
#[tokio::test]
async fn public_mutation_preserves_staged_edge_ids_on_publish() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = init_and_load(&dir).await;
db.mutate(
"main",
MUTATION_QUERIES,
"add_friend",
&params(&[("$from", "Alice"), ("$to", "Diana")]),
)
.await
.unwrap();
let runs = latest_runs(uri).await;
let latest = runs.last().unwrap();
let mut main_ids = collect_column_strings(&read_table(&db, "edge:Knows").await, "id");
let mut run_ids = collect_column_strings(
&read_table_branch(&db, latest.run_branch.as_str(), "edge:Knows").await,
"id",
);
main_ids.sort();
run_ids.sort();
assert_eq!(main_ids, run_ids);
}
#[tokio::test]
async fn failed_public_mutation_marks_run_failed_and_leaves_target_unchanged() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = init_and_load(&dir).await;
let err = db
.mutate(
"main",
MUTATION_QUERIES,
"add_friend",
&params(&[("$from", "Alice"), ("$to", "Missing")]),
)
.await
.unwrap_err();
match err {
OmniError::Manifest(message) => assert!(message.message.contains("not found")),
other => panic!("unexpected error: {}", other),
}
let runs = latest_runs(uri).await;
assert!(!runs.is_empty());
let latest = runs.last().unwrap();
assert_eq!(latest.status, "failed");
assert!(latest.run_branch.starts_with("__run__"));
let qr = db
.query(
ReadTarget::branch("main"),
TEST_QUERIES,
"friends_of",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
assert_eq!(qr.num_rows(), 2);
}
#[tokio::test]
async fn concurrent_conflicting_run_publish_fails_cleanly() {
let dir = tempfile::tempdir().unwrap();
let mut db = init_and_load(&dir).await;
let run_a = db.begin_run("main", Some("conflict-a")).await.unwrap();
let run_b = db.begin_run("main", Some("conflict-b")).await.unwrap();
db.mutate(
run_a.run_branch.as_str(),
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Alice")], &[("$age", 31)]),
)
.await
.unwrap();
db.mutate(
run_b.run_branch.as_str(),
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Alice")], &[("$age", 32)]),
)
.await
.unwrap();
db.publish_run(&run_a.run_id).await.unwrap();
let publish_b = db.publish_run(&run_b.run_id).await;
assert!(publish_b.is_err(), "second conflicting publish should fail");
let err = publish_b.unwrap_err().to_string();
assert!(
err.contains("conflict") || err.contains("divergent") || err.contains("Alice"),
"unexpected conflict error: {}",
err
);
let alice = db
.query(
ReadTarget::branch("main"),
TEST_QUERIES,
"get_person",
&params(&[("$name", "Alice")]),
)
.await
.unwrap();
let rows = alice.to_rust_json();
assert_eq!(alice.num_rows(), 1);
assert_eq!(rows[0]["p.age"], serde_json::json!(31));
let run_a_record = db.get_run(&run_a.run_id).await.unwrap();
assert_eq!(run_a_record.status, RunStatus::Published);
let run_b_record = db.get_run(&run_b.run_id).await.unwrap();
assert_eq!(run_b_record.status, RunStatus::Running);
}
#[tokio::test]
async fn public_mutation_records_actor_on_run_and_published_commit() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let mut db = init_and_load(&dir).await;
db.mutate_as(
"main",
MUTATION_QUERIES,
"set_age",
&mixed_params(&[("$name", "Alice")], &[("$age", 31)]),
Some("act-andrew"),
)
.await
.unwrap();
let runs = db.list_runs().await.unwrap();
let run = runs
.iter()
.find(|run| run.operation_hash.as_deref() == Some("mutation:set_age:branch=main"))
.expect("published mutation run should exist");
assert_eq!(run.actor_id.as_deref(), Some("act-andrew"));
assert_eq!(run.status, RunStatus::Published);
let head = CommitGraph::open(uri)
.await
.unwrap()
.head_commit()
.await
.unwrap()
.unwrap();
assert_eq!(head.actor_id.as_deref(), Some("act-andrew"));
}

View file

@ -0,0 +1,187 @@
mod helpers;
use omnigraph::db::MergeOutcome;
use omnigraph::db::Omnigraph;
use omnigraph::loader::{LoadMode, load_jsonl};
use helpers::*;
#[tokio::test(flavor = "multi_thread")]
async fn s3_compatible_repo_lifecycle_works() {
let Some(uri) = s3_test_repo_uri("omnigraph-runtime") else {
eprintln!("skipping s3 runtime test: OMNIGRAPH_S3_TEST_BUCKET is not set");
return;
};
let mut db = Omnigraph::init(&uri, TEST_SCHEMA).await.unwrap();
load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
.await
.unwrap();
let mut reopened = Omnigraph::open(&uri).await.unwrap();
let snapshot = reopened.snapshot_of("main").await.unwrap();
assert!(snapshot.entry("node:Person").is_some());
assert!(snapshot.entry("edge:Knows").is_some());
let alice = query_main(
&mut reopened,
TEST_QUERIES,
"get_person",
&params(&[("$name", "Alice")]),
)
.await
.unwrap()
.to_rust_json();
assert_eq!(alice[0]["p.name"], "Alice");
reopened
.mutate(
"main",
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "RustFS-Eve")], &[("$age", 29)]),
)
.await
.unwrap();
let run = reopened
.begin_run("main", Some("s3-runtime-run"))
.await
.unwrap();
reopened
.load(
&run.run_branch,
r#"{"type":"Person","data":{"name":"RunOnly","age":31}}"#,
LoadMode::Append,
)
.await
.unwrap();
reopened.publish_run(&run.run_id).await.unwrap();
let runs = reopened.list_runs().await.unwrap();
assert!(
runs.iter()
.any(|record| { record.run_id == run.run_id && record.status.as_str() == "published" }),
"expected published run record in {:?}",
runs
);
let mut reopened_again = Omnigraph::open(&uri).await.unwrap();
let eve = query_main(
&mut reopened_again,
TEST_QUERIES,
"get_person",
&params(&[("$name", "RustFS-Eve")]),
)
.await
.unwrap()
.to_rust_json();
assert_eq!(eve[0]["p.name"], "RustFS-Eve");
let run_only = query_main(
&mut reopened_again,
TEST_QUERIES,
"get_person",
&params(&[("$name", "RunOnly")]),
)
.await
.unwrap()
.to_rust_json();
assert_eq!(run_only[0]["p.name"], "RunOnly");
}
#[tokio::test(flavor = "multi_thread")]
async fn s3_branch_change_merge_flow_works() {
let Some(uri) = s3_test_repo_uri("omnigraph-branching") else {
eprintln!("skipping s3 branch test: OMNIGRAPH_S3_TEST_BUCKET is not set");
return;
};
let mut main = Omnigraph::init(&uri, TEST_SCHEMA).await.unwrap();
load_jsonl(&mut main, TEST_DATA, LoadMode::Overwrite)
.await
.unwrap();
main.branch_create("feature").await.unwrap();
let mut feature = Omnigraph::open(&uri).await.unwrap();
feature
.mutate(
"feature",
MUTATION_QUERIES,
"insert_person",
&mixed_params(&[("$name", "Feature-Eve")], &[("$age", 22)]),
)
.await
.unwrap();
let before_merge = query_main(
&mut main,
TEST_QUERIES,
"get_person",
&params(&[("$name", "Feature-Eve")]),
)
.await
.unwrap();
assert_eq!(before_merge.num_rows(), 0);
let outcome = main.branch_merge("feature", "main").await.unwrap();
assert_eq!(outcome, MergeOutcome::FastForward);
let mut reopened = Omnigraph::open(&uri).await.unwrap();
let after_merge = query_main(
&mut reopened,
TEST_QUERIES,
"get_person",
&params(&[("$name", "Feature-Eve")]),
)
.await
.unwrap()
.to_rust_json();
assert_eq!(after_merge[0]["p.name"], "Feature-Eve");
assert_eq!(
reopened.branch_list().await.unwrap(),
vec!["main".to_string(), "feature".to_string()]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn s3_public_load_uses_hidden_run_and_publishes() {
let Some(uri) = s3_test_repo_uri("omnigraph-public-load") else {
eprintln!("skipping s3 public load test: OMNIGRAPH_S3_TEST_BUCKET is not set");
return;
};
let mut db = Omnigraph::init(&uri, TEST_SCHEMA).await.unwrap();
load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
.await
.unwrap();
db.load(
"main",
r#"{"type":"Person","data":{"name":"Loaded-Over-S3","age":34}}"#,
LoadMode::Append,
)
.await
.unwrap();
let runs = db.list_runs().await.unwrap();
assert!(
runs.iter().any(|record| {
record.target_branch == "main" && record.status.as_str() == "published"
}),
"expected published transactional run in {:?}",
runs
);
let mut reopened = Omnigraph::open(&uri).await.unwrap();
let loaded = query_main(
&mut reopened,
TEST_QUERIES,
"get_person",
&params(&[("$name", "Loaded-Over-S3")]),
)
.await
.unwrap()
.to_rust_json();
assert_eq!(loaded[0]["p.name"], "Loaded-Over-S3");
}

Some files were not shown because too many files have changed in this diff Show more