diff --git a/metagpt/document_store/lancedb_store.py b/metagpt/document_store/lancedb_store.py new file mode 100644 index 000000000..b366fa650 --- /dev/null +++ b/metagpt/document_store/lancedb_store.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/9 15:42 +@Author : unkn-wn (Leon Yee) +@File : lancedb_store.py +""" +import lancedb +import shutil, os + + +class LanceStore: + def __init__(self, name): + db = lancedb.connect('./data/lancedb') + self.db = db + self.name = name + self.table = None + + def search(self, query, n_results=2, metric="L2", nprobes=20, **kwargs): + # This assumes query is a vector embedding + # kwargs can be used for optional filtering + # .select - only searches the specified columns + # .where - SQL syntax filtering for metadata (e.g. where("price > 100")) + # .metric - specifies the distance metric to use + # .nprobes - values will yield better recall (more likely to find vectors if they exist) at the expense of latency. + if self.table == None: raise Exception("Table not created yet, please add data first.") + + results = self.table \ + .search(query) \ + .limit(n_results) \ + .select(kwargs.get('select')) \ + .where(kwargs.get('where')) \ + .metric(metric) \ + .nprobes(nprobes) \ + .to_df() + return results + + def persist(self): + raise NotImplementedError + + def write(self, data, metadatas, ids): + # This function is similar to add(), but it's for more generalized updates + # "data" is the list of embeddings + # Inserts into table by expanding metadatas into a dataframe: [{'vector', 'id', 'meta', 'meta2'}, ...] + + documents = [] + for i in range(len(data)): + row = { + 'vector': data[i], + 'id': ids[i] + } + row.update(metadatas[i]) + documents.append(row) + + if self.table != None: + self.table.add(documents) + else: + self.table = self.db.create_table(self.name, documents) + + def add(self, data, metadata, _id): + # This function is for adding individual documents + # It assumes you're passing in a single vector embedding, metadata, and id + + row = { + 'vector': data, + 'id': _id + } + row.update(metadata) + + if self.table != None: + self.table.add([row]) + else: + self.table = self.db.create_table(self.name, [row]) + + def delete(self, _id): + # This function deletes a row by id. + # LanceDB delete syntax uses SQL syntax, so you can use "in" or "=" + if self.table == None: raise Exception("Table not created yet, please add data first") + + if isinstance(_id, str): + return self.table.delete(f"id = '{_id}'") + else: + return self.table.delete(f"id = {_id}") + + def drop(self, name): + # This function drops a table, if it exists. + + path = os.path.join(self.db.uri, name + '.lance') + if os.path.exists(path): + shutil.rmtree(path) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index efc2ea3e7..741ae74df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ faiss_cpu==1.7.4 fire==0.4.0 # godot==0.1.1 # google_api_python_client==2.93.0 +lancedb==0.1.16 langchain==0.0.231 loguru==0.6.0 meilisearch==0.21.0 diff --git a/tests/metagpt/document_store/test_lancedb_store.py b/tests/metagpt/document_store/test_lancedb_store.py new file mode 100644 index 000000000..9c2f9fb42 --- /dev/null +++ b/tests/metagpt/document_store/test_lancedb_store.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/9 15:42 +@Author : unkn-wn (Leon Yee) +@File : test_lancedb_store.py +""" +from metagpt.document_store.lancedb_store import LanceStore +import pytest +import random + +@pytest +def test_lance_store(): + + # This simply establishes the connection to the database, so we can drop the table if it exists + store = LanceStore('test') + + store.drop('test') + + store.write(data=[[random.random() for _ in range(100)] for _ in range(2)], + metadatas=[{"source": "google-docs"}, {"source": "notion"}], + ids=["doc1", "doc2"]) + + store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3") + + result = store.search([random.random() for _ in range(100)], n_results=3) + assert(len(result) == 3) + + store.delete("doc2") + result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine') + assert(len(result) == 1) \ No newline at end of file