rm sd, fix qdrant

This commit is contained in:
yzlin 2024-01-05 16:43:41 +08:00
parent a3dc6aa7e3
commit bd4a35fd94
3 changed files with 58 additions and 49 deletions

View file

@ -29,6 +29,16 @@ points = [
]
def assert_almost_equal(actual, expected):
delta = 1e-10
if isinstance(expected, list):
assert len(actual) == len(expected)
for ac, exp in zip(actual, expected):
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
else:
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
def test_qdrant_store():
qdrant_connection = QdrantConnection(memory=True)
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
@ -42,30 +52,30 @@ def test_qdrant_store():
qdrant_store.add("Book", points)
results = qdrant_store.search("Book", query=[1.0, 1.0])
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert_almost_equal(results[0]["score"], 0.999106722578389)
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert_almost_equal(results[1]["score"], 0.9961650411397226)
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
assert_almost_equal(results[0]["score"], 0.999106722578389)
assert_almost_equal(results[0]["vector"], [0.7363563179969788, 0.6765939593315125])
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
assert_almost_equal(results[1]["score"], 0.9961650411397226)
assert_almost_equal(results[1]["vector"], [0.7662628889083862, 0.6425272226333618])
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
)
assert results[0]["id"] == 8
assert results[0]["score"] == 0.9100373450784073
assert_almost_equal(results[0]["score"], 0.9100373450784073)
assert results[1]["id"] == 9
assert results[1]["score"] == 0.7127610621127889
assert_almost_equal(results[1]["score"], 0.7127610621127889)
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
return_vector=True,
)
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]
assert results[1]["vector"] == [0.9999677538871765, 0.00802854634821415]
assert_almost_equal(results[0]["vector"], [0.35037919878959656, 0.9366079568862915])
assert_almost_equal(results[1]["vector"], [0.9999677538871765, 0.00802854634821415])

View file

@ -1,26 +0,0 @@
# -*- coding: utf-8 -*-
# @Date : 2023/7/22 02:40
# @Author : stellahong (stellahong@deepwisdom.ai)
#
import os
from metagpt.config import CONFIG
from metagpt.tools.sd_engine import SDEngine
def test_sd_engine_init():
sd_engine = SDEngine()
assert sd_engine.payload["seed"] == -1
def test_sd_engine_generate_prompt():
sd_engine = SDEngine()
sd_engine.construct_payload(prompt="test")
assert sd_engine.payload["prompt"] == "test"
async def test_sd_engine_run_t2i():
sd_engine = SDEngine()
await sd_engine.run_t2i(prompts=["test"])
img_path = CONFIG.workspace_path / "resources" / "SD_Output" / "output_0.png"
assert os.path.exists(img_path)