205 lines
7.3 KiB
Python
205 lines
7.3 KiB
Python
import os
|
|
import re
|
|
|
|
from datasets import Dataset, concatenate_datasets, load_dataset
|
|
from llama_cpp import Llama
|
|
|
|
# 1. CONFIGURATION
|
|
GGUF_MODEL_PATH = "./path/to/model.gguf"
|
|
INPUT_PARQUET_PATH = "./path/to/input.parquet"
|
|
OUTPUT_PARQUET_PATH = "./path/to/output.parquet"
|
|
NEW_ROWS_COUNT = 100
|
|
|
|
# Check if files exist
|
|
if not os.path.exists(GGUF_MODEL_PATH):
|
|
print(f"❌ Error: GGUF model file not found at {GGUF_MODEL_PATH}")
|
|
exit()
|
|
if not os.path.exists(INPUT_PARQUET_PATH):
|
|
print(f"❌ Error: Input Parquet file not found at {INPUT_PARQUET_PATH}")
|
|
exit()
|
|
|
|
# 2. LOAD GGUF MODEL - GPU (Vulkan) ONLY
|
|
print("Loading llama.cpp model...")
|
|
try:
|
|
model = Llama(
|
|
model_path=GGUF_MODEL_PATH,
|
|
n_ctx=8192,
|
|
n_gpu_layers=-1, # ALL layers to GPU
|
|
verbose=False, # No logging
|
|
n_batch=512,
|
|
logits_all=False,
|
|
use_mmap=True,
|
|
use_mlock=False,
|
|
)
|
|
print("✅ llama.cpp model loaded with Vulkan GPU.")
|
|
except Exception as e:
|
|
print(f"❌ Error loading model: {e}")
|
|
exit()
|
|
|
|
# 3. LOAD EXISTING DATASET
|
|
print("Loading existing dataset from INPUT...")
|
|
try:
|
|
original_ds = load_dataset(
|
|
"parquet", data_files=[INPUT_PARQUET_PATH], split="train"
|
|
)
|
|
print(f"Original Columns: {original_ds.column_names}")
|
|
print(f"Original Dataset Shape: {original_ds.shape}")
|
|
except Exception as e:
|
|
print(f"❌ Error loading dataset: {e}")
|
|
exit()
|
|
|
|
existing_labels = list(set(original_ds["label"]))
|
|
|
|
# 4. GENERATE SYNTHETIC DATA - STRUCTURED OUTPUT
|
|
print(f"Generating {NEW_ROWS_COUNT} synthetic records...")
|
|
synthetic_data = []
|
|
|
|
for i in range(NEW_ROWS_COUNT):
|
|
try:
|
|
# Use chat format for Qwen
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a data generator. Output ONLY the format below, nothing else.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": """You are generating a gender bias avoidance dataset.
|
|
You must output ONLY raw text in the following format.
|
|
DO NOT use JSON. DO NOT use markdown code blocks (no ```).
|
|
DO NOT repeat the instructions.
|
|
Questions should be around 1-2 sentences long. Answers should be around 5 paragraphs in lengh essays.
|
|
Answers should be answered thoroughly and detailled.
|
|
Questions can vary from simple to complex systemic societal issues.
|
|
Physiological differences are still real and should be accounted for when encountering a question related to it.
|
|
Questions should be equaly distributed across all categories, like job/works, societal, relationships, personal, financial etc...
|
|
|
|
Format:
|
|
Question: [Ask a question which is stereotypically answered with gender bias]
|
|
Answer: [Provide an answer which is COMPLETELY unbiased]
|
|
Label: unbiased
|
|
|
|
DO NOT repeat the format without actually filling it out and DO NOT create empty placeholder questions.
|
|
----
|
|
Make sure that the content and Question: or Answer: are on the same line. Like this:
|
|
Question: Here goes the question. It can continue in new lines but needs to start here.
|
|
and not like this:
|
|
Question:
|
|
It doesnt go here without having a previouse sentence after the Question: tag.
|
|
-----
|
|
Now generate one record strictly adhering to the format, filling out both question and answer.
|
|
Question:
|
|
Answer:
|
|
Label: unbiased""",
|
|
},
|
|
]
|
|
|
|
# Generate with sampling parameters
|
|
response = model.create_chat_completion(
|
|
messages=messages,
|
|
max_tokens=200,
|
|
temperature=1.0,
|
|
top_p=0.95,
|
|
top_k=20,
|
|
min_p=0.0,
|
|
)
|
|
|
|
# Get response text
|
|
generated_text = response["choices"][0]["message"]["content"].strip()
|
|
|
|
# DIRECTLY PARSE TO STRUCTURED FORMAT
|
|
parsed_row = {}
|
|
lines = generated_text.split("\n")
|
|
|
|
question = None
|
|
answer = None
|
|
label = None
|
|
found_question = False
|
|
found_answer = False
|
|
found_label = False
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
|
|
# Extract Question
|
|
if "Question:" in line and "Answer:" not in line:
|
|
match = re.search(
|
|
r"Question:\s*(.+?)(?:\nAnswer|\nLabel|$)", line, re.IGNORECASE
|
|
)
|
|
if match:
|
|
question = match.group(1).strip()
|
|
found_question = True
|
|
|
|
# Extract Answer
|
|
elif "Answer:" in line:
|
|
match = re.search(r"Answer:\s*(.+?)(?:\nLabel|$)", line, re.IGNORECASE)
|
|
if match:
|
|
answer = match.group(1).strip()
|
|
found_answer = True
|
|
|
|
# Extract Label
|
|
elif "Label:" in line:
|
|
match = re.search(r"Label:\s*(.+)", line, re.IGNORECASE)
|
|
if match:
|
|
label = match.group(1).strip()
|
|
found_label = True
|
|
|
|
# VALIDATION
|
|
if not all([question, answer]):
|
|
print(f"⚠️ Row {i + 1}: Incomplete output. Skipping.")
|
|
for line in lines:
|
|
print(line)
|
|
continue
|
|
|
|
if not label:
|
|
label = "unbiased"
|
|
else:
|
|
# Normalize label
|
|
label = (
|
|
label.lower().strip('"').strip("'").replace("[", "").replace("]", "")
|
|
)
|
|
|
|
if label not in existing_labels:
|
|
print(f"⚠️ Row {i + 1}: Invalid label '{label}'. Skipping.")
|
|
continue
|
|
|
|
# Clean up
|
|
question = re.sub(r"```.*?```", "", question).strip()
|
|
answer = re.sub(r"```.*?```", "", answer).strip()
|
|
|
|
parsed_row = {"question": question, "answer": answer, "label": label}
|
|
|
|
# PRINT PARSED DATA IN TERMINAL
|
|
print(f"✅ ROW {i + 1} PARSED:")
|
|
print(f" Question: {question}")
|
|
print(f" Answer: {answer}")
|
|
print(f" Label: {label}")
|
|
print()
|
|
|
|
synthetic_data.append(parsed_row)
|
|
|
|
except Exception as e:
|
|
print(f"❌ Row {i + 1}: Error: {e}")
|
|
continue
|
|
|
|
# 5. SAVE TO PARQUET
|
|
if synthetic_data:
|
|
print(f"Adding {len(synthetic_data)} synthetic records...")
|
|
synthetic_ds = Dataset.from_list(synthetic_data)
|
|
|
|
base_ds = None
|
|
if os.path.exists(INPUT_PARQUET_PATH):
|
|
base_ds = load_dataset(
|
|
"parquet", data_files=[INPUT_PARQUET_PATH], split="train"
|
|
)
|
|
print(f"Existing: {len(base_ds)} rows")
|
|
else:
|
|
base_ds = original_ds
|
|
|
|
combined_ds = concatenate_datasets([base_ds, synthetic_ds])
|
|
print(f"Combined: {len(combined_ds)} rows")
|
|
|
|
combined_ds.to_parquet(OUTPUT_PARQUET_PATH)
|
|
print(f"✅ Saved to {OUTPUT_PARQUET_PATH}")
|
|
else:
|
|
print("❌ No valid records generated.")
|