finished unit tests, changed to have dynamic types

This commit is contained in:
Leon 2023-08-10 13:38:39 -07:00
parent 46ada5a7f9
commit 4a48955a5d
2 changed files with 19 additions and 21 deletions

View file

@ -6,8 +6,6 @@
@File : lancedb_store.py
"""
import lancedb
import pyarrow as pa
import pandas as pd
import shutil, os
@ -18,15 +16,6 @@ class LanceStore:
self.name = name
self.table = None
def create_table(self, columns: list):
# Create table given the columns as a list.
df = pd.DataFrame(columns=columns)
schema = pa.Schema.from_pandas(df)
schema = schema.remove_metadata()
schema = schema.remove(len(schema) - 1)
self.table = self.db.create_table(self.name, schema=schema)
def search(self, query, n_results=2, metric="L2", nprobes=20, **kwargs):
# This assumes query is a vector embedding
# kwargs can be used for optional filtering
@ -34,6 +23,8 @@ class LanceStore:
# .where - SQL syntax filtering for metadata (e.g. where("price > 100"))
# .metric - specifies the distance metric to use
# .nprobes - values will yield better recall (more likely to find vectors if they exist) at the expense of latency.
if self.table == None: raise Exception("Table not created yet, please add data first.")
results = self.table \
.search(query) \
.limit(n_results) \
@ -51,7 +42,6 @@ class LanceStore:
# This function is similar to add(), but it's for more generalized updates
# "data" is the list of embeddings
# Inserts into table by expanding metadatas into a dataframe: [{'vector', 'id', 'meta', 'meta2'}, ...]
if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first")
documents = []
for i in range(len(data)):
@ -62,12 +52,14 @@ class LanceStore:
row.update(metadatas[i])
documents.append(row)
return self.table.add(documents)
if self.table != None:
self.table.add(documents)
else:
self.table = self.db.create_table(self.name, documents)
def add(self, data, metadata, _id):
# This function is for adding individual documents
# It assumes you're passing in a single vector embedding, metadata, and id
if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first")
row = {
'vector': data,
@ -75,12 +67,15 @@ class LanceStore:
}
row.update(metadata)
return self.table.add([row])
if self.table != None:
self.table.add([row])
else:
self.table = self.db.create_table(self.name, [row])
def delete(self, _id):
# This function deletes a row by id.
# LanceDB delete syntax uses SQL syntax, so you can use "in" or "="
if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first")
if self.table == None: raise Exception("Table not created yet, please add data first")
if isinstance(_id, str):
return self.table.delete(f"id = '{_id}'")

View file

@ -6,8 +6,10 @@
@File : test_lancedb_store.py
"""
from metagpt.document_store.lancedb_store import LanceStore
import pytest
import random
@pytest
def test_lance_store():
# This simply establishes the connection to the database, so we can drop the table if it exists
@ -15,14 +17,15 @@ def test_lance_store():
store.drop('test')
store.create_table(['vector', 'id', 'meta', 'meta2'])
store.write(data=[[random.random() for _ in range(100)] for _ in range(2)],
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
ids=["doc1", "doc2"])
store.add(data=[random.random() for _ in range(100)], metadatas={"source": "notion"}, ids="doc3")
store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3")
result = store.search([random.random() for _ in range(100)], n_results=3)
print(result)
assert(len(result) > 0)
assert(len(result) == 3)
store.delete("doc2")
result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine')
assert(len(result) == 1)