""" SFT (Supervised Fine-Tuning) dataset for the Korean LLM project. Reads JSONL files in three supported formats: 1. Alpaca format {"instruction": "...", "input": "...", "output": "..."} 2. Alpaca format without optional input {"instruction": "...", "output": "..."} 3. Conversation format {"conversations": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]} Chat template applied: <|user|> {instruction or user turn} <|assistant|> {output or assistant turn} Loss masking: ``labels`` is -1 for all prompt tokens so ``nn.CrossEntropyLoss`` (ignore_index=-1) only trains on the assistant responses. """ from __future__ import annotations import json import multiprocessing import time from pathlib import Path from typing import Union import torch from torch.utils.data import Dataset from tokenizers import Tokenizer # HuggingFace tokenizers (fast, Rust-based) # --------------------------------------------------------------------------- # Role tags used in the chat template. # --------------------------------------------------------------------------- _USER_TAG = "<|user|>\n" _ASSISTANT_TAG = "<|assistant|>\n" _EOS_STRING = "" def _build_alpaca_turns( instruction: str, input_text: str, output: str, ) -> tuple[str, str]: """ Convert an Alpaca-format sample into (prompt, response) strings. The *prompt* includes the user tag and instruction (+ optional input). The *response* includes the assistant tag and output, plus EOS. Args: instruction: The task instruction. input_text: Optional additional input context. May be empty. output: The expected assistant response. Returns: Tuple of (prompt_text, response_text). """ user_body = instruction if input_text and input_text.strip(): user_body = f"{instruction}\n{input_text.strip()}" prompt = f"{_USER_TAG}{user_body}\n{_ASSISTANT_TAG}" response = f"{output}{_EOS_STRING}" return prompt, response def _build_conversation_turns( conversations: list[dict], ) -> list[tuple[str, str]]: """ Convert a conversation list into a sequence of (prompt, response) pairs. For a multi-turn conversation the prompt for turn *k* is the entire dialogue history up to (but not including) assistant turn *k*. Only user→assistant pairs contribute training samples. Consecutive user messages are merged. Conversations that start with an assistant turn, or that have no assistant turn, are skipped (return empty list). Args: conversations: List of dicts with ``role`` and ``content`` keys. Roles are expected to be ``"user"`` or ``"assistant"``. Returns: List of (prompt_text, response_text) tuples, one per assistant turn. """ pairs: list[tuple[str, str]] = [] history = "" # accumulated dialogue so far pending_user = "" # user content not yet closed by an assistant turn for turn in conversations: role = turn.get("role", "").lower() content = turn.get("content", "") if role == "user": if pending_user: # Two consecutive user turns — concatenate them. pending_user = f"{pending_user}\n{content}" else: pending_user = content elif role == "assistant": if not pending_user: # Assistant turn without a preceding user turn — skip. continue prompt = f"{history}{_USER_TAG}{pending_user}\n{_ASSISTANT_TAG}" response = f"{content}{_EOS_STRING}" pairs.append((prompt, response)) # Update history to include this full exchange (without the EOS # so the model does not treat it as a hard stop mid-context). history = f"{history}{_USER_TAG}{pending_user}\n{_ASSISTANT_TAG}{content}\n" pending_user = "" return pairs # --------------------------------------------------------------------------- # Multiprocessing worker for parallel tokenization. # --------------------------------------------------------------------------- _worker_tokenizer: Tokenizer | None = None _worker_eos_id: int = -1 _worker_max_seq_len: int = 4096 def _worker_init(tokenizer_path: str, eos_string: str, max_seq_len: int) -> None: """Initializer for each pool worker — loads its own tokenizer instance.""" global _worker_tokenizer, _worker_eos_id, _worker_max_seq_len _worker_tokenizer = Tokenizer.from_file(tokenizer_path) eos_id = _worker_tokenizer.token_to_id(eos_string) if eos_id is None: raise ValueError(f"EOS token '{eos_string}' not found in worker tokenizer.") _worker_eos_id = eos_id _worker_max_seq_len = max_seq_len def _worker_tokenize_batch( batch: list[tuple[str, str]], ) -> list[tuple[list[int], list[int]] | None]: """ Tokenize a batch of (prompt, response) pairs in a worker process. Returns a list of (prompt_ids, response_ids) as Python lists, or None for samples that should be skipped. """ global _worker_tokenizer, _worker_eos_id, _worker_max_seq_len tok = _worker_tokenizer eos_id = _worker_eos_id max_seq_len = _worker_max_seq_len results = [] for prompt_text, response_text in batch: prompt_ids = tok.encode(prompt_text).ids response_ids = tok.encode(response_text).ids # Skip samples where the prompt alone leaves no room for output. if len(prompt_ids) >= max_seq_len - 10: results.append(None) continue full_len = len(prompt_ids) + len(response_ids) # Truncate response if combined sequence is too long. if full_len > max_seq_len: allowed_response = max_seq_len - len(prompt_ids) if allowed_response <= 0: results.append(None) continue response_ids = response_ids[:allowed_response] # Force EOS at end after truncation. if response_ids[-1] != eos_id: response_ids[-1] = eos_id results.append((prompt_ids, response_ids)) return results class SFTDataset(Dataset): """ Supervised Fine-Tuning dataset built from JSONL files. Each JSONL line must conform to one of three schemas described in the module docstring. After tokenisation the sample is laid out as:: [prompt tokens ...] [response tokens ...] [pad tokens ...] |<---- labels=-1 ---->| |<-- labels=token_id -->| |<- labels=-1 ->| The ``labels`` tensor uses -1 as the ignore value so that ``nn.CrossEntropyLoss(ignore_index=-1)`` only penalises the model on the assistant response tokens. Args: data_path: Path to a single ``.jsonl`` file or a directory that contains multiple ``.jsonl`` files (all are loaded). tokenizer: A ``tokenizers.Tokenizer`` instance (HuggingFace fast tokenizer loaded from ``tokenizer.json``). max_seq_len: Maximum sequence length (tokens). Samples exceeding this are truncated from the *end of the response*. Default: 4096. pad_token_id: Token id used for right-padding. Default: 0. """ def __init__( self, data_path: Union[str, Path], tokenizer: Tokenizer, max_seq_len: int = 4096, pad_token_id: int = 0, tokenizer_path: Union[str, Path, None] = None, num_workers: int = 60, ) -> None: super().__init__() self.tokenizer = tokenizer self.max_seq_len = max_seq_len self.pad_token_id = pad_token_id # Resolve EOS token id from the vocabulary. eos_id = tokenizer.token_to_id(_EOS_STRING) if eos_id is None: raise ValueError( f"EOS token '{_EOS_STRING}' not found in the tokenizer vocabulary. " "Check that the tokenizer was trained with this special token." ) self.eos_token_id: int = eos_id # ------------------------------------------------------------------ # Load raw JSONL samples. # ------------------------------------------------------------------ data_path = Path(data_path) raw_samples = self._load_jsonl(data_path) # ------------------------------------------------------------------ # Try loading from cache first. # ------------------------------------------------------------------ cache_path = Path(f"{data_path}.sft_cache.pt") cache_key = self._make_cache_key(data_path, max_seq_len, tokenizer) cached = self._try_load_cache(cache_path, cache_key) if cached is not None: self.samples = cached return # ------------------------------------------------------------------ # Tokenise and build (input_ids, labels) pairs. # ------------------------------------------------------------------ t0 = time.time() if tokenizer_path is not None: self.samples = self._tokenize_parallel( raw_samples, str(tokenizer_path), max_seq_len, num_workers, ) else: self.samples = self._tokenize_sequential( raw_samples, tokenizer, max_seq_len, ) elapsed = time.time() - t0 print(f"[SFTDataset] Tokenization took {elapsed:.1f}s") # ------------------------------------------------------------------ # Save cache. # ------------------------------------------------------------------ self._save_cache(cache_path, cache_key) # ------------------------------------------------------------------ # Cache helpers # ------------------------------------------------------------------ @staticmethod def _make_cache_key( data_path: Path, max_seq_len: int, tokenizer: Tokenizer, ) -> tuple: """Build a cheap cache key from file metadata + settings.""" if data_path.is_file(): stat = data_path.stat() file_sig = (stat.st_size, stat.st_mtime) else: # Directory: combine stats of all jsonl files. parts = [] for f in sorted(data_path.glob("*.jsonl")): s = f.stat() parts.append((str(f), s.st_size, s.st_mtime)) file_sig = tuple(parts) return (file_sig, max_seq_len, tokenizer.get_vocab_size()) def _try_load_cache( self, cache_path: Path, cache_key: tuple, ) -> list[tuple[torch.Tensor, torch.Tensor]] | None: """Load cached tokenized samples if cache is valid.""" if not cache_path.exists(): print(f"[SFTDataset] Cache miss — no cache file at {cache_path}") return None try: t0 = time.time() cache = torch.load(cache_path, map_location="cpu", weights_only=False) if cache.get("cache_key") != cache_key: print(f"[SFTDataset] Cache miss — stale cache (key mismatch)") return None samples = cache["samples"] elapsed = time.time() - t0 print( f"[SFTDataset] Cache hit! Loaded {len(samples)} samples " f"from {cache_path} in {elapsed:.1f}s" ) return samples except Exception as exc: print(f"[SFTDataset] Cache miss — failed to load: {exc}") return None def _save_cache(self, cache_path: Path, cache_key: tuple) -> None: """Save tokenized samples to cache file.""" try: t0 = time.time() torch.save( {"cache_key": cache_key, "samples": self.samples}, cache_path, ) elapsed = time.time() - t0 size_mb = cache_path.stat().st_size / (1024 * 1024) print( f"[SFTDataset] Saved cache ({size_mb:.0f} MB) " f"to {cache_path} in {elapsed:.1f}s" ) except Exception as exc: print(f"[SFTDataset] WARNING: Failed to save cache: {exc}") # ------------------------------------------------------------------ # Tokenization strategies # ------------------------------------------------------------------ def _tokenize_sequential( self, raw_samples: list[tuple[str, str]], tokenizer: Tokenizer, max_seq_len: int, ) -> list[tuple[torch.Tensor, torch.Tensor]]: """Original sequential tokenization (fallback when no tokenizer_path).""" samples: list[tuple[torch.Tensor, torch.Tensor]] = [] total_loaded = 0 total_tokens = 0 skipped_too_long = 0 truncated_count = 0 for prompt_text, response_text in raw_samples: total_loaded += 1 prompt_ids = tokenizer.encode(prompt_text).ids response_ids = tokenizer.encode(response_text).ids if len(prompt_ids) >= max_seq_len - 10: skipped_too_long += 1 continue full_ids = prompt_ids + response_ids if len(full_ids) > max_seq_len: truncated_count += 1 allowed_response = max_seq_len - len(prompt_ids) if allowed_response <= 0: skipped_too_long += 1 continue response_ids = response_ids[:allowed_response] if response_ids[-1] != self.eos_token_id: response_ids[-1] = self.eos_token_id full_ids = prompt_ids + response_ids seq_len = len(full_ids) total_tokens += seq_len input_ids = torch.tensor(full_ids, dtype=torch.int32) labels = torch.full((seq_len,), fill_value=-1, dtype=torch.int32) resp_start = len(prompt_ids) resp_label_start = max(0, resp_start - 1) resp_label_end = resp_label_start + len(response_ids) labels[resp_label_start:resp_label_end] = torch.tensor( response_ids, dtype=torch.int32 ) samples.append((input_ids, labels)) n = len(samples) avg_len = (total_tokens / n) if n > 0 else 0.0 print( f"[SFTDataset] Loaded {n} samples " f"(raw={total_loaded}, " f"skipped_too_long={skipped_too_long}, " f"truncated={truncated_count})" ) print( f"[SFTDataset] avg_seq_len={avg_len:.1f}, " f"max_seq_len={max_seq_len}, " f"pad_token_id={self.pad_token_id}, " f"eos_token_id={self.eos_token_id}" ) return samples def _tokenize_parallel( self, raw_samples: list[tuple[str, str]], tokenizer_path: str, max_seq_len: int, num_workers: int, ) -> list[tuple[torch.Tensor, torch.Tensor]]: """Parallel tokenization using multiprocessing.Pool.""" total = len(raw_samples) print( f"[SFTDataset] Starting parallel tokenization: " f"{total} samples, {num_workers} workers" ) # Split raw_samples into chunks for imap_unordered. chunk_size = 1000 chunks = [] for i in range(0, total, chunk_size): chunks.append(raw_samples[i : i + chunk_size]) # Collect tokenized results from workers. all_token_pairs: list[tuple[list[int], list[int]] | None] = [] processed = 0 # Use 'spawn' context to avoid fork+CUDA issues when called # after model is already on GPU (e.g., in DDP training). ctx = multiprocessing.get_context("spawn") with ctx.Pool( processes=num_workers, initializer=_worker_init, initargs=(tokenizer_path, _EOS_STRING, max_seq_len), ) as pool: for batch_results in pool.imap_unordered( _worker_tokenize_batch, chunks, chunksize=1, ): all_token_pairs.extend(batch_results) processed += len(batch_results) if processed % 100_000 < chunk_size: print( f"[SFTDataset] Tokenized {processed}/{total} " f"({100.0 * processed / total:.1f}%)" ) # Print final progress if not already printed. if processed % 100_000 >= chunk_size: print(f"[SFTDataset] Tokenized {processed}/{total} (100.0%)") # Convert to tensors and build samples. samples: list[tuple[torch.Tensor, torch.Tensor]] = [] total_tokens = 0 skipped_too_long = 0 truncated_count = 0 for pair in all_token_pairs: if pair is None: skipped_too_long += 1 continue prompt_ids, response_ids = pair full_ids = prompt_ids + response_ids # Count truncated: if combined length exactly equals max_seq_len, # the worker likely truncated the response. if len(full_ids) == max_seq_len: truncated_count += 1 seq_len = len(full_ids) total_tokens += seq_len input_ids = torch.tensor(full_ids, dtype=torch.int32) labels = torch.full((seq_len,), fill_value=-1, dtype=torch.int32) resp_start = len(prompt_ids) resp_label_start = max(0, resp_start - 1) resp_label_end = resp_label_start + len(response_ids) labels[resp_label_start:resp_label_end] = torch.tensor( response_ids, dtype=torch.int32 ) samples.append((input_ids, labels)) n = len(samples) avg_len = (total_tokens / n) if n > 0 else 0.0 print( f"[SFTDataset] Loaded {n} samples " f"(raw={total}, " f"skipped_too_long={skipped_too_long}, " f"truncated={truncated_count})" ) print( f"[SFTDataset] avg_seq_len={avg_len:.1f}, " f"max_seq_len={max_seq_len}, " f"pad_token_id={self.pad_token_id}, " f"eos_token_id={self.eos_token_id}" ) return samples # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _load_jsonl(self, path: Path) -> list[tuple[str, str]]: """ Discover and parse JSONL files, returning (prompt, response) pairs. If ``path`` is a file, load that file only. If it is a directory, load all ``*.jsonl`` files found directly inside it (non-recursive). Args: path: File or directory path. Returns: List of (prompt_text, response_text) tuples. Raises: FileNotFoundError: If ``path`` does not exist. ValueError: If no ``.jsonl`` files are found under a directory path. """ if not path.exists(): raise FileNotFoundError(f"Data path not found: {path}") if path.is_dir(): jsonl_files = sorted(path.glob("*.jsonl")) if not jsonl_files: raise ValueError(f"No .jsonl files found in directory: {path}") else: jsonl_files = [path] pairs: list[tuple[str, str]] = [] for jsonl_file in jsonl_files: pairs.extend(self._parse_jsonl_file(jsonl_file)) return pairs def _parse_jsonl_file(self, path: Path) -> list[tuple[str, str]]: """ Parse a single JSONL file into (prompt, response) pairs. Lines that are empty, whitespace-only, or fail JSON parsing are silently skipped with a warning. Lines whose schema cannot be recognised are also skipped. Args: path: Path to a ``.jsonl`` file. Returns: List of (prompt_text, response_text) tuples extracted from the file. """ pairs: list[tuple[str, str]] = [] with path.open("r", encoding="utf-8") as fh: for lineno, line in enumerate(fh, start=1): line = line.strip() if not line: continue try: obj = json.loads(line) except json.JSONDecodeError as exc: print( f"[SFTDataset] WARNING: JSON parse error in " f"{path}:{lineno} — {exc}" ) continue # ---- Conversation format ------------------------------------ # Support both "conversations" and "messages" keys conv_list = obj.get("conversations") or obj.get("messages") if conv_list and isinstance(conv_list, list): turn_pairs = _build_conversation_turns(conv_list) if not turn_pairs: print( f"[SFTDataset] WARNING: No valid user→assistant " f"pairs in {path}:{lineno}, skipping." ) pairs.extend(turn_pairs) # ---- Alpaca / Alpaca-no-input format ----------------------- elif "instruction" in obj and "output" in obj: prompt, response = _build_alpaca_turns( instruction=obj["instruction"], input_text=obj.get("input", ""), output=obj["output"], ) pairs.append((prompt, response)) else: print( f"[SFTDataset] WARNING: Unrecognised schema at " f"{path}:{lineno}, skipping." ) return pairs # ------------------------------------------------------------------ # Dataset interface # ------------------------------------------------------------------ def __len__(self) -> int: """Return the number of valid samples in the dataset.""" return len(self.samples) def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ Return a single training sample. Args: idx: Sample index. Returns: Tuple ``(input_ids, labels)`` where both tensors have shape ``[seq_len]`` (variable per sample) and dtype ``torch.long``. Use a collate function to pad batches dynamically. - ``input_ids``: Full token sequence (prompt + response), NO padding (raw length). - ``labels``: Response token ids at response positions, ``-1`` everywhere else (prompt tokens). Use ``ignore_index=-1`` in your loss function. """ input_ids, labels = self.samples[idx] return input_ids.long(), labels.long()