mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-10 08:05:14 +02:00
Initial open-source release
This commit is contained in:
commit
1a42152e6f
1199 changed files with 257054 additions and 0 deletions
442
python/klo-daemon/tests/test_app.py
Normal file
442
python/klo-daemon/tests/test_app.py
Normal file
|
|
@ -0,0 +1,442 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from klo_daemon.app import create_app
|
||||
from klo_daemon.database_introspection import (
|
||||
DatabaseIntrospectionResponse,
|
||||
LiveDatabaseColumn,
|
||||
LiveDatabaseTable,
|
||||
)
|
||||
|
||||
|
||||
ORDERS_SOURCE = {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "status", "type": "string"},
|
||||
{"name": "amount", "type": "number"},
|
||||
],
|
||||
"joins": [],
|
||||
"measures": [{"name": "order_count", "expr": "count(*)"}],
|
||||
}
|
||||
|
||||
LOOKML_ORDER_VIEW = """
|
||||
view: orders {
|
||||
sql_table_name: public.orders ;;
|
||||
|
||||
dimension: id {
|
||||
primary_key: yes
|
||||
type: number
|
||||
sql: ${TABLE}.id ;;
|
||||
}
|
||||
|
||||
dimension: status {
|
||||
type: string
|
||||
sql: ${TABLE}.status ;;
|
||||
}
|
||||
|
||||
measure: order_count {
|
||||
type: count
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class FakeEmbeddingProvider:
|
||||
name = "fake"
|
||||
dimensions = 3
|
||||
max_batch_size = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[list[str]] = []
|
||||
|
||||
def encode(self, texts: list[str]) -> list[list[float]]:
|
||||
self.calls.append(list(texts))
|
||||
return [
|
||||
[float(len(text)), float(index), 1.0] for index, text in enumerate(texts)
|
||||
]
|
||||
|
||||
|
||||
def test_health_endpoint_returns_healthy() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
response = client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "healthy"}
|
||||
|
||||
|
||||
def test_database_introspect_endpoint_returns_snapshot() -> None:
|
||||
calls = []
|
||||
|
||||
def fake_introspector(request):
|
||||
calls.append(request)
|
||||
return DatabaseIntrospectionResponse(
|
||||
connection_id=request.connection_id,
|
||||
extracted_at="2026-04-28T10:00:00+00:00",
|
||||
metadata={"driver": request.driver, "schemas": request.schemas},
|
||||
tables=[
|
||||
LiveDatabaseTable(
|
||||
catalog="warehouse",
|
||||
db="public",
|
||||
name="orders",
|
||||
columns=[
|
||||
LiveDatabaseColumn(
|
||||
name="id",
|
||||
type="integer",
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
client = TestClient(create_app(database_introspector=fake_introspector))
|
||||
|
||||
response = client.post(
|
||||
"/database/introspect",
|
||||
json={
|
||||
"connection_id": "warehouse",
|
||||
"driver": "postgres",
|
||||
"url": "postgresql://readonly@example.test/warehouse",
|
||||
"schemas": ["public"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["connection_id"] == "warehouse"
|
||||
assert response.json()["tables"][0]["name"] == "orders"
|
||||
assert calls[0].connection_id == "warehouse"
|
||||
|
||||
|
||||
def test_database_introspect_endpoint_maps_value_error_to_400() -> None:
|
||||
def fake_introspector(request):
|
||||
raise ValueError('database introspection supports only driver "postgres"')
|
||||
|
||||
client = TestClient(create_app(database_introspector=fake_introspector))
|
||||
|
||||
response = client.post(
|
||||
"/database/introspect",
|
||||
json={
|
||||
"connection_id": "warehouse",
|
||||
"driver": "snowflake",
|
||||
"url": "snowflake://example",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {
|
||||
"detail": 'database introspection supports only driver "postgres"'
|
||||
}
|
||||
|
||||
|
||||
def test_embedding_compute_endpoint_returns_embedding() -> None:
|
||||
provider = FakeEmbeddingProvider()
|
||||
client = TestClient(create_app(embedding_provider=provider))
|
||||
|
||||
response = client.post("/embeddings/compute", json={"text": "hello"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"embedding": [5.0, 0.0, 1.0]}
|
||||
assert provider.calls == [["hello"]]
|
||||
|
||||
|
||||
def test_embedding_compute_bulk_endpoint_returns_embeddings() -> None:
|
||||
provider = FakeEmbeddingProvider()
|
||||
client = TestClient(create_app(embedding_provider=provider))
|
||||
|
||||
response = client.post(
|
||||
"/embeddings/compute-bulk",
|
||||
json={"texts": ["one", "three"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"embeddings": [[3.0, 0.0, 1.0], [5.0, 1.0, 1.0]]}
|
||||
assert provider.calls == [["one", "three"]]
|
||||
|
||||
|
||||
def test_embedding_compute_bulk_endpoint_maps_value_error_to_400() -> None:
|
||||
provider = FakeEmbeddingProvider()
|
||||
client = TestClient(create_app(embedding_provider=provider))
|
||||
|
||||
response = client.post(
|
||||
"/embeddings/compute-bulk",
|
||||
json={"texts": ["one", "two", "three"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Maximum 2 texts allowed per batch"}
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_code_execute_endpoint_is_not_registered_by_default() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
response = client.post("/code/execute", json={"code": "result = 7"})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_code_execute_endpoint_returns_result_when_enabled() -> None:
|
||||
client = TestClient(create_app(enable_code_execution=True))
|
||||
|
||||
response = client.post(
|
||||
"/code/execute",
|
||||
json={"code": 'print("ran")\nresult = {"value": 7}'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["result"] == {"value": 7}
|
||||
assert body["console_output"] == "ran\n"
|
||||
assert body["error"] is None
|
||||
assert body["message"] is None
|
||||
assert body["visualizations"] is None
|
||||
assert "=== Console Output ===" in body["formatted_result"]
|
||||
assert "=== Result ===" in body["formatted_result"]
|
||||
|
||||
|
||||
def test_code_execute_endpoint_serializes_numpy_result_when_enabled() -> None:
|
||||
client = TestClient(create_app(enable_code_execution=True))
|
||||
|
||||
response = client.post(
|
||||
"/code/execute",
|
||||
json={"code": "import numpy as np\nresult = {'value': np.float64(1.25)}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["result"] == {"value": 1.25}
|
||||
assert body["error"] is None
|
||||
|
||||
|
||||
def test_code_execute_endpoint_uses_host_free_boundary_when_enabled() -> None:
|
||||
client = TestClient(create_app(enable_code_execution=True))
|
||||
|
||||
response = client.post(
|
||||
"/code/execute",
|
||||
json={
|
||||
"source_id": "chat_123",
|
||||
"message_id": "message_456",
|
||||
"code": (
|
||||
"import pandas as pd\n"
|
||||
"result = save_df_to_scratchpad(pd.DataFrame({'value': [1]}), 'out.json')"
|
||||
),
|
||||
},
|
||||
headers={"Authorization": "Bearer should-not-forward"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["result"] is None
|
||||
assert (
|
||||
body["error"]
|
||||
== "nest_api_url, Authorization header, and source_id are required for scratchpad operations"
|
||||
)
|
||||
assert "=== Error ===" in body["formatted_result"]
|
||||
|
||||
|
||||
def test_sql_parse_table_identifier_endpoint() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
response = client.post(
|
||||
"/sql/parse-table-identifier",
|
||||
json={
|
||||
"items": [
|
||||
{
|
||||
"key": "orders",
|
||||
"sql_table_name": "public.orders",
|
||||
"dialect": "postgres",
|
||||
},
|
||||
{
|
||||
"key": "template",
|
||||
"sql_table_name": "${orders.SQL_TABLE_NAME}",
|
||||
"dialect": "postgres",
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["results"]["orders"]["ok"] is True
|
||||
assert body["results"]["orders"]["schema"] == "public"
|
||||
assert body["results"]["orders"]["name"] == "orders"
|
||||
assert body["results"]["template"]["ok"] is False
|
||||
assert body["results"]["template"]["reason"] == "looker_template_unresolved"
|
||||
|
||||
|
||||
def test_semantic_query_endpoint_returns_sql() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
response = client.post(
|
||||
"/semantic-layer/query",
|
||||
json={
|
||||
"sources": [ORDERS_SOURCE],
|
||||
"dialect": "postgres",
|
||||
"query": {
|
||||
"measures": ["orders.order_count"],
|
||||
"dimensions": ["orders.status"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["dialect"] == "postgres"
|
||||
assert "public.orders" in body["sql"]
|
||||
assert body["columns"][0]["name"] == "orders.status"
|
||||
|
||||
|
||||
def test_semantic_query_endpoint_maps_value_error_to_400() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
response = client.post(
|
||||
"/semantic-layer/query",
|
||||
json={
|
||||
"sources": [ORDERS_SOURCE],
|
||||
"dialect": "postgres",
|
||||
"query": {
|
||||
"measures": ["missing.order_count"],
|
||||
"dimensions": [],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "missing.order_count" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_semantic_validate_endpoint_returns_structured_validation() -> None:
|
||||
client = TestClient(create_app())
|
||||
invalid_source = {
|
||||
**ORDERS_SOURCE,
|
||||
"measures": [
|
||||
{"name": "revenue", "expr": "sum(amount)"},
|
||||
{"name": "revenue", "expr": "sum(amount)"},
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/semantic-layer/validate",
|
||||
json={"sources": [invalid_source], "dialect": "postgres"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["valid"] is False
|
||||
assert any("Duplicate measure" in error for error in body["errors"])
|
||||
assert body["warnings"] == []
|
||||
assert body["per_source_warnings"] == {}
|
||||
|
||||
|
||||
def test_semantic_generate_sources_endpoint_returns_sources() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
response = client.post(
|
||||
"/semantic-layer/generate-sources",
|
||||
json={
|
||||
"tables": [
|
||||
{
|
||||
"name": "orders",
|
||||
"db": "public",
|
||||
"comment": "Orders table",
|
||||
"columns": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"primary_key": True,
|
||||
"nullable": False,
|
||||
"comment": "Order ID",
|
||||
},
|
||||
{"name": "customer_id", "type": "integer"},
|
||||
{
|
||||
"name": "amount",
|
||||
"type": "decimal",
|
||||
"comment": "Order amount",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "customers",
|
||||
"db": "public",
|
||||
"columns": [
|
||||
{"name": "id", "type": "integer", "primary_key": True},
|
||||
{"name": "email", "type": "varchar"},
|
||||
],
|
||||
},
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"from_table": "orders",
|
||||
"from_column": "customer_id",
|
||||
"to_table": "customers",
|
||||
"to_column": "id",
|
||||
"relationship_type": "MANY_TO_ONE",
|
||||
}
|
||||
],
|
||||
"dialect": "postgres",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["source_count"] == 2
|
||||
sources = {source["name"]: source for source in body["sources"]}
|
||||
assert sources["orders"]["table"] == "public.orders"
|
||||
assert sources["orders"]["description"] == "Orders table"
|
||||
assert sources["orders"]["grain"] == ["id"]
|
||||
assert sources["orders"]["joins"] == [
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
]
|
||||
assert [measure["name"] for measure in sources["orders"]["measures"]] == [
|
||||
"record_count",
|
||||
"total_amount",
|
||||
"avg_amount",
|
||||
]
|
||||
|
||||
|
||||
def test_lookml_parse_endpoint_returns_resolved_views() -> None:
|
||||
client = TestClient(create_app())
|
||||
|
||||
response = client.post(
|
||||
"/lookml/parse",
|
||||
json={
|
||||
"files": [
|
||||
{
|
||||
"path": "views/orders.view.lkml",
|
||||
"content": LOOKML_ORDER_VIEW,
|
||||
}
|
||||
],
|
||||
"dialect": "postgres",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["joins"] == []
|
||||
assert body["skipped_views"] == []
|
||||
assert body["warnings"] == []
|
||||
assert len(body["views"]) == 1
|
||||
view = body["views"][0]
|
||||
assert view["name"] == "orders"
|
||||
assert view["source_type"] == "table"
|
||||
assert view["table_ref"] == "public.orders"
|
||||
assert view["grain"] == ["id"]
|
||||
assert [column["name"] for column in view["columns"]] == ["id", "status"]
|
||||
assert view["measures"] == [
|
||||
{
|
||||
"name": "order_count",
|
||||
"expr": "count(*)",
|
||||
"filter": None,
|
||||
"description": None,
|
||||
}
|
||||
]
|
||||
426
python/klo-daemon/tests/test_cli.py
Normal file
426
python/klo-daemon/tests/test_cli.py
Normal file
|
|
@ -0,0 +1,426 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
ORDERS_SOURCE = {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "status", "type": "string"},
|
||||
{"name": "amount", "type": "number"},
|
||||
],
|
||||
"joins": [],
|
||||
"measures": [{"name": "order_count", "expr": "count(*)"}],
|
||||
}
|
||||
|
||||
|
||||
def run_daemon_command(
|
||||
command: str, payload: dict[str, object]
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
env = os.environ.copy()
|
||||
src_path = str(Path(__file__).resolve().parents[1] / "src")
|
||||
env["PYTHONPATH"] = src_path + os.pathsep + env.get("PYTHONPATH", "")
|
||||
return subprocess.run(
|
||||
[sys.executable, "-m", "klo_daemon", command],
|
||||
input=json.dumps(payload),
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
def test_semantic_query_command_reads_stdin_and_writes_json() -> None:
|
||||
result = run_daemon_command(
|
||||
"semantic-query",
|
||||
{
|
||||
"sources": [ORDERS_SOURCE],
|
||||
"dialect": "postgres",
|
||||
"query": {
|
||||
"measures": ["orders.order_count"],
|
||||
"dimensions": ["orders.status"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert result.returncode == 0, result.stderr
|
||||
parsed = json.loads(result.stdout)
|
||||
assert "public.orders" in parsed["sql"]
|
||||
assert parsed["columns"][0]["name"] == "orders.status"
|
||||
|
||||
|
||||
def test_semantic_validate_command_reads_stdin_and_writes_json() -> None:
|
||||
result = run_daemon_command(
|
||||
"semantic-validate",
|
||||
{"sources": [ORDERS_SOURCE], "dialect": "postgres"},
|
||||
)
|
||||
|
||||
assert result.returncode == 0, result.stderr
|
||||
parsed = json.loads(result.stdout)
|
||||
assert parsed == {
|
||||
"valid": True,
|
||||
"errors": [],
|
||||
"warnings": [],
|
||||
"per_source_warnings": {},
|
||||
}
|
||||
|
||||
|
||||
def test_command_returns_nonzero_for_invalid_json() -> None:
|
||||
env = os.environ.copy()
|
||||
src_path = str(Path(__file__).resolve().parents[1] / "src")
|
||||
env["PYTHONPATH"] = src_path + os.pathsep + env.get("PYTHONPATH", "")
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "klo_daemon", "semantic-query"],
|
||||
input="{",
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
|
||||
assert result.returncode == 1
|
||||
assert "Expecting property name enclosed in double quotes" in result.stderr
|
||||
|
||||
|
||||
def test_serve_http_command_starts_uvicorn_without_reading_stdin(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
from klo_daemon import __main__ as daemon_main
|
||||
|
||||
calls: list[dict[str, object]] = []
|
||||
|
||||
class FailingStdin:
|
||||
def read(self) -> str:
|
||||
raise AssertionError("serve-http must not read stdin JSON")
|
||||
|
||||
def fake_run_http_server(
|
||||
*,
|
||||
host: str,
|
||||
port: int,
|
||||
log_level: str,
|
||||
enable_code_execution: bool,
|
||||
) -> None:
|
||||
calls.append(
|
||||
{
|
||||
"host": host,
|
||||
"port": port,
|
||||
"log_level": log_level,
|
||||
"enable_code_execution": enable_code_execution,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(sys, "stdin", FailingStdin())
|
||||
monkeypatch.setattr(daemon_main, "run_http_server", fake_run_http_server)
|
||||
|
||||
assert (
|
||||
daemon_main.main(
|
||||
[
|
||||
"serve-http",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"9191",
|
||||
"--log-level",
|
||||
"warning",
|
||||
]
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert calls == [
|
||||
{
|
||||
"host": "127.0.0.1",
|
||||
"port": 9191,
|
||||
"log_level": "warning",
|
||||
"enable_code_execution": False,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_serve_http_command_defaults_to_loopback(monkeypatch) -> None:
|
||||
from klo_daemon import __main__ as daemon_main
|
||||
|
||||
calls: list[dict[str, object]] = []
|
||||
|
||||
def fake_run_http_server(
|
||||
*,
|
||||
host: str,
|
||||
port: int,
|
||||
log_level: str,
|
||||
enable_code_execution: bool,
|
||||
) -> None:
|
||||
calls.append(
|
||||
{
|
||||
"host": host,
|
||||
"port": port,
|
||||
"log_level": log_level,
|
||||
"enable_code_execution": enable_code_execution,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(daemon_main, "run_http_server", fake_run_http_server)
|
||||
|
||||
assert daemon_main.main(["serve-http"]) == 0
|
||||
assert calls == [
|
||||
{
|
||||
"host": "127.0.0.1",
|
||||
"port": 8765,
|
||||
"log_level": "info",
|
||||
"enable_code_execution": False,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_serve_http_command_can_enable_code_execution(monkeypatch) -> None:
|
||||
from klo_daemon import __main__ as daemon_main
|
||||
|
||||
calls: list[dict[str, object]] = []
|
||||
|
||||
def fake_run_http_server(
|
||||
*,
|
||||
host: str,
|
||||
port: int,
|
||||
log_level: str,
|
||||
enable_code_execution: bool,
|
||||
) -> None:
|
||||
calls.append(
|
||||
{
|
||||
"host": host,
|
||||
"port": port,
|
||||
"log_level": log_level,
|
||||
"enable_code_execution": enable_code_execution,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(daemon_main, "run_http_server", fake_run_http_server)
|
||||
|
||||
assert daemon_main.main(["serve-http", "--enable-code-execution"]) == 0
|
||||
assert calls == [
|
||||
{
|
||||
"host": "127.0.0.1",
|
||||
"port": 8765,
|
||||
"log_level": "info",
|
||||
"enable_code_execution": True,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_lookml_parse_command_reads_stdin_and_writes_json() -> None:
|
||||
result = run_daemon_command(
|
||||
"lookml-parse",
|
||||
{
|
||||
"files": [
|
||||
{
|
||||
"path": "views/orders.view.lkml",
|
||||
"content": """
|
||||
view: orders {
|
||||
sql_table_name: public.orders ;;
|
||||
|
||||
dimension: id {
|
||||
primary_key: yes
|
||||
type: number
|
||||
sql: ${TABLE}.id ;;
|
||||
}
|
||||
|
||||
measure: order_count {
|
||||
type: count
|
||||
}
|
||||
}
|
||||
""",
|
||||
}
|
||||
],
|
||||
"dialect": "postgres",
|
||||
},
|
||||
)
|
||||
|
||||
assert result.returncode == 0, result.stderr
|
||||
parsed = json.loads(result.stdout)
|
||||
assert parsed["views"][0]["name"] == "orders"
|
||||
assert parsed["views"][0]["table_ref"] == "public.orders"
|
||||
assert parsed["views"][0]["measures"][0]["expr"] == "count(*)"
|
||||
assert parsed["joins"] == []
|
||||
assert parsed["skipped_views"] == []
|
||||
assert parsed["warnings"] == []
|
||||
|
||||
|
||||
def test_semantic_generate_sources_command_reads_stdin_and_writes_json() -> None:
|
||||
result = run_daemon_command(
|
||||
"semantic-generate-sources",
|
||||
{
|
||||
"tables": [
|
||||
{
|
||||
"name": "orders",
|
||||
"db": "public",
|
||||
"columns": [
|
||||
{"name": "id", "type": "integer", "primary_key": True},
|
||||
{"name": "amount", "type": "decimal"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"dialect": "postgres",
|
||||
},
|
||||
)
|
||||
|
||||
assert result.returncode == 0, result.stderr
|
||||
parsed = json.loads(result.stdout)
|
||||
assert parsed["source_count"] == 1
|
||||
assert parsed["sources"][0]["name"] == "orders"
|
||||
assert parsed["sources"][0]["table"] == "public.orders"
|
||||
assert parsed["sources"][0]["measures"] == [
|
||||
{
|
||||
"name": "record_count",
|
||||
"expr": "count(id)",
|
||||
"segments": [],
|
||||
"description": "Count of orders records",
|
||||
},
|
||||
{
|
||||
"name": "total_amount",
|
||||
"expr": "sum(amount)",
|
||||
"segments": [],
|
||||
"description": "Sum of amount",
|
||||
},
|
||||
{
|
||||
"name": "avg_amount",
|
||||
"expr": "avg(amount)",
|
||||
"segments": [],
|
||||
"description": "Average of amount",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_database_introspect_command_reads_stdin_and_writes_json(
|
||||
monkeypatch, capsys
|
||||
) -> None:
|
||||
from klo_daemon import __main__ as daemon_main
|
||||
from klo_daemon.database_introspection import (
|
||||
DatabaseIntrospectionResponse,
|
||||
LiveDatabaseColumn,
|
||||
LiveDatabaseTable,
|
||||
)
|
||||
|
||||
def fake_introspect(request):
|
||||
assert request.connection_id == "warehouse"
|
||||
assert request.driver == "postgres"
|
||||
assert request.schemas == ["public"]
|
||||
return DatabaseIntrospectionResponse(
|
||||
connection_id="warehouse",
|
||||
extracted_at="2026-04-28T10:00:00+00:00",
|
||||
metadata={"driver": "postgres", "schemas": ["public"]},
|
||||
tables=[
|
||||
LiveDatabaseTable(
|
||||
catalog="warehouse",
|
||||
db="public",
|
||||
name="orders",
|
||||
columns=[
|
||||
LiveDatabaseColumn(
|
||||
name="id",
|
||||
type="integer",
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(daemon_main, "introspect_database_response", fake_introspect)
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"stdin",
|
||||
io.StringIO(
|
||||
'{"connection_id":"warehouse","driver":"postgres","url":"postgresql://readonly@example.test/warehouse","schemas":["public"]}'
|
||||
),
|
||||
)
|
||||
|
||||
assert daemon_main.main(["database-introspect"]) == 0
|
||||
captured = capsys.readouterr()
|
||||
parsed = json.loads(captured.out)
|
||||
assert parsed["connection_id"] == "warehouse"
|
||||
assert parsed["metadata"] == {"driver": "postgres", "schemas": ["public"]}
|
||||
assert parsed["tables"][0]["name"] == "orders"
|
||||
assert captured.err == ""
|
||||
|
||||
|
||||
def test_embedding_compute_command_reads_stdin_and_writes_json(
|
||||
monkeypatch, capsys
|
||||
) -> None:
|
||||
from klo_daemon import __main__ as daemon_main
|
||||
from klo_daemon.embeddings import ComputeEmbeddingResponse
|
||||
|
||||
def fake_compute(request):
|
||||
assert request.text == "hello"
|
||||
return ComputeEmbeddingResponse(embedding=[1.0, 2.0, 3.0])
|
||||
|
||||
monkeypatch.setattr(daemon_main, "compute_embedding_response", fake_compute)
|
||||
monkeypatch.setattr(sys, "stdin", io.StringIO('{"text": "hello"}'))
|
||||
|
||||
assert daemon_main.main(["embedding-compute"]) == 0
|
||||
captured = capsys.readouterr()
|
||||
assert json.loads(captured.out) == {"embedding": [1.0, 2.0, 3.0]}
|
||||
assert captured.err == ""
|
||||
|
||||
|
||||
def test_embedding_compute_bulk_command_reads_stdin_and_writes_json(
|
||||
monkeypatch, capsys
|
||||
) -> None:
|
||||
from klo_daemon import __main__ as daemon_main
|
||||
from klo_daemon.embeddings import ComputeEmbeddingBulkResponse
|
||||
|
||||
def fake_compute(request):
|
||||
assert request.texts == ["hello", "world"]
|
||||
return ComputeEmbeddingBulkResponse(embeddings=[[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
monkeypatch.setattr(daemon_main, "compute_embedding_bulk_response", fake_compute)
|
||||
monkeypatch.setattr(sys, "stdin", io.StringIO('{"texts": ["hello", "world"]}'))
|
||||
|
||||
assert daemon_main.main(["embedding-compute-bulk"]) == 0
|
||||
captured = capsys.readouterr()
|
||||
assert json.loads(captured.out) == {"embeddings": [[1.0, 2.0], [3.0, 4.0]]}
|
||||
assert captured.err == ""
|
||||
|
||||
|
||||
def test_code_execute_command_reads_stdin_and_writes_json(monkeypatch, capsys) -> None:
|
||||
from klo_daemon import __main__ as daemon_main
|
||||
from klo_daemon.code_execution import ExecuteCodeResponse
|
||||
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
def fake_execute(request, *, nest_api_url, auth_header):
|
||||
calls.append(
|
||||
{
|
||||
"request": request,
|
||||
"nest_api_url": nest_api_url,
|
||||
"auth_header": auth_header,
|
||||
}
|
||||
)
|
||||
return ExecuteCodeResponse(
|
||||
formatted_result="\n\n=== Result ===\n\n7",
|
||||
result=7,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(daemon_main, "execute_code_response", fake_execute)
|
||||
monkeypatch.setattr(sys, "stdin", io.StringIO('{"code": "result = 7"}'))
|
||||
|
||||
assert daemon_main.main(["code-execute"]) == 0
|
||||
captured = capsys.readouterr()
|
||||
assert json.loads(captured.out) == {
|
||||
"formatted_result": "\n\n=== Result ===\n\n7",
|
||||
"result": 7,
|
||||
"console_output": None,
|
||||
"error": None,
|
||||
"message": None,
|
||||
"visualizations": None,
|
||||
}
|
||||
assert captured.err == ""
|
||||
assert calls[0]["request"].code == "result = 7"
|
||||
assert calls[0]["nest_api_url"] is None
|
||||
assert calls[0]["auth_header"] is None
|
||||
210
python/klo-daemon/tests/test_code_execution.py
Normal file
210
python/klo-daemon/tests/test_code_execution.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from klo_daemon.code_execution import (
|
||||
ExecuteCodeRequest,
|
||||
create_scratchpad_helpers,
|
||||
detect_visualizations,
|
||||
dumps_numpy_json,
|
||||
execute_code_response,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeResponse:
|
||||
json_payload: dict[str, Any] | None = None
|
||||
content: bytes = b""
|
||||
headers: dict[str, str] | None = None
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
return None
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
return self.json_payload or {}
|
||||
|
||||
|
||||
class FakeHttpClient:
|
||||
def __init__(self) -> None:
|
||||
self.posts: list[dict[str, Any]] = []
|
||||
self.gets: list[dict[str, Any]] = []
|
||||
|
||||
def post(
|
||||
self,
|
||||
url: str,
|
||||
data: bytes,
|
||||
headers: dict[str, str],
|
||||
timeout: int,
|
||||
) -> FakeResponse:
|
||||
self.posts.append(
|
||||
{
|
||||
"url": url,
|
||||
"data": orjson.loads(data),
|
||||
"headers": headers,
|
||||
"timeout": timeout,
|
||||
}
|
||||
)
|
||||
return FakeResponse(json_payload={"filename": "saved.json"})
|
||||
|
||||
def get(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
timeout: int,
|
||||
) -> FakeResponse:
|
||||
self.gets.append({"url": url, "headers": headers, "timeout": timeout})
|
||||
return FakeResponse(
|
||||
content=b"value,name\n1.25,alpha\n",
|
||||
headers={"content-type": "text/csv; charset=utf-8"},
|
||||
)
|
||||
|
||||
|
||||
def test_execute_code_response_captures_console_result_and_strips_ansi() -> None:
|
||||
response = execute_code_response(
|
||||
ExecuteCodeRequest(
|
||||
code='print("\\x1b[31mhello\\x1b[0m")\nresult = {"value": 3}',
|
||||
),
|
||||
nest_api_url=None,
|
||||
auth_header=None,
|
||||
)
|
||||
|
||||
assert response.result == {"value": 3}
|
||||
assert response.console_output == "\x1b[31mhello\x1b[0m\n"
|
||||
assert "=== Console Output ===" in response.formatted_result
|
||||
assert "hello" in response.formatted_result
|
||||
assert "\x1b" not in response.formatted_result
|
||||
assert "=== Result ===" in response.formatted_result
|
||||
|
||||
|
||||
def test_execute_code_response_returns_message_when_result_is_absent() -> None:
|
||||
response = execute_code_response(
|
||||
ExecuteCodeRequest(code='print("ran")'),
|
||||
nest_api_url=None,
|
||||
auth_header=None,
|
||||
)
|
||||
|
||||
assert response.result is None
|
||||
assert (
|
||||
response.message == "Code executed successfully but no result variable was set"
|
||||
)
|
||||
assert response.console_output == "ran\n"
|
||||
assert "=== Message ===" in response.formatted_result
|
||||
|
||||
|
||||
def test_execute_code_response_detects_visualization_records() -> None:
|
||||
response = execute_code_response(
|
||||
ExecuteCodeRequest(
|
||||
code="result = "
|
||||
+ json.dumps(
|
||||
{
|
||||
"type": "visualization",
|
||||
"vis_type": "bar",
|
||||
"config": {"title": "Revenue"},
|
||||
"data": [{"month": "Jan", "revenue": 10}],
|
||||
"title": "Revenue",
|
||||
}
|
||||
),
|
||||
),
|
||||
nest_api_url=None,
|
||||
auth_header=None,
|
||||
)
|
||||
|
||||
assert response.visualizations is not None
|
||||
assert len(response.visualizations) == 1
|
||||
assert response.visualizations[0].vis_type == "bar"
|
||||
assert response.visualizations[0].title == "Revenue"
|
||||
|
||||
|
||||
def test_detect_visualizations_filters_mixed_lists() -> None:
|
||||
visualizations = detect_visualizations(
|
||||
[
|
||||
{"type": "note", "text": "skip"},
|
||||
{
|
||||
"type": "visualization",
|
||||
"vis_type": "table",
|
||||
"config": {"title": "Rows"},
|
||||
"data": [{"row": 1}],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert visualizations == [
|
||||
{
|
||||
"type": "visualization",
|
||||
"vis_type": "table",
|
||||
"config": {"title": "Rows"},
|
||||
"data": [{"row": 1}],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_scratchpad_and_visualization_helpers_serialize_numpy_scalars() -> None:
|
||||
client = FakeHttpClient()
|
||||
save_df, read_file, save_viz = create_scratchpad_helpers(
|
||||
nest_api_url="http://nest",
|
||||
auth_header="Bearer token",
|
||||
source_id="source_123",
|
||||
message_id="message_456",
|
||||
http_client=client,
|
||||
)
|
||||
|
||||
df = pd.DataFrame({"value": [np.float64(1.25)]})
|
||||
assert save_df(df, filename="df.json") == "1 rows saved to saved.json"
|
||||
|
||||
read_df = read_file("input.csv")
|
||||
assert read_df.to_dict(orient="records") == [{"value": 1.25, "name": "alpha"}]
|
||||
|
||||
viz_ref = save_viz(
|
||||
vis_type="bar",
|
||||
config={"title": "Test", "x": "a", "y": np.float64(2.5)},
|
||||
data=[{"a": "row1", "b": np.float64(3.75)}],
|
||||
)
|
||||
assert viz_ref == ""
|
||||
|
||||
assert (
|
||||
client.posts[0]["url"] == "http://nest/private_api/scratchpad/source_123/files"
|
||||
)
|
||||
assert client.posts[0]["data"]["data"][0]["value"] == 1.25
|
||||
assert (
|
||||
client.gets[0]["url"]
|
||||
== "http://nest/private_api/scratchpad/source_123/files/input.csv?format=raw"
|
||||
)
|
||||
assert client.posts[1]["url"] == "http://nest/private_api/visualizations/source_123"
|
||||
assert client.posts[1]["data"]["config"]["y"] == 2.5
|
||||
assert client.posts[1]["data"]["data"][0]["b"] == 3.75
|
||||
|
||||
|
||||
def test_scratchpad_helpers_require_app_context_only_when_called() -> None:
|
||||
save_df, read_file, save_viz = create_scratchpad_helpers(
|
||||
nest_api_url=None,
|
||||
auth_header=None,
|
||||
source_id=None,
|
||||
message_id=None,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="required for scratchpad operations"):
|
||||
save_df(pd.DataFrame({"value": [1]}), filename="df.json")
|
||||
|
||||
with pytest.raises(ValueError, match="required for scratchpad operations"):
|
||||
read_file("df.csv")
|
||||
|
||||
with pytest.raises(ValueError, match="required for visualization operations"):
|
||||
save_viz("bar", {"title": "Chart"}, [{"value": 1}])
|
||||
|
||||
|
||||
def test_dumps_numpy_json_serializes_numpy_values() -> None:
|
||||
rendered = dumps_numpy_json(
|
||||
{
|
||||
"scalar": np.float64(1.5),
|
||||
"array": np.array([1, 2, 3]),
|
||||
}
|
||||
)
|
||||
|
||||
assert orjson.loads(rendered) == {"scalar": 1.5, "array": [1, 2, 3]}
|
||||
153
python/klo-daemon/tests/test_database_introspection.py
Normal file
153
python/klo-daemon/tests/test_database_introspection.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from klo_daemon.database_introspection import (
|
||||
DatabaseIntrospectionRequest,
|
||||
DatabaseIntrospectionRows,
|
||||
_statement_timeout_config,
|
||||
introspect_database_response,
|
||||
)
|
||||
|
||||
|
||||
def test_introspect_database_response_maps_postgres_catalog_rows() -> None:
|
||||
def fake_load_rows(
|
||||
request: DatabaseIntrospectionRequest,
|
||||
) -> DatabaseIntrospectionRows:
|
||||
assert request.connection_id == "warehouse"
|
||||
assert request.driver == "postgres"
|
||||
assert request.schemas == ["public"]
|
||||
return DatabaseIntrospectionRows(
|
||||
table_rows=[
|
||||
{
|
||||
"table_catalog": "warehouse",
|
||||
"table_schema": "public",
|
||||
"table_name": "customers",
|
||||
"table_comment": None,
|
||||
},
|
||||
{
|
||||
"table_catalog": "warehouse",
|
||||
"table_schema": "public",
|
||||
"table_name": "orders",
|
||||
"table_comment": "Orders table",
|
||||
},
|
||||
],
|
||||
column_rows=[
|
||||
{
|
||||
"table_catalog": "warehouse",
|
||||
"table_schema": "public",
|
||||
"table_name": "orders",
|
||||
"column_name": "id",
|
||||
"formatted_type": "integer",
|
||||
"is_nullable": False,
|
||||
"is_primary_key": True,
|
||||
"column_comment": "Order ID",
|
||||
},
|
||||
{
|
||||
"table_catalog": "warehouse",
|
||||
"table_schema": "public",
|
||||
"table_name": "orders",
|
||||
"column_name": "customer_id",
|
||||
"formatted_type": "integer",
|
||||
"is_nullable": False,
|
||||
"is_primary_key": False,
|
||||
"column_comment": None,
|
||||
},
|
||||
{
|
||||
"table_catalog": "warehouse",
|
||||
"table_schema": "public",
|
||||
"table_name": "customers",
|
||||
"column_name": "id",
|
||||
"formatted_type": "integer",
|
||||
"is_nullable": False,
|
||||
"is_primary_key": True,
|
||||
"column_comment": None,
|
||||
},
|
||||
],
|
||||
foreign_key_rows=[
|
||||
{
|
||||
"table_catalog": "warehouse",
|
||||
"table_schema": "public",
|
||||
"table_name": "orders",
|
||||
"from_column": "customer_id",
|
||||
"to_table": "customers",
|
||||
"to_column": "id",
|
||||
"constraint_name": "orders_customer_id_fkey",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
response = introspect_database_response(
|
||||
DatabaseIntrospectionRequest(
|
||||
connection_id="warehouse",
|
||||
driver="postgres",
|
||||
url="postgresql://readonly@example.test/warehouse",
|
||||
schemas=["public"],
|
||||
),
|
||||
load_rows=fake_load_rows,
|
||||
now=lambda: "2026-04-28T10:00:00+00:00",
|
||||
)
|
||||
|
||||
assert response.connection_id == "warehouse"
|
||||
assert response.extracted_at == "2026-04-28T10:00:00+00:00"
|
||||
assert response.metadata == {"driver": "postgres", "schemas": ["public"]}
|
||||
assert [table.name for table in response.tables] == ["customers", "orders"]
|
||||
orders = response.tables[1]
|
||||
assert orders.model_dump(exclude_none=True) == {
|
||||
"catalog": "warehouse",
|
||||
"db": "public",
|
||||
"name": "orders",
|
||||
"comment": "Orders table",
|
||||
"columns": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"nullable": False,
|
||||
"primary_key": True,
|
||||
"comment": "Order ID",
|
||||
},
|
||||
{
|
||||
"name": "customer_id",
|
||||
"type": "integer",
|
||||
"nullable": False,
|
||||
"primary_key": False,
|
||||
},
|
||||
],
|
||||
"foreign_keys": [
|
||||
{
|
||||
"from_column": "customer_id",
|
||||
"to_table": "customers",
|
||||
"to_column": "id",
|
||||
"constraint_name": "orders_customer_id_fkey",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_introspect_database_response_rejects_non_postgres_driver() -> None:
|
||||
with pytest.raises(ValueError, match='supports only driver "postgres"'):
|
||||
introspect_database_response(
|
||||
DatabaseIntrospectionRequest(
|
||||
connection_id="warehouse",
|
||||
driver="snowflake",
|
||||
url="snowflake://example",
|
||||
),
|
||||
load_rows=lambda request: DatabaseIntrospectionRows([], [], []),
|
||||
)
|
||||
|
||||
|
||||
def test_database_introspection_request_rejects_empty_schema_list() -> None:
|
||||
with pytest.raises(ValueError, match="at least one schema"):
|
||||
DatabaseIntrospectionRequest(
|
||||
connection_id="warehouse",
|
||||
driver="postgres",
|
||||
url="postgresql://readonly@example.test/warehouse",
|
||||
schemas=[],
|
||||
)
|
||||
|
||||
|
||||
def test_statement_timeout_config_uses_parameterized_set_config() -> None:
|
||||
assert _statement_timeout_config(30_000) == (
|
||||
"SELECT set_config('statement_timeout', %s, true)",
|
||||
("30000ms",),
|
||||
)
|
||||
107
python/klo-daemon/tests/test_embeddings.py
Normal file
107
python/klo-daemon/tests/test_embeddings.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from klo_daemon.embeddings import (
|
||||
ComputeEmbeddingBulkRequest,
|
||||
ComputeEmbeddingRequest,
|
||||
SentenceTransformersEmbeddingProvider,
|
||||
compute_embedding_bulk_response,
|
||||
compute_embedding_response,
|
||||
)
|
||||
|
||||
|
||||
class FakeEmbeddingProvider:
|
||||
name = "fake"
|
||||
dimensions = 3
|
||||
max_batch_size = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[list[str]] = []
|
||||
|
||||
def encode(self, texts: list[str]) -> list[list[float]]:
|
||||
self.calls.append(list(texts))
|
||||
return [
|
||||
[float(len(text)), float(index), 1.0] for index, text in enumerate(texts)
|
||||
]
|
||||
|
||||
|
||||
class ArrayLike:
|
||||
def __init__(self, value: list[float] | list[list[float]]) -> None:
|
||||
self.value = value
|
||||
|
||||
def tolist(self) -> list[float] | list[list[float]]:
|
||||
return self.value
|
||||
|
||||
|
||||
class FakeSentenceTransformerModel:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[str | list[str]] = []
|
||||
|
||||
def encode(self, value: str | list[str]) -> ArrayLike:
|
||||
self.calls.append(value)
|
||||
if isinstance(value, str):
|
||||
return ArrayLike([0.1, 0.2, 0.3])
|
||||
return ArrayLike(
|
||||
[[float(index), float(len(text)), 0.5] for index, text in enumerate(value)]
|
||||
)
|
||||
|
||||
|
||||
def test_compute_embedding_response_uses_injected_provider() -> None:
|
||||
provider = FakeEmbeddingProvider()
|
||||
|
||||
response = compute_embedding_response(
|
||||
ComputeEmbeddingRequest(text="hello"),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
assert response.embedding == [5.0, 0.0, 1.0]
|
||||
assert provider.calls == [["hello"]]
|
||||
|
||||
|
||||
def test_compute_embedding_bulk_response_uses_injected_provider() -> None:
|
||||
provider = FakeEmbeddingProvider()
|
||||
|
||||
response = compute_embedding_bulk_response(
|
||||
ComputeEmbeddingBulkRequest(texts=["one", "three"]),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
assert response.embeddings == [[3.0, 0.0, 1.0], [5.0, 1.0, 1.0]]
|
||||
assert provider.calls == [["one", "three"]]
|
||||
|
||||
|
||||
def test_compute_embedding_bulk_rejects_empty_texts() -> None:
|
||||
provider = FakeEmbeddingProvider()
|
||||
|
||||
with pytest.raises(ValueError, match="Empty texts found at indices: 1"):
|
||||
compute_embedding_bulk_response(
|
||||
ComputeEmbeddingBulkRequest(texts=["valid", " "]),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_compute_embedding_bulk_respects_provider_batch_size() -> None:
|
||||
provider = FakeEmbeddingProvider()
|
||||
|
||||
with pytest.raises(ValueError, match="Maximum 2 texts allowed per batch"):
|
||||
compute_embedding_bulk_response(
|
||||
ComputeEmbeddingBulkRequest(texts=["one", "two", "three"]),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_sentence_transformers_provider_normalizes_single_and_bulk_outputs() -> None:
|
||||
model = FakeSentenceTransformerModel()
|
||||
provider = SentenceTransformersEmbeddingProvider(model=model)
|
||||
|
||||
assert provider.encode(["hello"]) == [[0.1, 0.2, 0.3]]
|
||||
assert provider.encode(["one", "three"]) == [
|
||||
[0.0, 3.0, 0.5],
|
||||
[1.0, 5.0, 0.5],
|
||||
]
|
||||
assert model.calls == ["hello", ["one", "three"]]
|
||||
134
python/klo-daemon/tests/test_lookml.py
Normal file
134
python/klo-daemon/tests/test_lookml.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from klo_daemon.lookml import (
|
||||
LookMLFileInput,
|
||||
ParseLookMLRequest,
|
||||
parse_lookml_project,
|
||||
)
|
||||
|
||||
|
||||
ORDER_VIEW = """
|
||||
view: orders {
|
||||
sql_table_name: public.orders ;;
|
||||
|
||||
dimension: id {
|
||||
primary_key: yes
|
||||
type: number
|
||||
sql: ${TABLE}.id ;;
|
||||
}
|
||||
|
||||
dimension: user_id {
|
||||
type: number
|
||||
sql: ${TABLE}.user_id ;;
|
||||
}
|
||||
|
||||
dimension: status {
|
||||
type: string
|
||||
sql: ${TABLE}.status ;;
|
||||
}
|
||||
|
||||
measure: order_count {
|
||||
type: count
|
||||
}
|
||||
|
||||
measure: revenue {
|
||||
type: sum
|
||||
sql: ${TABLE}.amount ;;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
USER_VIEW = """
|
||||
view: users {
|
||||
sql_table_name: public.users ;;
|
||||
|
||||
dimension: id {
|
||||
primary_key: yes
|
||||
type: number
|
||||
sql: ${TABLE}.id ;;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
ORDER_MODEL = """
|
||||
explore: orders {
|
||||
join: users {
|
||||
relationship: many_to_one
|
||||
sql_on: ${orders.user_id} = ${users.id} ;;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
DERIVED_VIEW = """
|
||||
view: order_rollup {
|
||||
derived_table: {
|
||||
sql:
|
||||
SELECT status, SUM(amount) AS total_amount
|
||||
FROM public.orders
|
||||
GROUP BY status ;;
|
||||
}
|
||||
|
||||
dimension: status {
|
||||
type: string
|
||||
sql: ${TABLE}.status ;;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def test_parse_lookml_project_returns_views_and_joins() -> None:
|
||||
response = parse_lookml_project(
|
||||
ParseLookMLRequest(
|
||||
files=[
|
||||
LookMLFileInput(path="views/orders.view.lkml", content=ORDER_VIEW),
|
||||
LookMLFileInput(path="views/users.view.lkml", content=USER_VIEW),
|
||||
LookMLFileInput(
|
||||
path="models/ecommerce.model.lkml", content=ORDER_MODEL
|
||||
),
|
||||
],
|
||||
dialect="postgres",
|
||||
)
|
||||
)
|
||||
|
||||
views = {view.name: view for view in response.views}
|
||||
assert sorted(views) == ["orders", "users"]
|
||||
assert views["orders"].source_type == "table"
|
||||
assert views["orders"].table_ref == "public.orders"
|
||||
assert views["orders"].grain == ["id"]
|
||||
assert [measure.name for measure in views["orders"].measures] == [
|
||||
"order_count",
|
||||
"revenue",
|
||||
]
|
||||
assert views["orders"].measures[0].expr == "count(*)"
|
||||
assert views["orders"].measures[1].expr == "sum(amount)"
|
||||
assert response.joins[0].source_view == "orders"
|
||||
assert response.joins[0].to == "users"
|
||||
assert response.joins[0].relationship == "many_to_one"
|
||||
assert response.joins[0].on == "orders.user_id = users.id"
|
||||
assert response.skipped_views == []
|
||||
assert response.warnings == []
|
||||
|
||||
|
||||
def test_parse_lookml_project_extracts_derived_table_columns() -> None:
|
||||
response = parse_lookml_project(
|
||||
ParseLookMLRequest(
|
||||
files=[
|
||||
LookMLFileInput(
|
||||
path="views/order_rollup.view.lkml", content=DERIVED_VIEW
|
||||
)
|
||||
],
|
||||
dialect="postgres",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(response.views) == 1
|
||||
view = response.views[0]
|
||||
assert view.name == "order_rollup"
|
||||
assert view.source_type == "sql"
|
||||
assert "SELECT status, SUM(amount) AS total_amount" in (view.sql or "")
|
||||
assert [column.name for column in view.columns] == ["status", "total_amount"]
|
||||
assert response.skipped_views == []
|
||||
assert response.warnings == []
|
||||
6
python/klo-daemon/tests/test_package.py
Normal file
6
python/klo-daemon/tests/test_package.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from klo_daemon import PACKAGE_NAME, VERSION
|
||||
|
||||
|
||||
def test_package_metadata() -> None:
|
||||
assert PACKAGE_NAME == "klo-daemon"
|
||||
assert VERSION == "0.1.0"
|
||||
64
python/klo-daemon/tests/test_semantic_layer.py
Normal file
64
python/klo-daemon/tests/test_semantic_layer.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from klo_daemon.semantic_layer import (
|
||||
SemanticLayerQueryRequest,
|
||||
ValidateSourcesRequest,
|
||||
query_semantic_layer,
|
||||
validate_semantic_layer,
|
||||
)
|
||||
|
||||
|
||||
ORDERS_SOURCE = {
|
||||
"name": "orders",
|
||||
"table": "public.orders",
|
||||
"grain": ["id"],
|
||||
"columns": [
|
||||
{"name": "id", "type": "number"},
|
||||
{"name": "status", "type": "string"},
|
||||
{"name": "amount", "type": "number"},
|
||||
],
|
||||
"joins": [],
|
||||
"measures": [
|
||||
{"name": "order_count", "expr": "count(*)"},
|
||||
{"name": "revenue", "expr": "sum(amount)"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_query_semantic_layer_generates_sql_and_plan() -> None:
|
||||
response = query_semantic_layer(
|
||||
SemanticLayerQueryRequest(
|
||||
sources=[ORDERS_SOURCE],
|
||||
dialect="postgres",
|
||||
query={
|
||||
"measures": ["orders.order_count"],
|
||||
"dimensions": ["orders.status"],
|
||||
"limit": 25,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.dialect == "postgres"
|
||||
assert "public.orders" in response.sql
|
||||
assert "orders.status" in response.sql
|
||||
assert response.columns[0]["name"] == "orders.status"
|
||||
assert response.columns[1]["name"] == "orders.order_count"
|
||||
assert response.plan["sources_used"] == ["orders"]
|
||||
|
||||
|
||||
def test_validate_semantic_layer_reports_duplicate_measure_names() -> None:
|
||||
invalid_source = {
|
||||
**ORDERS_SOURCE,
|
||||
"measures": [
|
||||
{"name": "revenue", "expr": "sum(amount)"},
|
||||
{"name": "revenue", "expr": "sum(amount)"},
|
||||
],
|
||||
}
|
||||
|
||||
response = validate_semantic_layer(
|
||||
ValidateSourcesRequest(sources=[invalid_source], dialect="postgres")
|
||||
)
|
||||
|
||||
assert response.valid is False
|
||||
assert any("Duplicate measure" in error for error in response.errors)
|
||||
assert response.warnings == []
|
||||
161
python/klo-daemon/tests/test_source_generation.py
Normal file
161
python/klo-daemon/tests/test_source_generation.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from klo_daemon.source_generation import (
|
||||
ColumnInput,
|
||||
GenerateSourcesRequest,
|
||||
LinkInput,
|
||||
TableInput,
|
||||
generate_sources,
|
||||
generate_sources_response,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_sources_maps_tables_columns_measures_and_joins() -> None:
|
||||
response = generate_sources_response(
|
||||
GenerateSourcesRequest(
|
||||
tables=[
|
||||
TableInput(
|
||||
name="orders",
|
||||
db="public",
|
||||
comment="Orders table",
|
||||
columns=[
|
||||
ColumnInput(
|
||||
name="id",
|
||||
type="integer",
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
comment="Order ID",
|
||||
),
|
||||
ColumnInput(name="customer_id", type="integer"),
|
||||
ColumnInput(
|
||||
name="amount", type="decimal", comment="Order amount"
|
||||
),
|
||||
ColumnInput(name="created_at", type="timestamp"),
|
||||
ColumnInput(name="status", type="varchar"),
|
||||
],
|
||||
),
|
||||
TableInput(
|
||||
name="customers",
|
||||
db="public",
|
||||
columns=[
|
||||
ColumnInput(name="id", type="integer", primary_key=True),
|
||||
ColumnInput(name="email", type="varchar"),
|
||||
],
|
||||
),
|
||||
],
|
||||
links=[
|
||||
LinkInput(
|
||||
from_table="orders",
|
||||
from_column="customer_id",
|
||||
to_table="customers",
|
||||
to_column="id",
|
||||
relationship_type="MANY_TO_ONE",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
assert response.source_count == 2
|
||||
sources = {source["name"]: source for source in response.sources}
|
||||
assert sources["orders"]["description"] == "Orders table"
|
||||
assert sources["orders"]["table"] == "public.orders"
|
||||
assert sources["orders"]["grain"] == ["id"]
|
||||
assert sources["orders"]["columns"] == [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "number",
|
||||
"visibility": "public",
|
||||
"role": "default",
|
||||
"description": "Order ID",
|
||||
},
|
||||
{
|
||||
"name": "customer_id",
|
||||
"type": "number",
|
||||
"visibility": "public",
|
||||
"role": "default",
|
||||
},
|
||||
{
|
||||
"name": "amount",
|
||||
"type": "number",
|
||||
"visibility": "public",
|
||||
"role": "default",
|
||||
"description": "Order amount",
|
||||
},
|
||||
{"name": "created_at", "type": "time", "visibility": "public", "role": "time"},
|
||||
{"name": "status", "type": "string", "visibility": "public", "role": "default"},
|
||||
]
|
||||
assert sources["orders"]["joins"] == [
|
||||
{
|
||||
"to": "customers",
|
||||
"on": "customer_id = customers.id",
|
||||
"relationship": "many_to_one",
|
||||
}
|
||||
]
|
||||
assert [measure["name"] for measure in sources["orders"]["measures"]] == [
|
||||
"record_count",
|
||||
"total_amount",
|
||||
"avg_amount",
|
||||
]
|
||||
assert sources["orders"]["measures"][0]["expr"] == "count(id)"
|
||||
assert sources["orders"]["measures"][1]["expr"] == "sum(amount)"
|
||||
assert sources["orders"]["measures"][2]["expr"] == "avg(amount)"
|
||||
assert sources["customers"]["joins"] == [
|
||||
{
|
||||
"to": "orders",
|
||||
"on": "id = orders.customer_id",
|
||||
"relationship": "one_to_many",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_generate_sources_aliases_multiple_joins_to_same_table() -> None:
|
||||
sources = generate_sources(
|
||||
GenerateSourcesRequest(
|
||||
tables=[
|
||||
TableInput(
|
||||
name="orders",
|
||||
columns=[
|
||||
ColumnInput(name="id", type="integer", primary_key=True),
|
||||
ColumnInput(name="buyer_id", type="integer"),
|
||||
ColumnInput(name="seller_id", type="integer"),
|
||||
],
|
||||
),
|
||||
TableInput(
|
||||
name="users",
|
||||
columns=[ColumnInput(name="id", type="integer", primary_key=True)],
|
||||
),
|
||||
],
|
||||
links=[
|
||||
LinkInput(
|
||||
from_table="orders",
|
||||
from_column="buyer_id",
|
||||
to_table="users",
|
||||
to_column="id",
|
||||
relationship_type="many_to_one",
|
||||
),
|
||||
LinkInput(
|
||||
from_table="orders",
|
||||
from_column="seller_id",
|
||||
to_table="users",
|
||||
to_column="id",
|
||||
relationship_type="many_to_one",
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
orders = next(source for source in sources if source["name"] == "orders")
|
||||
assert orders["joins"] == [
|
||||
{
|
||||
"to": "users",
|
||||
"on": "buyer_id = users.id",
|
||||
"relationship": "many_to_one",
|
||||
"alias": "users_buyer_id",
|
||||
},
|
||||
{
|
||||
"to": "users",
|
||||
"on": "seller_id = users.id",
|
||||
"relationship": "many_to_one",
|
||||
"alias": "users_seller_id",
|
||||
},
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue