From 4a48955a5d005b4731cbd4155a863e11c29893c7 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 10 Aug 2023 13:38:39 -0700 Subject: [PATCH] finished unit tests, changed to have dynamic types --- metagpt/document_store/lancedb_store.py | 27 ++++++++----------- .../document_store/test_lancedb_store.py | 13 +++++---- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/metagpt/document_store/lancedb_store.py b/metagpt/document_store/lancedb_store.py index cefb70ea3..b366fa650 100644 --- a/metagpt/document_store/lancedb_store.py +++ b/metagpt/document_store/lancedb_store.py @@ -6,8 +6,6 @@ @File : lancedb_store.py """ import lancedb -import pyarrow as pa -import pandas as pd import shutil, os @@ -18,15 +16,6 @@ class LanceStore: self.name = name self.table = None - def create_table(self, columns: list): - # Create table given the columns as a list. - df = pd.DataFrame(columns=columns) - schema = pa.Schema.from_pandas(df) - schema = schema.remove_metadata() - schema = schema.remove(len(schema) - 1) - - self.table = self.db.create_table(self.name, schema=schema) - def search(self, query, n_results=2, metric="L2", nprobes=20, **kwargs): # This assumes query is a vector embedding # kwargs can be used for optional filtering @@ -34,6 +23,8 @@ class LanceStore: # .where - SQL syntax filtering for metadata (e.g. where("price > 100")) # .metric - specifies the distance metric to use # .nprobes - values will yield better recall (more likely to find vectors if they exist) at the expense of latency. + if self.table == None: raise Exception("Table not created yet, please add data first.") + results = self.table \ .search(query) \ .limit(n_results) \ @@ -51,7 +42,6 @@ class LanceStore: # This function is similar to add(), but it's for more generalized updates # "data" is the list of embeddings # Inserts into table by expanding metadatas into a dataframe: [{'vector', 'id', 'meta', 'meta2'}, ...] - if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first") documents = [] for i in range(len(data)): @@ -62,12 +52,14 @@ class LanceStore: row.update(metadatas[i]) documents.append(row) - return self.table.add(documents) + if self.table != None: + self.table.add(documents) + else: + self.table = self.db.create_table(self.name, documents) def add(self, data, metadata, _id): # This function is for adding individual documents # It assumes you're passing in a single vector embedding, metadata, and id - if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first") row = { 'vector': data, @@ -75,12 +67,15 @@ class LanceStore: } row.update(metadata) - return self.table.add([row]) + if self.table != None: + self.table.add([row]) + else: + self.table = self.db.create_table(self.name, [row]) def delete(self, _id): # This function deletes a row by id. # LanceDB delete syntax uses SQL syntax, so you can use "in" or "=" - if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first") + if self.table == None: raise Exception("Table not created yet, please add data first") if isinstance(_id, str): return self.table.delete(f"id = '{_id}'") diff --git a/tests/metagpt/document_store/test_lancedb_store.py b/tests/metagpt/document_store/test_lancedb_store.py index 1d83ef30b..9c2f9fb42 100644 --- a/tests/metagpt/document_store/test_lancedb_store.py +++ b/tests/metagpt/document_store/test_lancedb_store.py @@ -6,8 +6,10 @@ @File : test_lancedb_store.py """ from metagpt.document_store.lancedb_store import LanceStore +import pytest import random +@pytest def test_lance_store(): # This simply establishes the connection to the database, so we can drop the table if it exists @@ -15,14 +17,15 @@ def test_lance_store(): store.drop('test') - store.create_table(['vector', 'id', 'meta', 'meta2']) - store.write(data=[[random.random() for _ in range(100)] for _ in range(2)], metadatas=[{"source": "google-docs"}, {"source": "notion"}], ids=["doc1", "doc2"]) - store.add(data=[random.random() for _ in range(100)], metadatas={"source": "notion"}, ids="doc3") + store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3") result = store.search([random.random() for _ in range(100)], n_results=3) - print(result) - assert(len(result) > 0) + assert(len(result) == 3) + + store.delete("doc2") + result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine') + assert(len(result) == 1) \ No newline at end of file