From 3934b86af4ea93a1199367cc1277afa6215f1610 Mon Sep 17 00:00:00 2001 From: Leon Date: Wed, 9 Aug 2023 16:53:01 -0700 Subject: [PATCH 1/3] lancedb base implementation --- metagpt/document_store/lancedb_store.py | 87 +++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 88 insertions(+) create mode 100644 metagpt/document_store/lancedb_store.py diff --git a/metagpt/document_store/lancedb_store.py b/metagpt/document_store/lancedb_store.py new file mode 100644 index 000000000..81ca47f86 --- /dev/null +++ b/metagpt/document_store/lancedb_store.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/9 15:42 +@Author : unkn-wn (Leon Yee) +@File : lancedb_store.py +""" +import lancedb +import pyarrow as pa +import pandas as pd + + +class LanceStore: + def __init__(self, name): + db = lancedb.connect('./data/lancedb') + self.db = db + self.name = name + self.table = None + + def create_table(self, columns: list): + # Create table given the columns as a list. + df = pd.DataFrame(columns=columns) + schema = pa.Schema.from_pandas(df) + schema = schema.remove_metadata() + schema = schema.remove(len(schema) - 1) + + self.table = self.db.create_table(self.name, schema) + + def search(self, query, n_results=2, metric="L2", nprobes=20, **kwargs): + # This assumes query is a vector embedding + # kwargs can be used for optional filtering + # .select - only searches the specified columns + # .where - SQL syntax filtering for metadata (e.g. where("price > 100")) + # .metric - specifies the distance metric to use + # .nprobes - values will yield better recall (more likely to find vectors if they exist) at the expense of latency. + results = self.table \ + .search(query) \ + .limit(n_results) \ + .select(kwargs.get('select')) \ + .where(kwargs.get('where')) \ + .metric(metric) \ + .nprobes(nprobes) \ + .to_df() + return results + + def persist(self): + raise NotImplementedError + + def write(self, data, metadatas, ids): + # This function is similar to add(), but it's for more generalized updates + # "data" is the list of embeddings + # Inserts into table by expanding metadatas into a dataframe: [{'vector', 'id', 'meta', 'meta2'}, ...] + if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first") + + documents = [] + for i in range(len(data)): + row = { + 'vector': data[i], + 'id': ids[i] + } + row.update(metadatas[i]) + documents.append(row) + + return self.table.add(documents) + + def add(self, data, metadata, _id): + # This function is for adding individual documents + # It assumes you're passing in a single vector embedding, metadata, and id + if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first") + + row = { + 'vector': data, + 'id': _id + } + row.update(metadata) + + return self.table.add([row]) + + def delete(self, _id): + # This function deletes a row by id. + # LanceDB delete syntax uses SQL syntax, so you can use "in" or "=" + if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first") + + if isinstance(_id, str): + return self.table.delete(f"id = '{_id}'") + else: + return self.table.delete(f"id = {_id}") diff --git a/requirements.txt b/requirements.txt index 452e2d092..f9068dd7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ faiss_cpu==1.7.4 fire==0.4.0 # godot==0.1.1 # google_api_python_client==2.93.0 +lancedb==0.1.16 langchain==0.0.231 loguru==0.6.0 meilisearch==0.21.0 From 2e28ea6927a3a692e7265decde1a4543194e2ef7 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 10 Aug 2023 12:00:28 -0700 Subject: [PATCH 2/3] test --- metagpt/document_store/lancedb_store.py | 10 ++++++- .../document_store/test_lancedb_store.py | 28 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 tests/metagpt/document_store/test_lancedb_store.py diff --git a/metagpt/document_store/lancedb_store.py b/metagpt/document_store/lancedb_store.py index 81ca47f86..cefb70ea3 100644 --- a/metagpt/document_store/lancedb_store.py +++ b/metagpt/document_store/lancedb_store.py @@ -8,6 +8,7 @@ import lancedb import pyarrow as pa import pandas as pd +import shutil, os class LanceStore: @@ -24,7 +25,7 @@ class LanceStore: schema = schema.remove_metadata() schema = schema.remove(len(schema) - 1) - self.table = self.db.create_table(self.name, schema) + self.table = self.db.create_table(self.name, schema=schema) def search(self, query, n_results=2, metric="L2", nprobes=20, **kwargs): # This assumes query is a vector embedding @@ -85,3 +86,10 @@ class LanceStore: return self.table.delete(f"id = '{_id}'") else: return self.table.delete(f"id = {_id}") + + def drop(self, name): + # This function drops a table, if it exists. + + path = os.path.join(self.db.uri, name + '.lance') + if os.path.exists(path): + shutil.rmtree(path) \ No newline at end of file diff --git a/tests/metagpt/document_store/test_lancedb_store.py b/tests/metagpt/document_store/test_lancedb_store.py new file mode 100644 index 000000000..1d83ef30b --- /dev/null +++ b/tests/metagpt/document_store/test_lancedb_store.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/9 15:42 +@Author : unkn-wn (Leon Yee) +@File : test_lancedb_store.py +""" +from metagpt.document_store.lancedb_store import LanceStore +import random + +def test_lance_store(): + + # This simply establishes the connection to the database, so we can drop the table if it exists + store = LanceStore('test') + + store.drop('test') + + store.create_table(['vector', 'id', 'meta', 'meta2']) + + store.write(data=[[random.random() for _ in range(100)] for _ in range(2)], + metadatas=[{"source": "google-docs"}, {"source": "notion"}], + ids=["doc1", "doc2"]) + + store.add(data=[random.random() for _ in range(100)], metadatas={"source": "notion"}, ids="doc3") + + result = store.search([random.random() for _ in range(100)], n_results=3) + print(result) + assert(len(result) > 0) From 4a48955a5d005b4731cbd4155a863e11c29893c7 Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 10 Aug 2023 13:38:39 -0700 Subject: [PATCH 3/3] finished unit tests, changed to have dynamic types --- metagpt/document_store/lancedb_store.py | 27 ++++++++----------- .../document_store/test_lancedb_store.py | 13 +++++---- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/metagpt/document_store/lancedb_store.py b/metagpt/document_store/lancedb_store.py index cefb70ea3..b366fa650 100644 --- a/metagpt/document_store/lancedb_store.py +++ b/metagpt/document_store/lancedb_store.py @@ -6,8 +6,6 @@ @File : lancedb_store.py """ import lancedb -import pyarrow as pa -import pandas as pd import shutil, os @@ -18,15 +16,6 @@ class LanceStore: self.name = name self.table = None - def create_table(self, columns: list): - # Create table given the columns as a list. - df = pd.DataFrame(columns=columns) - schema = pa.Schema.from_pandas(df) - schema = schema.remove_metadata() - schema = schema.remove(len(schema) - 1) - - self.table = self.db.create_table(self.name, schema=schema) - def search(self, query, n_results=2, metric="L2", nprobes=20, **kwargs): # This assumes query is a vector embedding # kwargs can be used for optional filtering @@ -34,6 +23,8 @@ class LanceStore: # .where - SQL syntax filtering for metadata (e.g. where("price > 100")) # .metric - specifies the distance metric to use # .nprobes - values will yield better recall (more likely to find vectors if they exist) at the expense of latency. + if self.table == None: raise Exception("Table not created yet, please add data first.") + results = self.table \ .search(query) \ .limit(n_results) \ @@ -51,7 +42,6 @@ class LanceStore: # This function is similar to add(), but it's for more generalized updates # "data" is the list of embeddings # Inserts into table by expanding metadatas into a dataframe: [{'vector', 'id', 'meta', 'meta2'}, ...] - if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first") documents = [] for i in range(len(data)): @@ -62,12 +52,14 @@ class LanceStore: row.update(metadatas[i]) documents.append(row) - return self.table.add(documents) + if self.table != None: + self.table.add(documents) + else: + self.table = self.db.create_table(self.name, documents) def add(self, data, metadata, _id): # This function is for adding individual documents # It assumes you're passing in a single vector embedding, metadata, and id - if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first") row = { 'vector': data, @@ -75,12 +67,15 @@ class LanceStore: } row.update(metadata) - return self.table.add([row]) + if self.table != None: + self.table.add([row]) + else: + self.table = self.db.create_table(self.name, [row]) def delete(self, _id): # This function deletes a row by id. # LanceDB delete syntax uses SQL syntax, so you can use "in" or "=" - if self.table == None: raise Exception("Table not created yet, please use create_table([columns]) first") + if self.table == None: raise Exception("Table not created yet, please add data first") if isinstance(_id, str): return self.table.delete(f"id = '{_id}'") diff --git a/tests/metagpt/document_store/test_lancedb_store.py b/tests/metagpt/document_store/test_lancedb_store.py index 1d83ef30b..9c2f9fb42 100644 --- a/tests/metagpt/document_store/test_lancedb_store.py +++ b/tests/metagpt/document_store/test_lancedb_store.py @@ -6,8 +6,10 @@ @File : test_lancedb_store.py """ from metagpt.document_store.lancedb_store import LanceStore +import pytest import random +@pytest def test_lance_store(): # This simply establishes the connection to the database, so we can drop the table if it exists @@ -15,14 +17,15 @@ def test_lance_store(): store.drop('test') - store.create_table(['vector', 'id', 'meta', 'meta2']) - store.write(data=[[random.random() for _ in range(100)] for _ in range(2)], metadatas=[{"source": "google-docs"}, {"source": "notion"}], ids=["doc1", "doc2"]) - store.add(data=[random.random() for _ in range(100)], metadatas={"source": "notion"}, ids="doc3") + store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3") result = store.search([random.random() for _ in range(100)], n_results=3) - print(result) - assert(len(result) > 0) + assert(len(result) == 3) + + store.delete("doc2") + result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine') + assert(len(result) == 1) \ No newline at end of file