From 2e28ea6927a3a692e7265decde1a4543194e2ef7 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 10 Aug 2023 12:00:28 -0700 Subject: [PATCH] test --- metagpt/document_store/lancedb_store.py | 10 ++++++- .../document_store/test_lancedb_store.py | 28 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 tests/metagpt/document_store/test_lancedb_store.py diff --git a/metagpt/document_store/lancedb_store.py b/metagpt/document_store/lancedb_store.py index 81ca47f86..cefb70ea3 100644 --- a/metagpt/document_store/lancedb_store.py +++ b/metagpt/document_store/lancedb_store.py @@ -8,6 +8,7 @@ import lancedb import pyarrow as pa import pandas as pd +import shutil, os class LanceStore: @@ -24,7 +25,7 @@ class LanceStore: schema = schema.remove_metadata() schema = schema.remove(len(schema) - 1) - self.table = self.db.create_table(self.name, schema) + self.table = self.db.create_table(self.name, schema=schema) def search(self, query, n_results=2, metric="L2", nprobes=20, **kwargs): # This assumes query is a vector embedding @@ -85,3 +86,10 @@ class LanceStore: return self.table.delete(f"id = '{_id}'") else: return self.table.delete(f"id = {_id}") + + def drop(self, name): + # This function drops a table, if it exists. + + path = os.path.join(self.db.uri, name + '.lance') + if os.path.exists(path): + shutil.rmtree(path) \ No newline at end of file diff --git a/tests/metagpt/document_store/test_lancedb_store.py b/tests/metagpt/document_store/test_lancedb_store.py new file mode 100644 index 000000000..1d83ef30b --- /dev/null +++ b/tests/metagpt/document_store/test_lancedb_store.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/9 15:42 +@Author : unkn-wn (Leon Yee) +@File : test_lancedb_store.py +""" +from metagpt.document_store.lancedb_store import LanceStore +import random + +def test_lance_store(): + + # This simply establishes the connection to the database, so we can drop the table if it exists + store = LanceStore('test') + + store.drop('test') + + store.create_table(['vector', 'id', 'meta', 'meta2']) + + store.write(data=[[random.random() for _ in range(100)] for _ in range(2)], + metadatas=[{"source": "google-docs"}, {"source": "notion"}], + ids=["doc1", "doc2"]) + + store.add(data=[random.random() for _ in range(100)], metadatas={"source": "notion"}, ids="doc3") + + result = store.search([random.random() for _ in range(100)], n_results=3) + print(result) + assert(len(result) > 0)