diff --git a/sqlite-vec.c b/sqlite-vec.c index 81d3c29..f2e02f0 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -1814,6 +1814,9 @@ enum Vec0TokenType { TOKEN_TYPE_RBRACKET, TOKEN_TYPE_PLUS, TOKEN_TYPE_EQ, + TOKEN_TYPE_LPAREN, + TOKEN_TYPE_RPAREN, + TOKEN_TYPE_COMMA, }; struct Vec0Token { enum Vec0TokenType token_type; @@ -1864,6 +1867,24 @@ int vec0_token_next(char *start, char *end, struct Vec0Token *out) { out->end = ptr; out->token_type = TOKEN_TYPE_EQ; return VEC0_TOKEN_RESULT_SOME; + } else if (curr == '(') { + ptr++; + out->start = ptr; + out->end = ptr; + out->token_type = TOKEN_TYPE_LPAREN; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == ')') { + ptr++; + out->start = ptr; + out->end = ptr; + out->token_type = TOKEN_TYPE_RPAREN; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == ',') { + ptr++; + out->start = ptr; + out->end = ptr; + out->token_type = TOKEN_TYPE_COMMA; + return VEC0_TOKEN_RESULT_SOME; } else if (is_alpha(curr)) { char *start = ptr; while (ptr < end && (is_alpha(*ptr) || is_digit(*ptr) || *ptr == '_')) { diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h index 1e62da2..a540849 100644 --- a/tests/sqlite-vec-internal.h +++ b/tests/sqlite-vec-internal.h @@ -23,6 +23,9 @@ enum Vec0TokenType { TOKEN_TYPE_RBRACKET = 3, TOKEN_TYPE_PLUS = 4, TOKEN_TYPE_EQ = 5, + TOKEN_TYPE_LPAREN = 6, + TOKEN_TYPE_RPAREN = 7, + TOKEN_TYPE_COMMA = 8, }; #define VEC0_TOKEN_RESULT_EOF 1 diff --git a/tests/test-unit.c b/tests/test-unit.c index 2373e1c..269a990 100644 --- a/tests/test-unit.c +++ b/tests/test-unit.c @@ -108,6 +108,24 @@ void test_vec0_token_next() { assert(token.token_type == TOKEN_TYPE_DIGIT); assert(token.end - token.start == 2); + // Left paren + input = "("; + rc = vec0_token_next(input, input + 1, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_LPAREN); + + // Right paren + input = ")"; + rc = vec0_token_next(input, input + 1, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_RPAREN); + + // Comma + input = ","; + rc = vec0_token_next(input, input + 1, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_COMMA); + printf(" All vec0_token_next tests passed.\n"); } @@ -229,6 +247,60 @@ void test_vec0_scanner() { assert(rc == VEC0_TOKEN_RESULT_EOF); } + // Scan "diskann(k=v, k2=v2)" + { + const char *input = "diskann(k=v, k2=v2)"; + vec0_scanner_init(&scanner, input, (int)strlen(input)); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(strncmp(token.start, "diskann", 7) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_LPAREN); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(strncmp(token.start, "k", 1) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_EQ); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(strncmp(token.start, "v", 1) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_COMMA); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(strncmp(token.start, "k2", 2) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_EQ); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_IDENTIFIER); + assert(strncmp(token.start, "v2", 2) == 0); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_SOME); + assert(token.token_type == TOKEN_TYPE_RPAREN); + + rc = vec0_scanner_next(&scanner, &token); + assert(rc == VEC0_TOKEN_RESULT_EOF); + } + printf(" All vec0_scanner tests passed.\n"); }