fix: parallelization effort, progressbar w/ bit-identitical file write compared to sequential path

This commit is contained in:
Apunkt 2026-05-21 09:14:20 +02:00
parent 35e49044f1
commit 8ca5371a98
No known key found for this signature in database
9 changed files with 407 additions and 387 deletions

View file

@ -22,11 +22,9 @@
import os
import sys
import sqlite3
import pathlib
import concurrent.futures
import threading
import time
import queue
from collections import deque
from threading import Thread
from importlib import resources
from tqdm import tqdm
@ -43,26 +41,32 @@ except ImportError:
import sfo
def _decrypt_sector_worker(disc_key, sector_data, sector_number):
"""Standalone worker for parallel sector decryption."""
def _make_iv(sector_number):
"""Build a 16-byte IV from a sector number (little-endian)."""
iv = bytearray(16)
num = sector_number
for j in range(16):
iv[15 - j] = num & 0xFF
num >>= 8
cipher = AES.new(disc_key, AES.MODE_CBC, bytes(iv))
return (sector_number, cipher.decrypt(sector_data))
return bytes(iv)
def _encrypt_sector_worker(disc_key, sector_data, sector_number):
"""Standalone worker for parallel sector encryption."""
iv = bytearray(16)
num = sector_number
for j in range(16):
iv[15 - j] = num & 0xFF
num >>= 8
cipher = AES.new(disc_key, AES.MODE_CBC, bytes(iv))
return (sector_number, cipher.encrypt(sector_data))
def _process_sector_chunk_mp(args):
"""Multiprocessing worker: (de/en)crypt a contiguous run of sectors.
Picklable, module-level function for use with ProcessPoolExecutor. Takes a
single ``bytes`` blob of one or more whole sectors and returns the processed
blob, so only one (large) object is pickled per task instead of many small
ones this is what makes the multiprocessing worthwhile.
"""
disc_key, sector_blob, sector_start_number, encrypt_mode = args
out = bytearray()
for offset in range(0, len(sector_blob), core.SECTOR):
sector = sector_blob[offset:offset + core.SECTOR]
iv = _make_iv(sector_start_number + offset // core.SECTOR)
cipher = AES.new(disc_key, AES.MODE_CBC, iv)
out += cipher.encrypt(sector) if encrypt_mode else cipher.decrypt(sector)
return bytes(out)
class ISO:
@ -184,189 +188,134 @@ class ISO:
if not args.output:
args.output = f'{game_title} [{param["TITLE_ID"]}].iso'
except Exception:
except (UnicodeDecodeError, KeyError, IndexError, ValueError):
core.warning('Failed reading SFO', args)
self.disc_key = self.get_key_from_args(game_title, args)
if args.verbose and not args.quiet:
self.print_info()
def _make_iv(self, sector_number):
"""Build a 16-byte IV from a sector number (little-endian)."""
iv = bytearray(16)
num = sector_number
for j in range(16):
iv[15 - j] = num & 0xFF
num >>= 8
return bytes(iv)
def _process_region_pipeline(self, input_path, region, num_workers, encrypt_mode, args):
"""Process an encrypted region using a reader-worker pipeline.
A reader thread reads sectors from the file and puts them on a queue.
Worker threads pull from the queue, process each sector in parallel,
and store results. This overlaps I/O with processing for better CPU usage.
"""
num_sectors = (region['end'] - region['start']) // core.SECTOR
queue_size = min(64, num_sectors)
sector_queue = queue.Queue(maxsize=queue_size)
results = [None] * num_sectors
results_lock = threading.Lock()
def reader():
with open(input_path, 'rb') as f:
f.seek(region['start'])
for i in range(num_sectors):
sector_data = f.read(core.SECTOR)
sector_queue.put((i, sector_data))
for _ in range(num_workers):
sector_queue.put(None)
def worker():
while True:
item = sector_queue.get()
if item is None:
sector_queue.task_done()
break
idx, sector_data = item
_, processed = (_encrypt_sector_worker if encrypt_mode else _decrypt_sector_worker)(
self.disc_key, sector_data, region['start'] // core.SECTOR + idx
)
with results_lock:
results[idx] = processed
sector_queue.task_done()
reader_thread = threading.Thread(target=reader, daemon=True)
reader_thread.start()
workers = []
for _ in range(num_workers):
t = threading.Thread(target=worker, daemon=True)
t.start()
workers.append(t)
for t in workers:
t.join()
reader_thread.join()
return b''.join(results)
def decrypt(self, args):
"""Decrypt self using args from argparse."""
core.vprint(f'Decrypting with disc key: {self.disc_key.hex()}', args)
num_workers = args.threads if args.threads and args.threads > 0 else os.cpu_count() or 1
if num_workers > 1:
core.vprint(f'Using {num_workers} threads for parallel decryption', args)
if not args.output:
output_name = f'{self.game_id}.iso'
else:
output_name = args.output
core.vprint(f'Decrypted .iso is output to: {output_name}', args)
total_sectors = self.size // core.SECTOR
with open(args.iso, 'rb') as input_iso, open(output_name, 'wb') as output_iso:
pbar = tqdm(total=total_sectors, file=sys.stdout, disable=args.quiet, leave=True)
for region in self.regions:
region_sectors = (region['end'] - region['start']) // core.SECTOR
if not region['enc']:
# Unencrypted region — copy sequentially
input_iso.seek(region['start'])
for _ in range(region_sectors):
data = input_iso.read(core.SECTOR)
if not data:
core.warning('Trying to read past the end of the file', args)
break
output_iso.write(data)
pbar.update(1)
else:
# Encrypted region — pipeline: reader thread + worker threads
if num_workers > 1:
processed = self._process_region_pipeline(
args.iso, region, num_workers, encrypt_mode=False, args=args
)
else:
# Sequential fallback
input_iso.seek(region['start'])
processed = bytearray()
for i in range(region_sectors):
sector_num = region['start'] // core.SECTOR + i
iv = self._make_iv(sector_num)
cipher = AES.new(self.disc_key, AES.MODE_CBC, iv)
processed.extend(cipher.decrypt(input_iso.read(core.SECTOR)))
processed = bytes(processed)
output_iso.write(processed)
pbar.update(region_sectors)
pbar.close()
core.vprint('Decryption complete!', args)
self._process_iso(args, encrypt_mode=False)
def encrypt(self, args):
"""Encrypt self using args from argparse."""
self._process_iso(args, encrypt_mode=True)
core.vprint(f'Re-encrypting with disc key: {self.disc_key.hex()}', args)
def _process_iso(self, args, encrypt_mode):
"""Shared driver for decrypt/encrypt. ``encrypt_mode`` selects direction."""
num_workers = args.threads if args.threads and args.threads > 0 else os.cpu_count() or 1
if num_workers > 1:
core.vprint(f'Using {num_workers} threads for parallel re-encryption', args)
core.vprint(f'{"Re-encrypting" if encrypt_mode else "Decrypting"} with disc key: {self.disc_key.hex()}', args)
if not args.output:
output_name = f'{self.game_id}_e.iso'
else:
num_workers = args.threads if args.threads and args.threads > 0 else (os.cpu_count() or 1)
if args.output:
output_name = args.output
else:
output_name = f'{self.game_id}_e.iso' if encrypt_mode else f'{self.game_id}.iso'
core.vprint(f'Re-encrypted .iso is output to: {output_name}', args)
core.vprint(f'{"Re-encrypted" if encrypt_mode else "Decrypted"} .iso is output to: {output_name}', args)
with open(args.iso, 'rb') as input_iso, open(output_name, 'wb') as output_iso:
total_sectors = self.size // core.SECTOR
pbar = tqdm(total=total_sectors, file=sys.stdout, disable=args.quiet, leave=True)
pbar = tqdm(total=(self.size // 2048), file=sys.stdout, disable=args.quiet, leave=True)
if num_workers > 1:
core.vprint(f'Using {num_workers} processes for parallel {"re-encryption" if encrypt_mode else "decryption"}', args)
self._process_parallel(args, output_name, encrypt_mode, num_workers, pbar)
else:
self._process_sequential(args, output_name, encrypt_mode, pbar)
pbar.close()
core.vprint(f'{"Re-encryption" if encrypt_mode else "Decryption"} complete!', args)
def _process_sequential(self, args, output_name, encrypt_mode, pbar):
"""Single-process path: read, optionally (de/en)crypt, and write per sector."""
with open(args.iso, 'rb') as input_iso, open(output_name, 'wb') as output_file:
for region in self.regions:
region_sectors = (region['end'] - region['start']) // core.SECTOR
base_sector = region['start'] // core.SECTOR
input_iso.seek(region['start'])
if not region['enc']:
# Unencrypted region — copy sequentially
for i in range(region_sectors):
data = input_iso.read(core.SECTOR)
if not data:
core.warning('Trying to read past the end of the file', args)
break
if region['enc']:
cipher = AES.new(self.disc_key, AES.MODE_CBC, _make_iv(base_sector + i))
data = cipher.encrypt(data) if encrypt_mode else cipher.decrypt(data)
output_file.write(data)
pbar.update(1)
def _process_parallel(self, args, output_name, encrypt_mode, num_workers, pbar):
"""Multi-process path: one process pool for the whole operation.
The output file is pre-sized so every region can be written to its
absolute offset; this also clears any stale file from a previous run.
"""
# Pre-create the output at its final size and write to absolute offsets.
total_bytes = self.regions[-1]['end']
with open(output_name, 'wb') as f:
f.truncate(total_bytes)
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor, \
open(args.iso, 'rb') as input_iso, \
open(output_name, 'r+b') as out:
for region in self.regions:
if region['enc']:
self._process_region_parallel(input_iso, out, region, executor, num_workers, encrypt_mode, pbar)
else:
# Unencrypted region — copy straight through to its offset.
region_sectors = (region['end'] - region['start']) // core.SECTOR
input_iso.seek(region['start'])
out.seek(region['start'])
for _ in range(region_sectors):
data = input_iso.read(core.SECTOR)
if not data:
core.warning('Trying to read past the end of the file', args)
break
output_iso.write(data)
out.write(data)
pbar.update(1)
else:
# Encrypted region — pipeline: reader thread + worker threads
if num_workers > 1:
processed = self._process_region_pipeline(
args.iso, region, num_workers, encrypt_mode=True, args=args
)
else:
# Sequential fallback
input_iso.seek(region['start'])
processed = bytearray()
for i in range(region_sectors):
sector_num = region['start'] // core.SECTOR + i
iv = self._make_iv(sector_num)
cipher = AES.new(self.disc_key, AES.MODE_CBC, iv)
processed.extend(cipher.encrypt(input_iso.read(core.SECTOR)))
processed = bytes(processed)
output_iso.write(processed)
pbar.update(region_sectors)
def _process_region_parallel(self, input_iso, out, region, executor, num_workers, encrypt_mode, pbar):
"""(De/en)crypt one encrypted region across worker processes.
if not args.quiet:
pbar.close()
Reads the region in bounded-size chunks and keeps only a small window of
chunks in flight, so peak memory stays bounded even for huge ISOs.
Results are written back to their absolute offsets as they complete.
"""
num_sectors = (region['end'] - region['start']) // core.SECTOR
base_sector = region['start'] // core.SECTOR
core.vprint('Re-encryption complete!', args)
# Sectors per task: large enough to amortise IPC, small enough to keep
# every worker busy and bound peak memory.
sectors_per_chunk = max(1, min(512, (num_sectors // (num_workers * 4)) or 1))
max_in_flight = num_workers * 4
input_iso.seek(region['start'])
next_sector = base_sector
sectors_left = num_sectors
pending = deque()
def submit_more():
nonlocal next_sector, sectors_left
while sectors_left > 0 and len(pending) < max_in_flight:
count = min(sectors_per_chunk, sectors_left)
blob = input_iso.read(count * core.SECTOR)
future = executor.submit(_process_sector_chunk_mp,
(self.disc_key, blob, next_sector, encrypt_mode))
pending.append((next_sector, count, future))
next_sector += count
sectors_left -= count
submit_more()
while pending:
start_sector, count, future = pending.popleft()
out.seek(start_sector * core.SECTOR)
out.write(future.result())
pbar.update(count)
submit_more()
def get_key_from_args(self, game_title, args):
# key provided with -d / --decryption-key
@ -392,24 +341,17 @@ class ISO:
# No key or .ird specified. Let's first check if keys.db is packaged with this release
core.vprint('Checking for bundled redump keys', args)
try:
db_path = resources.files(__name__).joinpath('data', 'keys.db')
if hasattr(db_path, 'read_bytes'):
# importlib.resources.abc.Traversable - write to temp file for sqlite3
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.db') as tmp:
tmp.write(db_path.read_bytes())
db = sqlite3.connect(tmp.name)
else:
db = sqlite3.connect(str(db_path))
except (FileNotFoundError, AttributeError):
db = sqlite3.connect((pathlib.Path(__file__).resolve() / 'data/') / 'keys.db')
db_path = resources.files(__name__).joinpath('data', 'keys.db')
db = sqlite3.connect(str(db_path))
c = db.cursor()
# UPDATE: 2024 - New database now has game/title ids. See if we have that.
core.vprint('Searching using TITLE_ID', args)
keys = c.execute('SELECT name, key FROM games WHERE title_id = ?', [self.game_id.replace('-','')]).fetchall()
try:
keys = c.execute('SELECT name, key FROM games WHERE title_id = ?', [self.game_id.replace('-', '')]).fetchall()
except sqlite3.OperationalError:
core.error('keys.db not found or invalid. Build it with: python3 tools/keys2db.py')
if len(keys) == 1:
core.vprint(f'Found potential redump key: "{keys[0][0]}"', args)
return keys[0][1]
@ -424,10 +366,15 @@ class ISO:
# If not, see if we can filter it out based on name and size
core.vprint('Trying to find redump key based on size, game title, and country', args)
if not game_title:
raise ValueError
core.error('Could not determine game title from PARAM.SFO. Specify a decryption key with -d or provide an IRD file with -k.')
try:
country = core.serial_country(self.game_id)
except ValueError:
core.error(f'Unknown country code in game ID "{self.game_id}". Specify a decryption key with -d or provide an IRD file with -k.')
keys = c.execute('SELECT name, key FROM games WHERE lower(name) LIKE ? AND size = ?', [
'%' + '%'.join(game_title.lower().split(' ')) + '%' + core.serial_country(self.game_id).lower() + '%', str(self.size)]).fetchall()
'%' + '%'.join(game_title.lower().split(' ')) + '%' + country.lower() + '%', str(self.size)]).fetchall()
if keys:
core.vprint(f'Found potential redump key: "{keys[0][0]}"', args)
return keys[0][1]
@ -439,35 +386,29 @@ class ISO:
# Okay, searching has failed us, but maaaybe the checksum works?
core.vprint('Trying to find redump key based on CRC32', args)
crc32 = None
crc32_continue = [True]
cancel = threading.Event()
crc_done = threading.Event()
if args.checksum_timeout > 0:
def timeout(allow_execution):
time.sleep(float(args.checksum_timeout))
if crc32 is None:
def timeout():
# Abort the CRC32 calculation if it hasn't finished in time.
if not crc_done.wait(timeout=float(args.checksum_timeout)):
core.vprint(f'could not calculate CRC32 before {args.checksum_timeout}-second timeout', args)
allow_execution[0] = False
crc_thread = Thread(target=timeout, args=(crc32_continue,), daemon=True)
cancel.set()
crc_thread = Thread(target=timeout, daemon=True)
crc_thread.start()
crc32 = core.crc32(args.iso, crc32_continue)
if crc32 is None:
calculated_crc = core.crc32(args.iso, cancel)
crc_done.set()
if calculated_crc is None:
raise TimeoutError
keys = c.execute('SELECT name, key FROM games WHERE crc32=?', [crc32.lower()]).fetchall()
keys = c.execute('SELECT name, key FROM games WHERE crc32=?', [calculated_crc.lower()]).fetchall()
if len(keys) == 1:
core.vprint(f'Found potential redump key: "{keys[0][0]}" (CRC32={crc32.lower()})', args)
core.vprint(f'Found potential redump key: "{keys[0][0]}" (CRC32={calculated_crc.lower()})', args)
return keys[0][1]
# Fallback to downloading an IRD from the internet (currently disabled)
# try:
# core.warning('No IRD file specified, finding required file', args)
# args.ird = core.ird_by_game_id(self.game_id) # Download ird
# return get_key_from_ird(args.ird)
# except:
# core.vprint('Could not download IRD file', args)
raise ValueError
core.error('could not find disc key')
def print_info(self):
# TODO: This could probably have been a __str__? Who cares?