mirror of
https://github.com/SakanaAI/doc-to-lora.git
synced 2026-04-25 00:06:20 +02:00
Doc-to-LoRA release
This commit is contained in:
commit
1abe8ae16d
92 changed files with 22131 additions and 0 deletions
41
examples/python_api.py
Normal file
41
examples/python_api.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
# caveat: this interface only supports non-batched inputs
|
||||
# for batched inference please see `src/ctx_to_lora/modeling/hypernet.py`
|
||||
import torch
|
||||
|
||||
from ctx_to_lora.model_loading import get_tokenizer
|
||||
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
|
||||
|
||||
# model loading
|
||||
checkpoint_path = "trained_d2l/gemma_demo/checkpoint-80000/pytorch_model.bin"
|
||||
state_dict = torch.load(checkpoint_path, weights_only=False)
|
||||
model = ModulatedPretrainedModel.from_state_dict(
|
||||
state_dict, train=False, use_sequence_packing=False
|
||||
)
|
||||
model.reset()
|
||||
tokenizer = get_tokenizer(model.base_model.name_or_path)
|
||||
|
||||
# prepare data
|
||||
doc = open("data/sakana_wiki.txt", "r").read()
|
||||
chat = [{"role": "user", "content": "Tell me about Sakana AI."}]
|
||||
chat_ids = tokenizer.apply_chat_template(
|
||||
chat,
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
|
||||
# calls after internalization will be influenced by internalized info
|
||||
model.internalize(doc)
|
||||
|
||||
outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
|
||||
|
||||
# remove internalized info
|
||||
# model.reset()
|
||||
|
||||
# without internalized info, the model will halucinate
|
||||
# outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
|
||||
# print(tokenizer.decode(outputs[0]))
|
||||
Loading…
Add table
Add a link
Reference in a new issue