# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Dict, List, Optional, Union import numpy as np import sentencepiece import torch class SentencePieceTokenizer: """ SentencePieceTokenizer https://github.com/google/sentencepiece Args: model_path: path to sentence piece tokenizer model. special_tokens: either list of special tokens or dictionary of token name to token value legacy: when set to True, the previous behavior of the SentecePiece wrapper will be restored, including the possibility to add special tokens inside wrapper. tokenizer: wraps an existing tokenizer """ def __init__( self, model_path: Optional[str] = None, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, legacy: bool = False, tokenizer: Optional[sentencepiece.SentencePieceProcessor] = None, ): model_path_provided = model_path is not None tokenizer_provided = tokenizer is not None if not (model_path_provided ^ tokenizer_provided): raise ValueError("Exactly only one of the arguments 'model_path', 'tokenizer' should be provided") if tokenizer_provided: self.tokenizer = tokenizer else: if not model_path or not os.path.exists(model_path): raise ValueError(f"model_path: {model_path} is invalid") self.tokenizer = sentencepiece.SentencePieceProcessor() self.tokenizer.Load(model_path) self.original_vocab_size = self.tokenizer.get_piece_size() self.vocab_size = self.tokenizer.get_piece_size() self.legacy = legacy self.special_token_to_id = {} self.id_to_special_token = {} if special_tokens: if not self.legacy: raise ValueError( "Special tokens must be None when legacy is set to False. Provide special tokens at train time." ) self.add_special_tokens(special_tokens) self.space_sensitive = self.text_to_tokens('x y') != self.text_to_tokens('x') + self.text_to_tokens('y') def text_to_tokens(self, text): if self.legacy: tokens = [] idx = 0 while 1: indices = {} for token in self.special_token_to_id: try: indices[token] = text[idx:].index(token) except ValueError: continue if len(indices) == 0: break next_token = min(indices, key=indices.get) next_idx = idx + indices[next_token] tokens.extend(self.tokenizer.encode_as_pieces(text[idx:next_idx])) tokens.append(next_token) idx = next_idx + len(next_token) tokens.extend(self.tokenizer.encode_as_pieces(text[idx:])) return tokens return self.tokenizer.encode_as_pieces(text) def encode(self, text): if self.legacy: ids = [] idx = 0 while 1: indices = {} for token in self.special_token_to_id: try: indices[token] = text[idx:].index(token) except ValueError: continue if len(indices) == 0: break next_token = min(indices, key=indices.get) next_idx = idx + indices[next_token] ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx])) ids.append(self.special_token_to_id[next_token]) idx = next_idx + len(next_token) ids.extend(self.tokenizer.encode_as_ids(text[idx:])) return ids return self.tokenizer.encode_as_ids(text) def tokens_to_text(self, tokens): if isinstance(tokens, np.ndarray): tokens = tokens.tolist() return self.tokenizer.decode_pieces(tokens) def batch_decode(self, ids): if isinstance(ids, np.ndarray) or torch.is_tensor(ids): ids = ids.tolist() if self.legacy: text = "" last_i = 0 for i, id in enumerate(ids): if id in self.id_to_special_token: text += self.tokenizer.decode_ids(ids[last_i:i]) + " " text += self.id_to_special_token[id] + " " last_i = i + 1 text += self.tokenizer.decode_ids(ids[last_i:]) return text.strip() return self.tokenizer.decode(ids) def token_to_id(self, token): if self.legacy and token in self.special_token_to_id: return self.special_token_to_id[token] return self.tokenizer.piece_to_id(token) def ids_to_tokens(self, ids): tokens = [] for id in ids: if id >= self.original_vocab_size: tokens.append(self.id_to_special_token[id]) else: tokens.append(self.tokenizer.id_to_piece(id)) return tokens def tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: if isinstance(tokens, str): tokens = [tokens] ids = [] for token in tokens: ids.append(self.token_to_id(token)) return ids def add_special_tokens(self, special_tokens): if not self.legacy: raise AttributeError("Special Token addition does not work when legacy is set to False.") if isinstance(special_tokens, list): for token in special_tokens: if ( self.tokenizer.piece_to_id(token) == self.tokenizer.unk_id() and token not in self.special_token_to_id ): self.special_token_to_id[token] = self.vocab_size self.id_to_special_token[self.vocab_size] = token self.vocab_size += 1 elif isinstance(special_tokens, dict): for token_name, token in special_tokens.items(): setattr(self, token_name, token) if ( self.tokenizer.piece_to_id(token) == self.tokenizer.unk_id() and token not in self.special_token_to_id ): self.special_token_to_id[token] = self.vocab_size self.id_to_special_token[self.vocab_size] = token self.vocab_size += 1 @property def pad_id(self): if self.legacy: pad_id = self.tokens_to_ids([self.pad_token])[0] else: pad_id = self.tokenizer.pad_id() return pad_id @property def bos_token_id(self): if self.legacy: bos_id = self.tokens_to_ids([self.bos_token])[0] else: bos_id = self.tokenizer.bos_id() return bos_id @property def eos_token_id(self): if self.legacy: eos_id = self.tokens_to_ids([self.eos_token])[0] else: eos_id = self.tokenizer.eos_id() return eos_id @property def sep_id(self): if self.legacy: return self.tokens_to_ids([self.sep_token])[0] else: raise NameError("Use function token_to_id to retrieve special tokens other than unk, pad, bos, and eos.") @property def cls_id(self): if self.legacy: return self.tokens_to_ids([self.cls_token])[0] else: raise NameError("Use function token_to_id to retrieve special tokens other than unk, pad, bos, and eos.") @property def mask_id(self): if self.legacy: return self.tokens_to_ids([self.mask_token])[0] else: raise NameError("Use function token_to_id to retrieve special tokens other than unk, pad, bos, and eos.") @property def unk_id(self): return self.tokenizer.unk_id() @property def additional_special_tokens_ids(self): """Returns a list of the additional special tokens (excluding bos, eos, pad, unk). Used to return sentinel tokens for e.g. T5.""" special_tokens = set( [self.bos_token, self.eos_token, self.pad_token, self.mask_token, self.cls_token, self.sep_token] ) return [v for k, v in self.special_token_to_id.items() if k not in special_tokens] @property def vocab(self): main_vocab = [self.tokenizer.id_to_piece(id) for id in range(self.tokenizer.get_piece_size())] special_tokens = [ self.id_to_special_token[self.original_vocab_size + i] for i in range(self.vocab_size - self.original_vocab_size) ] return main_vocab + special_tokens # Below are a few methods that mimic transformers.PreTrainedTokenizer for vLLM def convert_ids_to_tokens(self, ids, skip_special_tokens: bool = False): return self.ids_to_tokens(ids) # TODO: support skip_special_tokens def convert_tokens_to_string(self, tokens: List[str]): return self.tokens_to_text(tokens) def __len__(self): return self.vocab_size @property def is_fast(self): return True def get_added_vocab(self): return None