frankenstallm / data /sft_dataset.py
pathcosmos's picture
feat: Add data pipeline scripts + phase reports (Tier 3 - reproducibility)
b3d361d verified
"""
SFT (Supervised Fine-Tuning) dataset for the Korean LLM project.
Reads JSONL files in three supported formats:
1. Alpaca format
{"instruction": "...", "input": "...", "output": "..."}
2. Alpaca format without optional input
{"instruction": "...", "output": "..."}
3. Conversation format
{"conversations": [{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}]}
Chat template applied:
<|user|>
{instruction or user turn}
<|assistant|>
{output or assistant turn}</s>
Loss masking: ``labels`` is -1 for all prompt tokens so
``nn.CrossEntropyLoss`` (ignore_index=-1) only trains on
the assistant responses.
"""
from __future__ import annotations
import json
import multiprocessing
import time
from pathlib import Path
from typing import Union
import torch
from torch.utils.data import Dataset
from tokenizers import Tokenizer # HuggingFace tokenizers (fast, Rust-based)
# ---------------------------------------------------------------------------
# Role tags used in the chat template.
# ---------------------------------------------------------------------------
_USER_TAG = "<|user|>\n"
_ASSISTANT_TAG = "<|assistant|>\n"
_EOS_STRING = "</s>"
def _build_alpaca_turns(
instruction: str,
input_text: str,
output: str,
) -> tuple[str, str]:
"""
Convert an Alpaca-format sample into (prompt, response) strings.
The *prompt* includes the user tag and instruction (+ optional input).
The *response* includes the assistant tag and output, plus EOS.
Args:
instruction: The task instruction.
input_text: Optional additional input context. May be empty.
output: The expected assistant response.
Returns:
Tuple of (prompt_text, response_text).
"""
user_body = instruction
if input_text and input_text.strip():
user_body = f"{instruction}\n{input_text.strip()}"
prompt = f"{_USER_TAG}{user_body}\n{_ASSISTANT_TAG}"
response = f"{output}{_EOS_STRING}"
return prompt, response
def _build_conversation_turns(
conversations: list[dict],
) -> list[tuple[str, str]]:
"""
Convert a conversation list into a sequence of (prompt, response) pairs.
For a multi-turn conversation the prompt for turn *k* is the entire
dialogue history up to (but not including) assistant turn *k*.
Only user→assistant pairs contribute training samples. Consecutive
user messages are merged. Conversations that start with an assistant
turn, or that have no assistant turn, are skipped (return empty list).
Args:
conversations: List of dicts with ``role`` and ``content`` keys.
Roles are expected to be ``"user"`` or ``"assistant"``.
Returns:
List of (prompt_text, response_text) tuples, one per assistant turn.
"""
pairs: list[tuple[str, str]] = []
history = "" # accumulated dialogue so far
pending_user = "" # user content not yet closed by an assistant turn
for turn in conversations:
role = turn.get("role", "").lower()
content = turn.get("content", "")
if role == "user":
if pending_user:
# Two consecutive user turns — concatenate them.
pending_user = f"{pending_user}\n{content}"
else:
pending_user = content
elif role == "assistant":
if not pending_user:
# Assistant turn without a preceding user turn — skip.
continue
prompt = f"{history}{_USER_TAG}{pending_user}\n{_ASSISTANT_TAG}"
response = f"{content}{_EOS_STRING}"
pairs.append((prompt, response))
# Update history to include this full exchange (without the EOS
# so the model does not treat it as a hard stop mid-context).
history = f"{history}{_USER_TAG}{pending_user}\n{_ASSISTANT_TAG}{content}\n"
pending_user = ""
return pairs
# ---------------------------------------------------------------------------
# Multiprocessing worker for parallel tokenization.
# ---------------------------------------------------------------------------
_worker_tokenizer: Tokenizer | None = None
_worker_eos_id: int = -1
_worker_max_seq_len: int = 4096
def _worker_init(tokenizer_path: str, eos_string: str, max_seq_len: int) -> None:
"""Initializer for each pool worker — loads its own tokenizer instance."""
global _worker_tokenizer, _worker_eos_id, _worker_max_seq_len
_worker_tokenizer = Tokenizer.from_file(tokenizer_path)
eos_id = _worker_tokenizer.token_to_id(eos_string)
if eos_id is None:
raise ValueError(f"EOS token '{eos_string}' not found in worker tokenizer.")
_worker_eos_id = eos_id
_worker_max_seq_len = max_seq_len
def _worker_tokenize_batch(
batch: list[tuple[str, str]],
) -> list[tuple[list[int], list[int]] | None]:
"""
Tokenize a batch of (prompt, response) pairs in a worker process.
Returns a list of (prompt_ids, response_ids) as Python lists, or None
for samples that should be skipped.
"""
global _worker_tokenizer, _worker_eos_id, _worker_max_seq_len
tok = _worker_tokenizer
eos_id = _worker_eos_id
max_seq_len = _worker_max_seq_len
results = []
for prompt_text, response_text in batch:
prompt_ids = tok.encode(prompt_text).ids
response_ids = tok.encode(response_text).ids
# Skip samples where the prompt alone leaves no room for output.
if len(prompt_ids) >= max_seq_len - 10:
results.append(None)
continue
full_len = len(prompt_ids) + len(response_ids)
# Truncate response if combined sequence is too long.
if full_len > max_seq_len:
allowed_response = max_seq_len - len(prompt_ids)
if allowed_response <= 0:
results.append(None)
continue
response_ids = response_ids[:allowed_response]
# Force EOS at end after truncation.
if response_ids[-1] != eos_id:
response_ids[-1] = eos_id
results.append((prompt_ids, response_ids))
return results
class SFTDataset(Dataset):
"""
Supervised Fine-Tuning dataset built from JSONL files.
Each JSONL line must conform to one of three schemas described in the
module docstring. After tokenisation the sample is laid out as::
[prompt tokens ...] [response tokens ...] [pad tokens ...]
|<---- labels=-1 ---->| |<-- labels=token_id -->| |<- labels=-1 ->|
The ``labels`` tensor uses -1 as the ignore value so that
``nn.CrossEntropyLoss(ignore_index=-1)`` only penalises the model on
the assistant response tokens.
Args:
data_path: Path to a single ``.jsonl`` file or a directory that
contains multiple ``.jsonl`` files (all are loaded).
tokenizer: A ``tokenizers.Tokenizer`` instance (HuggingFace fast
tokenizer loaded from ``tokenizer.json``).
max_seq_len: Maximum sequence length (tokens). Samples exceeding
this are truncated from the *end of the response*.
Default: 4096.
pad_token_id: Token id used for right-padding. Default: 0.
"""
def __init__(
self,
data_path: Union[str, Path],
tokenizer: Tokenizer,
max_seq_len: int = 4096,
pad_token_id: int = 0,
tokenizer_path: Union[str, Path, None] = None,
num_workers: int = 60,
) -> None:
super().__init__()
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.pad_token_id = pad_token_id
# Resolve EOS token id from the vocabulary.
eos_id = tokenizer.token_to_id(_EOS_STRING)
if eos_id is None:
raise ValueError(
f"EOS token '{_EOS_STRING}' not found in the tokenizer vocabulary. "
"Check that the tokenizer was trained with this special token."
)
self.eos_token_id: int = eos_id
# ------------------------------------------------------------------
# Load raw JSONL samples.
# ------------------------------------------------------------------
data_path = Path(data_path)
raw_samples = self._load_jsonl(data_path)
# ------------------------------------------------------------------
# Try loading from cache first.
# ------------------------------------------------------------------
cache_path = Path(f"{data_path}.sft_cache.pt")
cache_key = self._make_cache_key(data_path, max_seq_len, tokenizer)
cached = self._try_load_cache(cache_path, cache_key)
if cached is not None:
self.samples = cached
return
# ------------------------------------------------------------------
# Tokenise and build (input_ids, labels) pairs.
# ------------------------------------------------------------------
t0 = time.time()
if tokenizer_path is not None:
self.samples = self._tokenize_parallel(
raw_samples, str(tokenizer_path), max_seq_len, num_workers,
)
else:
self.samples = self._tokenize_sequential(
raw_samples, tokenizer, max_seq_len,
)
elapsed = time.time() - t0
print(f"[SFTDataset] Tokenization took {elapsed:.1f}s")
# ------------------------------------------------------------------
# Save cache.
# ------------------------------------------------------------------
self._save_cache(cache_path, cache_key)
# ------------------------------------------------------------------
# Cache helpers
# ------------------------------------------------------------------
@staticmethod
def _make_cache_key(
data_path: Path, max_seq_len: int, tokenizer: Tokenizer,
) -> tuple:
"""Build a cheap cache key from file metadata + settings."""
if data_path.is_file():
stat = data_path.stat()
file_sig = (stat.st_size, stat.st_mtime)
else:
# Directory: combine stats of all jsonl files.
parts = []
for f in sorted(data_path.glob("*.jsonl")):
s = f.stat()
parts.append((str(f), s.st_size, s.st_mtime))
file_sig = tuple(parts)
return (file_sig, max_seq_len, tokenizer.get_vocab_size())
def _try_load_cache(
self, cache_path: Path, cache_key: tuple,
) -> list[tuple[torch.Tensor, torch.Tensor]] | None:
"""Load cached tokenized samples if cache is valid."""
if not cache_path.exists():
print(f"[SFTDataset] Cache miss — no cache file at {cache_path}")
return None
try:
t0 = time.time()
cache = torch.load(cache_path, map_location="cpu", weights_only=False)
if cache.get("cache_key") != cache_key:
print(f"[SFTDataset] Cache miss — stale cache (key mismatch)")
return None
samples = cache["samples"]
elapsed = time.time() - t0
print(
f"[SFTDataset] Cache hit! Loaded {len(samples)} samples "
f"from {cache_path} in {elapsed:.1f}s"
)
return samples
except Exception as exc:
print(f"[SFTDataset] Cache miss — failed to load: {exc}")
return None
def _save_cache(self, cache_path: Path, cache_key: tuple) -> None:
"""Save tokenized samples to cache file."""
try:
t0 = time.time()
torch.save(
{"cache_key": cache_key, "samples": self.samples},
cache_path,
)
elapsed = time.time() - t0
size_mb = cache_path.stat().st_size / (1024 * 1024)
print(
f"[SFTDataset] Saved cache ({size_mb:.0f} MB) "
f"to {cache_path} in {elapsed:.1f}s"
)
except Exception as exc:
print(f"[SFTDataset] WARNING: Failed to save cache: {exc}")
# ------------------------------------------------------------------
# Tokenization strategies
# ------------------------------------------------------------------
def _tokenize_sequential(
self,
raw_samples: list[tuple[str, str]],
tokenizer: Tokenizer,
max_seq_len: int,
) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""Original sequential tokenization (fallback when no tokenizer_path)."""
samples: list[tuple[torch.Tensor, torch.Tensor]] = []
total_loaded = 0
total_tokens = 0
skipped_too_long = 0
truncated_count = 0
for prompt_text, response_text in raw_samples:
total_loaded += 1
prompt_ids = tokenizer.encode(prompt_text).ids
response_ids = tokenizer.encode(response_text).ids
if len(prompt_ids) >= max_seq_len - 10:
skipped_too_long += 1
continue
full_ids = prompt_ids + response_ids
if len(full_ids) > max_seq_len:
truncated_count += 1
allowed_response = max_seq_len - len(prompt_ids)
if allowed_response <= 0:
skipped_too_long += 1
continue
response_ids = response_ids[:allowed_response]
if response_ids[-1] != self.eos_token_id:
response_ids[-1] = self.eos_token_id
full_ids = prompt_ids + response_ids
seq_len = len(full_ids)
total_tokens += seq_len
input_ids = torch.tensor(full_ids, dtype=torch.int32)
labels = torch.full((seq_len,), fill_value=-1, dtype=torch.int32)
resp_start = len(prompt_ids)
resp_label_start = max(0, resp_start - 1)
resp_label_end = resp_label_start + len(response_ids)
labels[resp_label_start:resp_label_end] = torch.tensor(
response_ids, dtype=torch.int32
)
samples.append((input_ids, labels))
n = len(samples)
avg_len = (total_tokens / n) if n > 0 else 0.0
print(
f"[SFTDataset] Loaded {n} samples "
f"(raw={total_loaded}, "
f"skipped_too_long={skipped_too_long}, "
f"truncated={truncated_count})"
)
print(
f"[SFTDataset] avg_seq_len={avg_len:.1f}, "
f"max_seq_len={max_seq_len}, "
f"pad_token_id={self.pad_token_id}, "
f"eos_token_id={self.eos_token_id}"
)
return samples
def _tokenize_parallel(
self,
raw_samples: list[tuple[str, str]],
tokenizer_path: str,
max_seq_len: int,
num_workers: int,
) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""Parallel tokenization using multiprocessing.Pool."""
total = len(raw_samples)
print(
f"[SFTDataset] Starting parallel tokenization: "
f"{total} samples, {num_workers} workers"
)
# Split raw_samples into chunks for imap_unordered.
chunk_size = 1000
chunks = []
for i in range(0, total, chunk_size):
chunks.append(raw_samples[i : i + chunk_size])
# Collect tokenized results from workers.
all_token_pairs: list[tuple[list[int], list[int]] | None] = []
processed = 0
# Use 'spawn' context to avoid fork+CUDA issues when called
# after model is already on GPU (e.g., in DDP training).
ctx = multiprocessing.get_context("spawn")
with ctx.Pool(
processes=num_workers,
initializer=_worker_init,
initargs=(tokenizer_path, _EOS_STRING, max_seq_len),
) as pool:
for batch_results in pool.imap_unordered(
_worker_tokenize_batch, chunks, chunksize=1,
):
all_token_pairs.extend(batch_results)
processed += len(batch_results)
if processed % 100_000 < chunk_size:
print(
f"[SFTDataset] Tokenized {processed}/{total} "
f"({100.0 * processed / total:.1f}%)"
)
# Print final progress if not already printed.
if processed % 100_000 >= chunk_size:
print(f"[SFTDataset] Tokenized {processed}/{total} (100.0%)")
# Convert to tensors and build samples.
samples: list[tuple[torch.Tensor, torch.Tensor]] = []
total_tokens = 0
skipped_too_long = 0
truncated_count = 0
for pair in all_token_pairs:
if pair is None:
skipped_too_long += 1
continue
prompt_ids, response_ids = pair
full_ids = prompt_ids + response_ids
# Count truncated: if combined length exactly equals max_seq_len,
# the worker likely truncated the response.
if len(full_ids) == max_seq_len:
truncated_count += 1
seq_len = len(full_ids)
total_tokens += seq_len
input_ids = torch.tensor(full_ids, dtype=torch.int32)
labels = torch.full((seq_len,), fill_value=-1, dtype=torch.int32)
resp_start = len(prompt_ids)
resp_label_start = max(0, resp_start - 1)
resp_label_end = resp_label_start + len(response_ids)
labels[resp_label_start:resp_label_end] = torch.tensor(
response_ids, dtype=torch.int32
)
samples.append((input_ids, labels))
n = len(samples)
avg_len = (total_tokens / n) if n > 0 else 0.0
print(
f"[SFTDataset] Loaded {n} samples "
f"(raw={total}, "
f"skipped_too_long={skipped_too_long}, "
f"truncated={truncated_count})"
)
print(
f"[SFTDataset] avg_seq_len={avg_len:.1f}, "
f"max_seq_len={max_seq_len}, "
f"pad_token_id={self.pad_token_id}, "
f"eos_token_id={self.eos_token_id}"
)
return samples
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _load_jsonl(self, path: Path) -> list[tuple[str, str]]:
"""
Discover and parse JSONL files, returning (prompt, response) pairs.
If ``path`` is a file, load that file only. If it is a directory,
load all ``*.jsonl`` files found directly inside it (non-recursive).
Args:
path: File or directory path.
Returns:
List of (prompt_text, response_text) tuples.
Raises:
FileNotFoundError: If ``path`` does not exist.
ValueError: If no ``.jsonl`` files are found under a
directory path.
"""
if not path.exists():
raise FileNotFoundError(f"Data path not found: {path}")
if path.is_dir():
jsonl_files = sorted(path.glob("*.jsonl"))
if not jsonl_files:
raise ValueError(f"No .jsonl files found in directory: {path}")
else:
jsonl_files = [path]
pairs: list[tuple[str, str]] = []
for jsonl_file in jsonl_files:
pairs.extend(self._parse_jsonl_file(jsonl_file))
return pairs
def _parse_jsonl_file(self, path: Path) -> list[tuple[str, str]]:
"""
Parse a single JSONL file into (prompt, response) pairs.
Lines that are empty, whitespace-only, or fail JSON parsing are
silently skipped with a warning. Lines whose schema cannot be
recognised are also skipped.
Args:
path: Path to a ``.jsonl`` file.
Returns:
List of (prompt_text, response_text) tuples extracted from
the file.
"""
pairs: list[tuple[str, str]] = []
with path.open("r", encoding="utf-8") as fh:
for lineno, line in enumerate(fh, start=1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError as exc:
print(
f"[SFTDataset] WARNING: JSON parse error in "
f"{path}:{lineno}{exc}"
)
continue
# ---- Conversation format ------------------------------------
# Support both "conversations" and "messages" keys
conv_list = obj.get("conversations") or obj.get("messages")
if conv_list and isinstance(conv_list, list):
turn_pairs = _build_conversation_turns(conv_list)
if not turn_pairs:
print(
f"[SFTDataset] WARNING: No valid user→assistant "
f"pairs in {path}:{lineno}, skipping."
)
pairs.extend(turn_pairs)
# ---- Alpaca / Alpaca-no-input format -----------------------
elif "instruction" in obj and "output" in obj:
prompt, response = _build_alpaca_turns(
instruction=obj["instruction"],
input_text=obj.get("input", ""),
output=obj["output"],
)
pairs.append((prompt, response))
else:
print(
f"[SFTDataset] WARNING: Unrecognised schema at "
f"{path}:{lineno}, skipping."
)
return pairs
# ------------------------------------------------------------------
# Dataset interface
# ------------------------------------------------------------------
def __len__(self) -> int:
"""Return the number of valid samples in the dataset."""
return len(self.samples)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return a single training sample.
Args:
idx: Sample index.
Returns:
Tuple ``(input_ids, labels)`` where both tensors have shape
``[seq_len]`` (variable per sample) and dtype ``torch.long``.
Use a collate function to pad batches dynamically.
- ``input_ids``: Full token sequence (prompt + response),
NO padding (raw length).
- ``labels``: Response token ids at response positions,
``-1`` everywhere else (prompt tokens).
Use ``ignore_index=-1`` in your loss function.
"""
input_ids, labels = self.samples[idx]
return input_ids.long(), labels.long()