Unsloth-Finetune-Template/synthetic-data.py

179 lines
5.2 KiB
Python
Raw Normal View History

2026-06-02 15:45:59 +02:00
import os
import re
from datasets import Dataset, concatenate_datasets, load_dataset
from llama_cpp import Llama
# 1. CONFIGURATION
GGUF_MODEL_PATH = "./path/to/model.gguf"
INPUT_PARQUET_PATH = "./path/to/input.parquet"
OUTPUT_PARQUET_PATH = "./path/to/output.parquet"
NEW_ROWS_COUNT = 100
# Check if files exist
if not os.path.exists(GGUF_MODEL_PATH):
print(f"❌ Error: GGUF model file not found at {GGUF_MODEL_PATH}")
exit()
if not os.path.exists(INPUT_PARQUET_PATH):
print(f"❌ Error: Input Parquet file not found at {INPUT_PARQUET_PATH}")
exit()
# 2. LOAD GGUF MODEL - GPU (Vulkan) ONLY
print("Loading llama.cpp model...")
try:
model = Llama(
model_path=GGUF_MODEL_PATH,
n_ctx=8192,
n_gpu_layers=-1, # ALL layers to GPU
verbose=False, # No logging
n_batch=512,
logits_all=False,
use_mmap=True,
use_mlock=False,
)
print("✅ llama.cpp model loaded with Vulkan GPU.")
except Exception as e:
print(f"❌ Error loading model: {e}")
exit()
# 3. LOAD EXISTING DATASET
print("Loading existing dataset from INPUT...")
try:
original_ds = load_dataset(
"parquet", data_files=[INPUT_PARQUET_PATH], split="train"
)
print(f"Original Columns: {original_ds.column_names}")
print(f"Original Dataset Shape: {original_ds.shape}")
except Exception as e:
print(f"❌ Error loading dataset: {e}")
exit()
existing_labels = list(set(original_ds["label"]))
# 4. GENERATE SYNTHETIC DATA - STRUCTURED OUTPUT
print(f"Generating {NEW_ROWS_COUNT} synthetic records...")
synthetic_data = []
for i in range(NEW_ROWS_COUNT):
try:
messages = [
{
"role": "system",
"content": "You are a data generator. Output ONLY the format below, nothing else.",
},
{
"role": "user",
2026-06-02 16:50:03 +02:00
"content": """YOUR PROMPT GOES HERE""",
2026-06-02 15:45:59 +02:00
},
]
# Generate with sampling parameters
response = model.create_chat_completion(
messages=messages,
max_tokens=200,
temperature=1.0,
top_p=0.95,
top_k=20,
min_p=0.0,
)
# Get response text
generated_text = response["choices"][0]["message"]["content"].strip()
# DIRECTLY PARSE TO STRUCTURED FORMAT
parsed_row = {}
lines = generated_text.split("\n")
question = None
answer = None
label = None
found_question = False
found_answer = False
found_label = False
for line in lines:
line = line.strip()
# Extract Question
if "Question:" in line and "Answer:" not in line:
match = re.search(
r"Question:\s*(.+?)(?:\nAnswer|\nLabel|$)", line, re.IGNORECASE
)
if match:
question = match.group(1).strip()
found_question = True
# Extract Answer
elif "Answer:" in line:
match = re.search(r"Answer:\s*(.+?)(?:\nLabel|$)", line, re.IGNORECASE)
if match:
answer = match.group(1).strip()
found_answer = True
# Extract Label
elif "Label:" in line:
match = re.search(r"Label:\s*(.+)", line, re.IGNORECASE)
if match:
label = match.group(1).strip()
found_label = True
# VALIDATION
if not all([question, answer]):
print(f"⚠️ Row {i + 1}: Incomplete output. Skipping.")
for line in lines:
print(line)
continue
if not label:
label = "unbiased"
else:
# Normalize label
label = (
label.lower().strip('"').strip("'").replace("[", "").replace("]", "")
)
if label not in existing_labels:
print(f"⚠️ Row {i + 1}: Invalid label '{label}'. Skipping.")
continue
# Clean up
question = re.sub(r"```.*?```", "", question).strip()
answer = re.sub(r"```.*?```", "", answer).strip()
parsed_row = {"question": question, "answer": answer, "label": label}
# PRINT PARSED DATA IN TERMINAL
print(f"✅ ROW {i + 1} PARSED:")
print(f" Question: {question}")
print(f" Answer: {answer}")
print(f" Label: {label}")
print()
synthetic_data.append(parsed_row)
except Exception as e:
print(f"❌ Row {i + 1}: Error: {e}")
continue
# 5. SAVE TO PARQUET
if synthetic_data:
print(f"Adding {len(synthetic_data)} synthetic records...")
synthetic_ds = Dataset.from_list(synthetic_data)
base_ds = None
if os.path.exists(INPUT_PARQUET_PATH):
base_ds = load_dataset(
"parquet", data_files=[INPUT_PARQUET_PATH], split="train"
)
print(f"Existing: {len(base_ds)} rows")
else:
base_ds = original_ds
combined_ds = concatenate_datasets([base_ds, synthetic_ds])
print(f"Combined: {len(combined_ds)} rows")
combined_ds.to_parquet(OUTPUT_PARQUET_PATH)
print(f"✅ Saved to {OUTPUT_PARQUET_PATH}")
else:
print("❌ No valid records generated.")