Initial release: iai-mcp v0.1.0
Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
This commit is contained in:
commit
f6b876fbe7
332 changed files with 97258 additions and 0 deletions
165
tests/test_learn_retrieval_policy.py
Normal file
165
tests/test_learn_retrieval_policy.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""Tests for LEARN-02 retrieval-policy RL + LEARN-05 meta-learning.
|
||||
|
||||
LEARN-02: implicit user feedback (used/corrected/re_asked) updates the D-13
|
||||
score weights (W_COSINE / W_AAAK / W_DEGREE / W_AGE).
|
||||
|
||||
LEARN-05: ε-greedy bandit over strategies picks best strategy per query type.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- feedback shape
|
||||
|
||||
|
||||
def test_retrieval_feedback_dataclass():
|
||||
from iai_mcp.learn import RetrievalFeedback
|
||||
|
||||
fb = RetrievalFeedback(
|
||||
query_type="fact_lookup",
|
||||
hit_ids=[uuid4(), uuid4()],
|
||||
used_ids=[],
|
||||
corrected=False,
|
||||
re_asked=False,
|
||||
)
|
||||
assert fb.query_type == "fact_lookup"
|
||||
assert len(fb.hit_ids) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- update_retrieval_weights
|
||||
|
||||
|
||||
def test_retrieval_feedback_used_boosts_weights():
|
||||
"""Higher use-rate -> W_COSINE goes up."""
|
||||
from iai_mcp.learn import RetrievalFeedback, update_retrieval_weights
|
||||
|
||||
ids = [uuid4() for _ in range(3)]
|
||||
fb = RetrievalFeedback(
|
||||
query_type="lookup",
|
||||
hit_ids=ids,
|
||||
used_ids=ids[:3], # used all hits
|
||||
corrected=False,
|
||||
re_asked=False,
|
||||
)
|
||||
before = {"W_COSINE": 1.0, "W_AAAK": 0.3, "W_DEGREE": 0.1, "W_AGE": 0.05}
|
||||
after = update_retrieval_weights(fb, before)
|
||||
assert after["W_COSINE"] > before["W_COSINE"]
|
||||
|
||||
|
||||
def test_retrieval_feedback_corrected_reduces_weights():
|
||||
from iai_mcp.learn import RetrievalFeedback, update_retrieval_weights
|
||||
|
||||
ids = [uuid4() for _ in range(3)]
|
||||
fb = RetrievalFeedback(
|
||||
query_type="lookup",
|
||||
hit_ids=ids,
|
||||
used_ids=[],
|
||||
corrected=True,
|
||||
re_asked=False,
|
||||
)
|
||||
before = {"W_COSINE": 1.0, "W_AAAK": 0.3, "W_DEGREE": 0.1, "W_AGE": 0.05}
|
||||
after = update_retrieval_weights(fb, before)
|
||||
assert after["W_COSINE"] < before["W_COSINE"]
|
||||
|
||||
|
||||
def test_retrieval_feedback_re_asked_reduces_weights():
|
||||
from iai_mcp.learn import RetrievalFeedback, update_retrieval_weights
|
||||
|
||||
ids = [uuid4() for _ in range(3)]
|
||||
fb = RetrievalFeedback(
|
||||
query_type="lookup",
|
||||
hit_ids=ids,
|
||||
used_ids=[],
|
||||
corrected=False,
|
||||
re_asked=True,
|
||||
)
|
||||
before = {"W_COSINE": 1.0, "W_AAAK": 0.3, "W_DEGREE": 0.1, "W_AGE": 0.05}
|
||||
after = update_retrieval_weights(fb, before)
|
||||
assert after["W_COSINE"] < before["W_COSINE"]
|
||||
|
||||
|
||||
def test_retrieval_weights_bounded():
|
||||
"""After many updates, weights stay in [0, 5]."""
|
||||
from iai_mcp.learn import MAX_WEIGHT, MIN_WEIGHT, RetrievalFeedback, update_retrieval_weights
|
||||
|
||||
ids = [uuid4() for _ in range(3)]
|
||||
# 1000 "used" feedbacks (continually boost)
|
||||
weights = {"W_COSINE": 1.0, "W_AAAK": 0.3, "W_DEGREE": 0.1, "W_AGE": 0.05}
|
||||
for _ in range(1000):
|
||||
fb = RetrievalFeedback(
|
||||
query_type="x", hit_ids=ids, used_ids=ids,
|
||||
corrected=False, re_asked=False,
|
||||
)
|
||||
weights = update_retrieval_weights(fb, weights)
|
||||
assert weights["W_COSINE"] <= MAX_WEIGHT
|
||||
assert weights["W_COSINE"] >= MIN_WEIGHT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- epsilon-greedy strategy
|
||||
|
||||
|
||||
def test_pick_retrieval_strategy_returns_string():
|
||||
from iai_mcp.learn import pick_retrieval_strategy
|
||||
|
||||
random.seed(42)
|
||||
s = pick_retrieval_strategy("fact_lookup", history={})
|
||||
assert isinstance(s, str)
|
||||
|
||||
|
||||
def test_pick_retrieval_strategy_epsilon_greedy():
|
||||
"""Over 200 calls, mostly picks the highest-mean strategy."""
|
||||
from iai_mcp.learn import pick_retrieval_strategy
|
||||
|
||||
random.seed(7)
|
||||
history = {
|
||||
"fact_lookup": {
|
||||
"pipeline_default": {"mean": 0.9, "n": 10},
|
||||
"greedy_2hop": {"mean": 0.1, "n": 10},
|
||||
"rich_club_first": {"mean": 0.2, "n": 10},
|
||||
}
|
||||
}
|
||||
picks = {"pipeline_default": 0, "greedy_2hop": 0, "rich_club_first": 0}
|
||||
for _ in range(200):
|
||||
s = pick_retrieval_strategy("fact_lookup", history)
|
||||
picks[s] = picks.get(s, 0) + 1
|
||||
# The best strategy (pipeline_default) should dominate at >= 60%.
|
||||
assert picks["pipeline_default"] > 120
|
||||
|
||||
|
||||
def test_pick_retrieval_strategy_no_history():
|
||||
"""Fresh query_type with no history -> returns a strategy anyway."""
|
||||
from iai_mcp.learn import pick_retrieval_strategy
|
||||
|
||||
random.seed(42)
|
||||
s = pick_retrieval_strategy("unseen", history={})
|
||||
assert isinstance(s, str)
|
||||
assert s in ("pipeline_default", "greedy_2hop", "rich_club_first")
|
||||
|
||||
|
||||
def test_pick_retrieval_strategy_custom_strategies():
|
||||
"""Caller can pass custom strategy list."""
|
||||
from iai_mcp.learn import pick_retrieval_strategy
|
||||
|
||||
random.seed(1)
|
||||
s = pick_retrieval_strategy("x", history={}, strategies=["a", "b", "c"])
|
||||
assert s in ("a", "b", "c")
|
||||
|
||||
|
||||
def test_retrieval_policy_per_query_type():
|
||||
"""Different query_types accumulate separate weights."""
|
||||
from iai_mcp.learn import RetrievalFeedback, update_retrieval_weights
|
||||
|
||||
ids = [uuid4()]
|
||||
w1 = {"W_COSINE": 1.0, "W_AAAK": 0.3, "W_DEGREE": 0.1, "W_AGE": 0.05}
|
||||
w2 = {"W_COSINE": 1.0, "W_AAAK": 0.3, "W_DEGREE": 0.1, "W_AGE": 0.05}
|
||||
# Query type A: user uses everything
|
||||
fb_a = RetrievalFeedback("A", ids, ids, False, False)
|
||||
w1 = update_retrieval_weights(fb_a, w1)
|
||||
# Query type B: user corrects
|
||||
fb_b = RetrievalFeedback("B", ids, [], True, False)
|
||||
w2 = update_retrieval_weights(fb_b, w2)
|
||||
assert w1["W_COSINE"] > w2["W_COSINE"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue