Merge pull request #207 from hezhaozhao-git/qdrant

add qdrant store
This commit is contained in:
stellaHSR 2023-08-14 22:56:18 +08:00 committed by GitHub
commit 53a96ff119
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 209 additions and 2 deletions

View file

@ -15,7 +15,7 @@ class BaseStore(ABC):
"""FIXME: consider add_index, set_index and think 颗粒度"""
@abstractmethod
def search(self, query, *args, **kwargs):
def search(self, *args, **kwargs):
raise NotImplementedError
@abstractmethod

View file

@ -0,0 +1,129 @@
from dataclasses import dataclass
from typing import List
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, PointStruct, VectorParams
from metagpt.document_store.base_store import BaseStore
@dataclass
class QdrantConnection:
"""
Args:
url: qdrant url
host: qdrant host
port: qdrant port
memory: qdrant service use memory mode
api_key: qdrant cloud api_key
"""
url: str = None
host: str = None
port: int = None
memory: bool = False
api_key: str = None
class QdrantStore(BaseStore):
def __init__(self, connect: QdrantConnection):
if connect.memory:
self.client = QdrantClient(":memory:")
elif connect.url:
self.client = QdrantClient(url=connect.url, api_key=connect.api_key)
elif connect.host and connect.port:
self.client = QdrantClient(
host=connect.host, port=connect.port, api_key=connect.api_key
)
else:
raise Exception("please check QdrantConnection.")
def create_collection(
self,
collection_name: str,
vectors_config: VectorParams,
force_recreate=False,
**kwargs,
):
"""
create a collection
Args:
collection_name: collection name
vectors_config: VectorParams object,detail in https://github.com/qdrant/qdrant-client
force_recreate: default is False, if True, will delete exists collection,then create it
**kwargs:
Returns:
"""
try:
self.client.get_collection(collection_name)
if force_recreate:
res = self.client.recreate_collection(
collection_name, vectors_config=vectors_config, **kwargs
)
return res
return True
except: # noqa: E722
return self.client.recreate_collection(
collection_name, vectors_config=vectors_config, **kwargs
)
def has_collection(self, collection_name: str):
try:
self.client.get_collection(collection_name)
return True
except: # noqa: E722
return False
def delete_collection(self, collection_name: str, timeout=60):
res = self.client.delete_collection(collection_name, timeout=timeout)
if not res:
raise Exception(f"Delete collection {collection_name} failed.")
def add(self, collection_name: str, points: List[PointStruct]):
"""
add some vector data to qdrant
Args:
collection_name: collection name
points: list of PointStruct object, about PointStruct detail in https://github.com/qdrant/qdrant-client
Returns: NoneX
"""
# self.client.upload_records()
self.client.upsert(
collection_name,
points,
)
def search(
self,
collection_name: str,
query: List[float],
query_filter: Filter = None,
k=10,
return_vector=False,
):
"""
vector search
Args:
collection_name: qdrant collection name
query: input vector
query_filter: Filter object, detail in https://github.com/qdrant/qdrant-client
k: return the most similar k pieces of data
return_vector: whether return vector
Returns: list of dict
"""
hits = self.client.search(
collection_name=collection_name,
query_vector=query,
query_filter=query_filter,
limit=k,
with_vectors=return_vector,
)
return [hit.__dict__ for hit in hits]
def write(self, *args, **kwargs):
pass

View file

@ -17,7 +17,7 @@ numpy==1.24.3
openai==0.27.8
openpyxl
pandas==1.4.1
pydantic==1.10.7
pydantic==1.10.8
#pygame==2.1.3
#pymilvus==2.2.8
pytest==7.2.2
@ -36,3 +36,4 @@ anthropic==0.3.6
typing-inspect==0.8.0
typing_extensions==4.5.0
libcst==1.0.1
qdrant-client==1.4.0

View file

@ -0,0 +1,77 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/6/11 21:08
@Author : hezhaozhao
@File : test_qdrant_store.py
"""
import random
from qdrant_client.models import (
Distance,
FieldCondition,
Filter,
PointStruct,
Range,
VectorParams,
)
from metagpt.document_store.qdrant_store import QdrantConnection, QdrantStore
seed_value = 42
random.seed(seed_value)
vectors = [[random.random() for _ in range(2)] for _ in range(10)]
points = [
PointStruct(
id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10}
)
for idx, vector in enumerate(vectors)
]
def test_milvus_store():
qdrant_connection = QdrantConnection(memory=True)
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
qdrant_store = QdrantStore(qdrant_connection)
qdrant_store.create_collection("Book", vectors_config, force_recreate=True)
assert qdrant_store.has_collection("Book") is True
qdrant_store.delete_collection("Book")
assert qdrant_store.has_collection("Book") is False
qdrant_store.create_collection("Book", vectors_config)
assert qdrant_store.has_collection("Book") is True
qdrant_store.add("Book", points)
results = qdrant_store.search("Book", query=[1.0, 1.0])
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[1]["score"] == 7
assert results[1]["score"] == 0.9961650411397226
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
assert results[1]["score"] == 7
assert results[1]["score"] == 0.9961650411397226
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
),
)
assert results[0]["id"] == 8
assert results[0]["score"] == 0.9100373450784073
assert results[1]["id"] == 9
assert results[1]["score"] == 0.7127610621127889
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
),
return_vector=True,
)
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]
assert results[1]["vector"] == [0.9999677538871765, 0.00802854634821415]