diff --git a/sqlite-vec.c b/sqlite-vec.c index cf8e267..4d91c9e 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -557,6 +557,103 @@ static int int8_vec_from_value(sqlite3_value *value, i8 **vector, *cleanup = vector_cleanup_noop; return SQLITE_OK; } + + if (value_type == SQLITE_TEXT) { + const char *source = (const char *)sqlite3_value_text(value); + int source_len = sqlite3_value_bytes(value); + int i = 0; + + struct Array x; + int rc = array_init(&x, sizeof(i8), ceil(source_len / 2.0)); + if (rc != SQLITE_OK) { + return rc; + } + + // advance leading whitespace to first '[' + while (i < source_len) { + if (jsonIsspace(source[i])) { + i++; + continue; + } + if (source[i] == '[') { + break; + } + + *pzErr = sqlite3_mprintf( + "JSON array parsing error: Input does not start with '['"); + array_cleanup(&x); + return SQLITE_ERROR; + } + if (source[i] != '[') { + *pzErr = sqlite3_mprintf( + "JSON array parsing error: Input does not start with '['"); + array_cleanup(&x); + return SQLITE_ERROR; + } + int offset = i + 1; + + while (offset < source_len) { + char *ptr = (char *)&source[offset]; + char *endptr; + + errno = 0; + long result = strtol(ptr, &endptr, 10); + if ((errno != 0 && result == 0) + || (errno == ERANGE && + (result == LONG_MAX || result == LONG_MIN)) + ) { + sqlite3_free(x.z); + *pzErr = sqlite3_mprintf("JSON parsing error"); + return SQLITE_ERROR; + } + + if (endptr == ptr) { + if (*ptr != ']') { + sqlite3_free(x.z); + *pzErr = sqlite3_mprintf("JSON parsing error"); + return SQLITE_ERROR; + } + goto done; + } + + if (result < INT8_MIN || result > INT8_MAX) { + sqlite3_free(x.z); + *pzErr = sqlite3_mprintf("JSON parsing error: value out of range for int8"); + return SQLITE_ERROR; + } + + i8 res = (i8)result; + array_append(&x, (const void *)&res); + + offset += (endptr - ptr); + while (offset < source_len) { + if (jsonIsspace(source[offset])) { + offset++; + continue; + } + if (source[offset] == ',') { + offset++; + continue; + } + if (source[offset] == ']') + goto done; + break; + } + } + + done: + + if (x.length > 0) { + *vector = (i8 *)x.z; + *dimensions = x.length; + *cleanup = (vector_cleanup)sqlite3_free; + return SQLITE_OK; + } + sqlite3_free(x.z); + *pzErr = sqlite3_mprintf("zero-length vectors are not supported."); + return SQLITE_ERROR; + } + *pzErr = sqlite3_mprintf("Unknown type for int8 vector."); return SQLITE_ERROR; } diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 2acd75a..52f4c3d 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -211,6 +211,8 @@ def test_vec_int8(): vec_int8 = lambda *args: db.execute("select vec_int8(?)", args).fetchone()[0] assert vec_int8(b"\x00") == _int8([0]) assert vec_int8(b"\x00\x0f") == _int8([0, 15]) + assert vec_int8("[0]") == _int8([0]) + assert vec_int8("[1, 2, 3]") == _int8([1, 2, 3]) if SUPPORTS_SUBTYPE: assert db.execute("select subtype(vec_int8(?))", [b"\x00"]).fetchone()[0] == 225