import os import re from datasets import Dataset, concatenate_datasets, load_dataset from llama_cpp import Llama # 1. CONFIGURATION GGUF_MODEL_PATH = "./path/to/model.gguf" INPUT_PARQUET_PATH = "./path/to/input.parquet" OUTPUT_PARQUET_PATH = "./path/to/output.parquet" NEW_ROWS_COUNT = 100 # Check if files exist if not os.path.exists(GGUF_MODEL_PATH): print(f"❌ Error: GGUF model file not found at {GGUF_MODEL_PATH}") exit() if not os.path.exists(INPUT_PARQUET_PATH): print(f"❌ Error: Input Parquet file not found at {INPUT_PARQUET_PATH}") exit() # 2. LOAD GGUF MODEL - GPU (Vulkan) ONLY print("Loading llama.cpp model...") try: model = Llama( model_path=GGUF_MODEL_PATH, n_ctx=8192, n_gpu_layers=-1, # ALL layers to GPU verbose=False, # No logging n_batch=512, logits_all=False, use_mmap=True, use_mlock=False, ) print("✅ llama.cpp model loaded with Vulkan GPU.") except Exception as e: print(f"❌ Error loading model: {e}") exit() # 3. LOAD EXISTING DATASET print("Loading existing dataset from INPUT...") try: original_ds = load_dataset( "parquet", data_files=[INPUT_PARQUET_PATH], split="train" ) print(f"Original Columns: {original_ds.column_names}") print(f"Original Dataset Shape: {original_ds.shape}") except Exception as e: print(f"❌ Error loading dataset: {e}") exit() existing_labels = list(set(original_ds["label"])) # 4. GENERATE SYNTHETIC DATA - STRUCTURED OUTPUT print(f"Generating {NEW_ROWS_COUNT} synthetic records...") synthetic_data = [] for i in range(NEW_ROWS_COUNT): try: messages = [ { "role": "system", "content": "You are a data generator. Output ONLY the format below, nothing else.", }, { "role": "user", "content": """YOUR PROMPT GOES HERE""", }, ] # Generate with sampling parameters response = model.create_chat_completion( messages=messages, max_tokens=200, temperature=1.0, top_p=0.95, top_k=20, min_p=0.0, ) # Get response text generated_text = response["choices"][0]["message"]["content"].strip() # DIRECTLY PARSE TO STRUCTURED FORMAT parsed_row = {} lines = generated_text.split("\n") question = None answer = None label = None found_question = False found_answer = False found_label = False for line in lines: line = line.strip() # Extract Question if "Question:" in line and "Answer:" not in line: match = re.search( r"Question:\s*(.+?)(?:\nAnswer|\nLabel|$)", line, re.IGNORECASE ) if match: question = match.group(1).strip() found_question = True # Extract Answer elif "Answer:" in line: match = re.search(r"Answer:\s*(.+?)(?:\nLabel|$)", line, re.IGNORECASE) if match: answer = match.group(1).strip() found_answer = True # Extract Label elif "Label:" in line: match = re.search(r"Label:\s*(.+)", line, re.IGNORECASE) if match: label = match.group(1).strip() found_label = True # VALIDATION if not all([question, answer]): print(f"⚠️ Row {i + 1}: Incomplete output. Skipping.") for line in lines: print(line) continue if not label: label = "unbiased" else: # Normalize label label = ( label.lower().strip('"').strip("'").replace("[", "").replace("]", "") ) if label not in existing_labels: print(f"⚠️ Row {i + 1}: Invalid label '{label}'. Skipping.") continue # Clean up question = re.sub(r"```.*?```", "", question).strip() answer = re.sub(r"```.*?```", "", answer).strip() parsed_row = {"question": question, "answer": answer, "label": label} # PRINT PARSED DATA IN TERMINAL print(f"✅ ROW {i + 1} PARSED:") print(f" Question: {question}") print(f" Answer: {answer}") print(f" Label: {label}") print() synthetic_data.append(parsed_row) except Exception as e: print(f"❌ Row {i + 1}: Error: {e}") continue # 5. SAVE TO PARQUET if synthetic_data: print(f"Adding {len(synthetic_data)} synthetic records...") synthetic_ds = Dataset.from_list(synthetic_data) base_ds = None if os.path.exists(INPUT_PARQUET_PATH): base_ds = load_dataset( "parquet", data_files=[INPUT_PARQUET_PATH], split="train" ) print(f"Existing: {len(base_ds)} rows") else: base_ds = original_ds combined_ds = concatenate_datasets([base_ds, synthetic_ds]) print(f"Combined: {len(combined_ds)} rows") combined_ds.to_parquet(OUTPUT_PARQUET_PATH) print(f"✅ Saved to {OUTPUT_PARQUET_PATH}") else: print("❌ No valid records generated.")