| """ |
| SFT (Supervised Fine-Tuning) dataset for the Korean LLM project. |
| |
| Reads JSONL files in three supported formats: |
| |
| 1. Alpaca format |
| {"instruction": "...", "input": "...", "output": "..."} |
| |
| 2. Alpaca format without optional input |
| {"instruction": "...", "output": "..."} |
| |
| 3. Conversation format |
| {"conversations": [{"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."}]} |
| |
| Chat template applied: |
| |
| <|user|> |
| {instruction or user turn} |
| <|assistant|> |
| {output or assistant turn}</s> |
| |
| Loss masking: ``labels`` is -1 for all prompt tokens so |
| ``nn.CrossEntropyLoss`` (ignore_index=-1) only trains on |
| the assistant responses. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import multiprocessing |
| import time |
| from pathlib import Path |
| from typing import Union |
|
|
| import torch |
| from torch.utils.data import Dataset |
| from tokenizers import Tokenizer |
|
|
|
|
| |
| |
| |
| _USER_TAG = "<|user|>\n" |
| _ASSISTANT_TAG = "<|assistant|>\n" |
| _EOS_STRING = "</s>" |
|
|
|
|
| def _build_alpaca_turns( |
| instruction: str, |
| input_text: str, |
| output: str, |
| ) -> tuple[str, str]: |
| """ |
| Convert an Alpaca-format sample into (prompt, response) strings. |
| |
| The *prompt* includes the user tag and instruction (+ optional input). |
| The *response* includes the assistant tag and output, plus EOS. |
| |
| Args: |
| instruction: The task instruction. |
| input_text: Optional additional input context. May be empty. |
| output: The expected assistant response. |
| |
| Returns: |
| Tuple of (prompt_text, response_text). |
| """ |
| user_body = instruction |
| if input_text and input_text.strip(): |
| user_body = f"{instruction}\n{input_text.strip()}" |
|
|
| prompt = f"{_USER_TAG}{user_body}\n{_ASSISTANT_TAG}" |
| response = f"{output}{_EOS_STRING}" |
| return prompt, response |
|
|
|
|
| def _build_conversation_turns( |
| conversations: list[dict], |
| ) -> list[tuple[str, str]]: |
| """ |
| Convert a conversation list into a sequence of (prompt, response) pairs. |
| |
| For a multi-turn conversation the prompt for turn *k* is the entire |
| dialogue history up to (but not including) assistant turn *k*. |
| |
| Only user→assistant pairs contribute training samples. Consecutive |
| user messages are merged. Conversations that start with an assistant |
| turn, or that have no assistant turn, are skipped (return empty list). |
| |
| Args: |
| conversations: List of dicts with ``role`` and ``content`` keys. |
| Roles are expected to be ``"user"`` or ``"assistant"``. |
| |
| Returns: |
| List of (prompt_text, response_text) tuples, one per assistant turn. |
| """ |
| pairs: list[tuple[str, str]] = [] |
| history = "" |
| pending_user = "" |
|
|
| for turn in conversations: |
| role = turn.get("role", "").lower() |
| content = turn.get("content", "") |
|
|
| if role == "user": |
| if pending_user: |
| |
| pending_user = f"{pending_user}\n{content}" |
| else: |
| pending_user = content |
|
|
| elif role == "assistant": |
| if not pending_user: |
| |
| continue |
| prompt = f"{history}{_USER_TAG}{pending_user}\n{_ASSISTANT_TAG}" |
| response = f"{content}{_EOS_STRING}" |
| pairs.append((prompt, response)) |
| |
| |
| history = f"{history}{_USER_TAG}{pending_user}\n{_ASSISTANT_TAG}{content}\n" |
| pending_user = "" |
|
|
| return pairs |
|
|
|
|
| |
| |
| |
| _worker_tokenizer: Tokenizer | None = None |
| _worker_eos_id: int = -1 |
| _worker_max_seq_len: int = 4096 |
|
|
|
|
| def _worker_init(tokenizer_path: str, eos_string: str, max_seq_len: int) -> None: |
| """Initializer for each pool worker — loads its own tokenizer instance.""" |
| global _worker_tokenizer, _worker_eos_id, _worker_max_seq_len |
| _worker_tokenizer = Tokenizer.from_file(tokenizer_path) |
| eos_id = _worker_tokenizer.token_to_id(eos_string) |
| if eos_id is None: |
| raise ValueError(f"EOS token '{eos_string}' not found in worker tokenizer.") |
| _worker_eos_id = eos_id |
| _worker_max_seq_len = max_seq_len |
|
|
|
|
| def _worker_tokenize_batch( |
| batch: list[tuple[str, str]], |
| ) -> list[tuple[list[int], list[int]] | None]: |
| """ |
| Tokenize a batch of (prompt, response) pairs in a worker process. |
| |
| Returns a list of (prompt_ids, response_ids) as Python lists, or None |
| for samples that should be skipped. |
| """ |
| global _worker_tokenizer, _worker_eos_id, _worker_max_seq_len |
| tok = _worker_tokenizer |
| eos_id = _worker_eos_id |
| max_seq_len = _worker_max_seq_len |
| results = [] |
|
|
| for prompt_text, response_text in batch: |
| prompt_ids = tok.encode(prompt_text).ids |
| response_ids = tok.encode(response_text).ids |
|
|
| |
| if len(prompt_ids) >= max_seq_len - 10: |
| results.append(None) |
| continue |
|
|
| full_len = len(prompt_ids) + len(response_ids) |
|
|
| |
| if full_len > max_seq_len: |
| allowed_response = max_seq_len - len(prompt_ids) |
| if allowed_response <= 0: |
| results.append(None) |
| continue |
| response_ids = response_ids[:allowed_response] |
| |
| if response_ids[-1] != eos_id: |
| response_ids[-1] = eos_id |
|
|
| results.append((prompt_ids, response_ids)) |
|
|
| return results |
|
|
|
|
| class SFTDataset(Dataset): |
| """ |
| Supervised Fine-Tuning dataset built from JSONL files. |
| |
| Each JSONL line must conform to one of three schemas described in the |
| module docstring. After tokenisation the sample is laid out as:: |
| |
| [prompt tokens ...] [response tokens ...] [pad tokens ...] |
| |<---- labels=-1 ---->| |<-- labels=token_id -->| |<- labels=-1 ->| |
| |
| The ``labels`` tensor uses -1 as the ignore value so that |
| ``nn.CrossEntropyLoss(ignore_index=-1)`` only penalises the model on |
| the assistant response tokens. |
| |
| Args: |
| data_path: Path to a single ``.jsonl`` file or a directory that |
| contains multiple ``.jsonl`` files (all are loaded). |
| tokenizer: A ``tokenizers.Tokenizer`` instance (HuggingFace fast |
| tokenizer loaded from ``tokenizer.json``). |
| max_seq_len: Maximum sequence length (tokens). Samples exceeding |
| this are truncated from the *end of the response*. |
| Default: 4096. |
| pad_token_id: Token id used for right-padding. Default: 0. |
| """ |
|
|
| def __init__( |
| self, |
| data_path: Union[str, Path], |
| tokenizer: Tokenizer, |
| max_seq_len: int = 4096, |
| pad_token_id: int = 0, |
| tokenizer_path: Union[str, Path, None] = None, |
| num_workers: int = 60, |
| ) -> None: |
| super().__init__() |
|
|
| self.tokenizer = tokenizer |
| self.max_seq_len = max_seq_len |
| self.pad_token_id = pad_token_id |
|
|
| |
| eos_id = tokenizer.token_to_id(_EOS_STRING) |
| if eos_id is None: |
| raise ValueError( |
| f"EOS token '{_EOS_STRING}' not found in the tokenizer vocabulary. " |
| "Check that the tokenizer was trained with this special token." |
| ) |
| self.eos_token_id: int = eos_id |
|
|
| |
| |
| |
| data_path = Path(data_path) |
| raw_samples = self._load_jsonl(data_path) |
|
|
| |
| |
| |
| cache_path = Path(f"{data_path}.sft_cache.pt") |
| cache_key = self._make_cache_key(data_path, max_seq_len, tokenizer) |
| cached = self._try_load_cache(cache_path, cache_key) |
| if cached is not None: |
| self.samples = cached |
| return |
|
|
| |
| |
| |
| t0 = time.time() |
|
|
| if tokenizer_path is not None: |
| self.samples = self._tokenize_parallel( |
| raw_samples, str(tokenizer_path), max_seq_len, num_workers, |
| ) |
| else: |
| self.samples = self._tokenize_sequential( |
| raw_samples, tokenizer, max_seq_len, |
| ) |
|
|
| elapsed = time.time() - t0 |
| print(f"[SFTDataset] Tokenization took {elapsed:.1f}s") |
|
|
| |
| |
| |
| self._save_cache(cache_path, cache_key) |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _make_cache_key( |
| data_path: Path, max_seq_len: int, tokenizer: Tokenizer, |
| ) -> tuple: |
| """Build a cheap cache key from file metadata + settings.""" |
| if data_path.is_file(): |
| stat = data_path.stat() |
| file_sig = (stat.st_size, stat.st_mtime) |
| else: |
| |
| parts = [] |
| for f in sorted(data_path.glob("*.jsonl")): |
| s = f.stat() |
| parts.append((str(f), s.st_size, s.st_mtime)) |
| file_sig = tuple(parts) |
| return (file_sig, max_seq_len, tokenizer.get_vocab_size()) |
|
|
| def _try_load_cache( |
| self, cache_path: Path, cache_key: tuple, |
| ) -> list[tuple[torch.Tensor, torch.Tensor]] | None: |
| """Load cached tokenized samples if cache is valid.""" |
| if not cache_path.exists(): |
| print(f"[SFTDataset] Cache miss — no cache file at {cache_path}") |
| return None |
| try: |
| t0 = time.time() |
| cache = torch.load(cache_path, map_location="cpu", weights_only=False) |
| if cache.get("cache_key") != cache_key: |
| print(f"[SFTDataset] Cache miss — stale cache (key mismatch)") |
| return None |
| samples = cache["samples"] |
| elapsed = time.time() - t0 |
| print( |
| f"[SFTDataset] Cache hit! Loaded {len(samples)} samples " |
| f"from {cache_path} in {elapsed:.1f}s" |
| ) |
| return samples |
| except Exception as exc: |
| print(f"[SFTDataset] Cache miss — failed to load: {exc}") |
| return None |
|
|
| def _save_cache(self, cache_path: Path, cache_key: tuple) -> None: |
| """Save tokenized samples to cache file.""" |
| try: |
| t0 = time.time() |
| torch.save( |
| {"cache_key": cache_key, "samples": self.samples}, |
| cache_path, |
| ) |
| elapsed = time.time() - t0 |
| size_mb = cache_path.stat().st_size / (1024 * 1024) |
| print( |
| f"[SFTDataset] Saved cache ({size_mb:.0f} MB) " |
| f"to {cache_path} in {elapsed:.1f}s" |
| ) |
| except Exception as exc: |
| print(f"[SFTDataset] WARNING: Failed to save cache: {exc}") |
|
|
| |
| |
| |
|
|
| def _tokenize_sequential( |
| self, |
| raw_samples: list[tuple[str, str]], |
| tokenizer: Tokenizer, |
| max_seq_len: int, |
| ) -> list[tuple[torch.Tensor, torch.Tensor]]: |
| """Original sequential tokenization (fallback when no tokenizer_path).""" |
| samples: list[tuple[torch.Tensor, torch.Tensor]] = [] |
| total_loaded = 0 |
| total_tokens = 0 |
| skipped_too_long = 0 |
| truncated_count = 0 |
|
|
| for prompt_text, response_text in raw_samples: |
| total_loaded += 1 |
|
|
| prompt_ids = tokenizer.encode(prompt_text).ids |
| response_ids = tokenizer.encode(response_text).ids |
|
|
| if len(prompt_ids) >= max_seq_len - 10: |
| skipped_too_long += 1 |
| continue |
|
|
| full_ids = prompt_ids + response_ids |
|
|
| if len(full_ids) > max_seq_len: |
| truncated_count += 1 |
| allowed_response = max_seq_len - len(prompt_ids) |
| if allowed_response <= 0: |
| skipped_too_long += 1 |
| continue |
| response_ids = response_ids[:allowed_response] |
| if response_ids[-1] != self.eos_token_id: |
| response_ids[-1] = self.eos_token_id |
| full_ids = prompt_ids + response_ids |
|
|
| seq_len = len(full_ids) |
| total_tokens += seq_len |
|
|
| input_ids = torch.tensor(full_ids, dtype=torch.int32) |
| labels = torch.full((seq_len,), fill_value=-1, dtype=torch.int32) |
| resp_start = len(prompt_ids) |
| resp_label_start = max(0, resp_start - 1) |
| resp_label_end = resp_label_start + len(response_ids) |
| labels[resp_label_start:resp_label_end] = torch.tensor( |
| response_ids, dtype=torch.int32 |
| ) |
| samples.append((input_ids, labels)) |
|
|
| n = len(samples) |
| avg_len = (total_tokens / n) if n > 0 else 0.0 |
| print( |
| f"[SFTDataset] Loaded {n} samples " |
| f"(raw={total_loaded}, " |
| f"skipped_too_long={skipped_too_long}, " |
| f"truncated={truncated_count})" |
| ) |
| print( |
| f"[SFTDataset] avg_seq_len={avg_len:.1f}, " |
| f"max_seq_len={max_seq_len}, " |
| f"pad_token_id={self.pad_token_id}, " |
| f"eos_token_id={self.eos_token_id}" |
| ) |
| return samples |
|
|
| def _tokenize_parallel( |
| self, |
| raw_samples: list[tuple[str, str]], |
| tokenizer_path: str, |
| max_seq_len: int, |
| num_workers: int, |
| ) -> list[tuple[torch.Tensor, torch.Tensor]]: |
| """Parallel tokenization using multiprocessing.Pool.""" |
| total = len(raw_samples) |
| print( |
| f"[SFTDataset] Starting parallel tokenization: " |
| f"{total} samples, {num_workers} workers" |
| ) |
|
|
| |
| chunk_size = 1000 |
| chunks = [] |
| for i in range(0, total, chunk_size): |
| chunks.append(raw_samples[i : i + chunk_size]) |
|
|
| |
| all_token_pairs: list[tuple[list[int], list[int]] | None] = [] |
| processed = 0 |
|
|
| |
| |
| ctx = multiprocessing.get_context("spawn") |
| with ctx.Pool( |
| processes=num_workers, |
| initializer=_worker_init, |
| initargs=(tokenizer_path, _EOS_STRING, max_seq_len), |
| ) as pool: |
| for batch_results in pool.imap_unordered( |
| _worker_tokenize_batch, chunks, chunksize=1, |
| ): |
| all_token_pairs.extend(batch_results) |
| processed += len(batch_results) |
| if processed % 100_000 < chunk_size: |
| print( |
| f"[SFTDataset] Tokenized {processed}/{total} " |
| f"({100.0 * processed / total:.1f}%)" |
| ) |
|
|
| |
| if processed % 100_000 >= chunk_size: |
| print(f"[SFTDataset] Tokenized {processed}/{total} (100.0%)") |
|
|
| |
| samples: list[tuple[torch.Tensor, torch.Tensor]] = [] |
| total_tokens = 0 |
| skipped_too_long = 0 |
| truncated_count = 0 |
|
|
| for pair in all_token_pairs: |
| if pair is None: |
| skipped_too_long += 1 |
| continue |
|
|
| prompt_ids, response_ids = pair |
| full_ids = prompt_ids + response_ids |
|
|
| |
| |
| if len(full_ids) == max_seq_len: |
| truncated_count += 1 |
|
|
| seq_len = len(full_ids) |
| total_tokens += seq_len |
|
|
| input_ids = torch.tensor(full_ids, dtype=torch.int32) |
| labels = torch.full((seq_len,), fill_value=-1, dtype=torch.int32) |
| resp_start = len(prompt_ids) |
| resp_label_start = max(0, resp_start - 1) |
| resp_label_end = resp_label_start + len(response_ids) |
| labels[resp_label_start:resp_label_end] = torch.tensor( |
| response_ids, dtype=torch.int32 |
| ) |
| samples.append((input_ids, labels)) |
|
|
| n = len(samples) |
| avg_len = (total_tokens / n) if n > 0 else 0.0 |
| print( |
| f"[SFTDataset] Loaded {n} samples " |
| f"(raw={total}, " |
| f"skipped_too_long={skipped_too_long}, " |
| f"truncated={truncated_count})" |
| ) |
| print( |
| f"[SFTDataset] avg_seq_len={avg_len:.1f}, " |
| f"max_seq_len={max_seq_len}, " |
| f"pad_token_id={self.pad_token_id}, " |
| f"eos_token_id={self.eos_token_id}" |
| ) |
| return samples |
|
|
| |
| |
| |
|
|
| def _load_jsonl(self, path: Path) -> list[tuple[str, str]]: |
| """ |
| Discover and parse JSONL files, returning (prompt, response) pairs. |
| |
| If ``path`` is a file, load that file only. If it is a directory, |
| load all ``*.jsonl`` files found directly inside it (non-recursive). |
| |
| Args: |
| path: File or directory path. |
| |
| Returns: |
| List of (prompt_text, response_text) tuples. |
| |
| Raises: |
| FileNotFoundError: If ``path`` does not exist. |
| ValueError: If no ``.jsonl`` files are found under a |
| directory path. |
| """ |
| if not path.exists(): |
| raise FileNotFoundError(f"Data path not found: {path}") |
|
|
| if path.is_dir(): |
| jsonl_files = sorted(path.glob("*.jsonl")) |
| if not jsonl_files: |
| raise ValueError(f"No .jsonl files found in directory: {path}") |
| else: |
| jsonl_files = [path] |
|
|
| pairs: list[tuple[str, str]] = [] |
| for jsonl_file in jsonl_files: |
| pairs.extend(self._parse_jsonl_file(jsonl_file)) |
|
|
| return pairs |
|
|
| def _parse_jsonl_file(self, path: Path) -> list[tuple[str, str]]: |
| """ |
| Parse a single JSONL file into (prompt, response) pairs. |
| |
| Lines that are empty, whitespace-only, or fail JSON parsing are |
| silently skipped with a warning. Lines whose schema cannot be |
| recognised are also skipped. |
| |
| Args: |
| path: Path to a ``.jsonl`` file. |
| |
| Returns: |
| List of (prompt_text, response_text) tuples extracted from |
| the file. |
| """ |
| pairs: list[tuple[str, str]] = [] |
|
|
| with path.open("r", encoding="utf-8") as fh: |
| for lineno, line in enumerate(fh, start=1): |
| line = line.strip() |
| if not line: |
| continue |
|
|
| try: |
| obj = json.loads(line) |
| except json.JSONDecodeError as exc: |
| print( |
| f"[SFTDataset] WARNING: JSON parse error in " |
| f"{path}:{lineno} — {exc}" |
| ) |
| continue |
|
|
| |
| |
| conv_list = obj.get("conversations") or obj.get("messages") |
| if conv_list and isinstance(conv_list, list): |
| turn_pairs = _build_conversation_turns(conv_list) |
| if not turn_pairs: |
| print( |
| f"[SFTDataset] WARNING: No valid user→assistant " |
| f"pairs in {path}:{lineno}, skipping." |
| ) |
| pairs.extend(turn_pairs) |
|
|
| |
| elif "instruction" in obj and "output" in obj: |
| prompt, response = _build_alpaca_turns( |
| instruction=obj["instruction"], |
| input_text=obj.get("input", ""), |
| output=obj["output"], |
| ) |
| pairs.append((prompt, response)) |
|
|
| else: |
| print( |
| f"[SFTDataset] WARNING: Unrecognised schema at " |
| f"{path}:{lineno}, skipping." |
| ) |
|
|
| return pairs |
|
|
| |
| |
| |
|
|
| def __len__(self) -> int: |
| """Return the number of valid samples in the dataset.""" |
| return len(self.samples) |
|
|
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Return a single training sample. |
| |
| Args: |
| idx: Sample index. |
| |
| Returns: |
| Tuple ``(input_ids, labels)`` where both tensors have shape |
| ``[seq_len]`` (variable per sample) and dtype ``torch.long``. |
| Use a collate function to pad batches dynamically. |
| |
| - ``input_ids``: Full token sequence (prompt + response), |
| NO padding (raw length). |
| - ``labels``: Response token ids at response positions, |
| ``-1`` everywhere else (prompt tokens). |
| Use ``ignore_index=-1`` in your loss function. |
| """ |
| input_ids, labels = self.samples[idx] |
| return input_ids.long(), labels.long() |
|
|