Instructions to use Elvis-t9/CGE-test with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Elvis-t9/CGE-test with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Elvis-t9/CGE-test", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Elvis-t9/CGE-test", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import gc | |
| import inspect | |
| import math | |
| import multiprocessing as mp | |
| import queue | |
| from multiprocessing import Queue | |
| import warnings | |
| from typing import Any, Union, List, Dict, Literal, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| from transformers import PretrainedConfig | |
| from transformers import Qwen2Config | |
| from transformers.activations import ACT2FN | |
| from transformers.cache_utils import Cache, DynamicCache | |
| from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.utils import ( | |
| add_start_docstrings, | |
| add_start_docstrings_to_model_forward, | |
| is_flash_attn_2_available, | |
| is_flash_attn_greater_or_equal_2_10, | |
| logging, | |
| replace_return_docstrings, | |
| ) | |
| import numpy as np | |
| from transformers import Qwen2Config | |
| from transformers import Qwen2ForCausalLM | |
| import inspect | |
| import math | |
| import os | |
| import warnings | |
| from typing import List, Optional, Tuple, Union | |
| from tqdm import tqdm, trange | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| from transformers.activations import ACT2FN | |
| from transformers.cache_utils import Cache, DynamicCache | |
| from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.utils import ( | |
| add_start_docstrings, | |
| add_start_docstrings_to_model_forward, | |
| is_flash_attn_2_available, | |
| is_flash_attn_greater_or_equal_2_10, | |
| logging, | |
| replace_return_docstrings, | |
| ) | |
| import numpy as np | |
| import torch | |
| import os | |
| import argparse | |
| import json | |
| from tqdm import tqdm | |
| from typing import cast, List, Union, Tuple | |
| from transformers import AutoTokenizer, AutoModel # pylint: disable=C0413 | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| import time | |
| import torch.nn.functional as F | |
| import sys | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from tqdm import tqdm, trange | |
| from collections import defaultdict | |
| from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig | |
| import torch.distributed as dist | |
| from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import re | |
| import logging | |
| logging.getLogger().setLevel(logging.INFO) | |
| from .configuration_cge2 import CGEConfig | |
| from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2Attention | |
| class MAB_POST(nn.Module): | |
| def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): | |
| super(MAB_POST, self).__init__() | |
| self.dim_V = dim_V | |
| self.num_heads = num_heads | |
| self.fc_q = nn.Linear(dim_Q, dim_V) | |
| self.fc_k = nn.Linear(dim_K, dim_V) | |
| self.fc_v = nn.Linear(dim_K, dim_V) | |
| if ln: | |
| self.ln0 = nn.LayerNorm(dim_V) | |
| self.ln1 = nn.LayerNorm(dim_V) | |
| self.fc_o = nn.Linear(dim_V, dim_V) | |
| nn.init.xavier_uniform_(self.fc_q.weight) | |
| nn.init.xavier_uniform_(self.fc_k.weight) | |
| nn.init.xavier_uniform_(self.fc_v.weight) | |
| nn.init.xavier_uniform_(self.fc_o.weight) | |
| def forward(self, Q, K, pad_mask=None): | |
| Q_ = self.fc_q(Q) | |
| K_, V_ = self.fc_k(K), self.fc_v(K) | |
| dim_split = self.dim_V // self.num_heads | |
| Q_ = torch.cat(Q_.split(dim_split, 2), 0) | |
| K_ = torch.cat(K_.split(dim_split, 2), 0) | |
| V_ = torch.cat(V_.split(dim_split, 2), 0) | |
| pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) | |
| score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) | |
| score = score.masked_fill(pad_mask == 0, -1e12) | |
| A = torch.softmax(score, 2) | |
| A = A * pad_mask | |
| O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) | |
| O = Q + O | |
| O = O if getattr(self, 'ln0', None) is None else self.ln0(O) | |
| O = O + F.relu(self.fc_o(O)) | |
| O = O if getattr(self, 'ln1', None) is None else self.ln1(O) | |
| return O | |
| class PMA(nn.Module): | |
| def __init__(self, dim, compressed_dim, num_heads, num_seeds, ln=False, pma_mode=None): | |
| super(PMA, self).__init__() | |
| self.S = nn.Parameter(torch.Tensor(1, num_seeds, compressed_dim)) | |
| nn.init.xavier_uniform_(self.S) | |
| if pma_mode == 'post_normal': | |
| self.mab = MAB_POST(compressed_dim, dim, compressed_dim, num_heads, ln=ln) | |
| else: | |
| raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !") | |
| def forward(self, X, pad_mask): | |
| if self.S.dtype != torch.bfloat16: | |
| X = X.float() | |
| return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask) | |
| class MAB_POST_v2(nn.Module): | |
| def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): | |
| super(MAB_POST_v2, self).__init__() | |
| self.dim_V = dim_V | |
| self.num_heads = num_heads | |
| self.fc_q = nn.Linear(dim_Q, dim_V) | |
| self.fc_k = nn.Linear(dim_K, dim_V) | |
| self.fc_v = nn.Linear(dim_K, dim_V) | |
| if ln: | |
| self.ln0 = nn.LayerNorm(dim_V) | |
| self.ln1 = nn.LayerNorm(dim_V) | |
| self.fc_o = nn.Linear(dim_V, dim_V) | |
| nn.init.xavier_uniform_(self.fc_q.weight) | |
| nn.init.xavier_uniform_(self.fc_k.weight) | |
| nn.init.xavier_uniform_(self.fc_v.weight) | |
| nn.init.xavier_uniform_(self.fc_o.weight) | |
| # Q(B, num_seed, D), pad_mask (bs, seq) Post-LN | |
| def forward(self, Q, K, pad_mask=None): | |
| Q_tmp = self.fc_q(Q) # B, num_seed, C | |
| K_, V_ = self.fc_k(K), self.fc_v(K) # B, L, C | |
| dim_split = self.dim_V // self.num_heads | |
| Q_ = torch.cat(Q_tmp.split(dim_split, 2), 0) # (B* num_head, num_seed, C) | |
| K_ = torch.cat(K_.split(dim_split, 2), 0) # (B* num_head, L, C) | |
| V_ = torch.cat(V_.split(dim_split, 2), 0) # (B* num_head,L, C) | |
| pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) # (B*num_head, num_seed, L) | |
| score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) # (B*num_head, num_seed, L) | |
| score = score.masked_fill(pad_mask == 0, -1e12) # B,num_seed,L | |
| A = torch.softmax(score, 2) # (B*num_head, num_seed, L) | |
| A = A * pad_mask | |
| O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) # (B, num_seed, D) | |
| O = Q_tmp + O | |
| # O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) | |
| O = O if getattr(self, 'ln0', None) is None else self.ln0(O) | |
| O = O + F.relu(self.fc_o(O)) | |
| O = O if getattr(self, 'ln1', None) is None else self.ln1(O) | |
| return O | |
| class PMA_v2(nn.Module): | |
| def __init__(self, dim, compressed_dim, num_heads, num_seeds, ln=False): | |
| super(PMA_v2, self).__init__() | |
| self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) | |
| nn.init.xavier_uniform_(self.S) | |
| # if pma_mode == 'post_normal': | |
| self.mab = MAB_POST_v2(dim, dim, compressed_dim, num_heads, ln=ln) | |
| # elif pma_mode == 'pre_normal': | |
| # self.mab = MAB_PRE_NORMAL(dim, dim, compressed_dim, num_heads, ln=ln) | |
| # elif pma_mode == 'pre_gptj': | |
| # self.mab = MAB_PRE_GPTJ(dim, dim, compressed_dim, num_heads, ln=ln) | |
| # else: | |
| # raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !") | |
| # X: (bs, seq, emb), pad_mask: (bs, seq) | |
| def forward(self, X, pad_mask): | |
| if self.S.dtype != torch.bfloat16: | |
| X = X.float() | |
| return self.mab(self.S.expand(X.size(0), -1, -1), X, pad_mask) | |
| class CGEModel(PreTrainedModel): | |
| config_class = CGEConfig | |
| config: CGEConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["Qwen2DecoderLayer"] | |
| _skip_keys_device_placement = ["past_key_values"] | |
| _supports_flash_attn = True | |
| _supports_sdpa = True | |
| _supports_flex_attn = True | |
| _can_compile_fullgraph = True | |
| _supports_attention_backend = True | |
| _can_record_outputs = { | |
| "hidden_states": Qwen2DecoderLayer, | |
| "attentions": Qwen2Attention, | |
| } | |
| class CgeForEmbedding(CGEModel): | |
| config_class = CGEConfig | |
| model_type = "cge2" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| qwen_cfg = Qwen2Config.from_dict(config.to_dict()) | |
| self.plm_model = AutoModelForCausalLM.from_config(qwen_cfg) | |
| self.embedding_method = config.embedding_method | |
| self.inf_seq_length = 1024 | |
| self.padding_side = config.padding_side | |
| self.emb_dim = self.plm_model.model.embed_tokens.weight.size(1) | |
| self.keep_max_layer = self.plm_model.config.num_hidden_layers | |
| self.num_heads = config.pma_num_heads | |
| self.ln = config.pma_ln | |
| self.norm = config.pma_norm | |
| self.pma_mode = config.pma_norm_mode | |
| self.compressed_dim = config.compressed_dim | |
| self.mha_pma_disc = PMA_v2(self.emb_dim, self.compressed_dim, self.num_heads, 1, ln=self.ln) | |
| self.pool = None | |
| self.target_devices = self.get_target_devices(None) | |
| self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path, padding_side=config.padding_side) if config.tokenizer_name_or_path is not None else None | |
| self.config_class = CGEConfig | |
| def pma_embedding(self, mha_pma, A, mask): | |
| res = mha_pma(A, mask).squeeze(1) | |
| return res | |
| def get_hidden_states(self, **inputs): | |
| outputs = self.plm_model(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True) | |
| return outputs.hidden_states[self.keep_max_layer] | |
| def get_sentence_embedding(self, embedding_method, hidden_states, emb_type, attention_mask): | |
| if embedding_method == 'pma': | |
| if emb_type == 'disc': | |
| res_embedding = self.pma_embedding(self.mha_pma_disc, hidden_states, attention_mask) | |
| if self.norm: | |
| res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None) | |
| return res_embedding | |
| else: | |
| raise NotImplementedError(f"emb type {emb_type} hasn't been implemented") | |
| else: | |
| raise NotImplementedError(f"embedding method {embedding_method} hasn't been implemented") | |
| def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]: | |
| """ | |
| Args: | |
| devices (Union[str, int, List[str], List[int]]): specified devices, can be `str`, `int`, list of `str`, or list of `int`. | |
| Raises: | |
| ValueError: Devices should be a string or an integer or a list of strings or a list of integers. | |
| Returns: | |
| List[str]: A list of target devices in format. | |
| """ | |
| if devices is None: | |
| if torch.cuda.is_available(): | |
| return [f"cuda:{i}" for i in range(torch.cuda.device_count())] | |
| elif is_torch_npu_available(): | |
| return [f"npu:{i}" for i in range(torch.npu.device_count())] | |
| elif hasattr(torch, "musa") and torch.musa.is_available(): | |
| return [f"musa:{i}" for i in range(torch.musa.device_count())] | |
| elif torch.backends.mps.is_available(): | |
| try: | |
| return [f"mps:{i}" for i in range(torch.mps.device_count())] | |
| except: | |
| return ["mps"] | |
| else: | |
| return ["cpu"] | |
| elif isinstance(devices, str): | |
| return [devices] | |
| elif isinstance(devices, int): | |
| if hasattr(torch, "musa") and torch.musa.is_available(): | |
| return [f"musa:{devices}"] | |
| else: | |
| return [f"cuda:{devices}"] | |
| elif isinstance(devices, list): | |
| if isinstance(devices[0], str): | |
| return devices | |
| elif isinstance(devices[0], int): | |
| if hasattr(torch, "musa") and torch.musa.is_available(): | |
| return [f"musa:{device}" for device in devices] | |
| else: | |
| return [f"cuda:{device}" for device in devices] | |
| else: | |
| raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.") | |
| else: | |
| raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.") | |
| # adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L807 | |
| def start_multi_process_pool( | |
| self, | |
| process_target_func: Any, | |
| ) -> Dict[Literal["input", "output", "processes"], Any]: | |
| """ | |
| Starts a multi-process pool to process the encoding with several independent processes | |
| via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`. | |
| This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised | |
| to start only one process per GPU. This method works together with encode_multi_process | |
| and stop_multi_process_pool. | |
| Returns: | |
| Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue. | |
| """ | |
| if self.plm_model is None or self.mha_pma_disc is None: | |
| raise ValueError("Model is not initialized.") | |
| logging.info("Start multi-process pool on devices: {}".format(", ".join(map(str, self.target_devices)))) | |
| self.to("cpu") | |
| self.share_memory() | |
| ctx = mp.get_context("spawn") | |
| input_queue = ctx.Queue() | |
| output_queue = ctx.Queue() | |
| processes = [] | |
| for device_id in tqdm(self.target_devices, desc='initial target device'): | |
| p = ctx.Process( | |
| target=process_target_func, | |
| args=(device_id, self, input_queue, output_queue), | |
| daemon=True, | |
| ) | |
| p.start() | |
| processes.append(p) | |
| return {"input": input_queue, "output": output_queue, "processes": processes} | |
| def _encode_multi_process_worker( | |
| target_device: str, model: 'CgeForEmbedding', input_queue: Queue, results_queue: Queue | |
| ) -> None: | |
| model = model.to(target_device) | |
| while True: | |
| try: | |
| chunk_id, sentences, kwargs = ( | |
| input_queue.get() | |
| ) | |
| embeddings = model.encode_single_device( | |
| sentences, | |
| device=target_device, | |
| **kwargs | |
| ) | |
| results_queue.put([chunk_id, embeddings]) | |
| except queue.Empty: | |
| break | |
| def encode_multi_process( | |
| self, | |
| sentences: List[str], | |
| pool: Dict[Literal["input", "output", "processes"], Any], | |
| **kwargs | |
| ): | |
| chunk_size = math.ceil(len(sentences) / len(pool["processes"])) | |
| input_queue = pool["input"] | |
| last_chunk_id = 0 | |
| chunk = [] | |
| for sentence in sentences: | |
| chunk.append(sentence) | |
| if len(chunk) >= chunk_size: | |
| input_queue.put( | |
| [last_chunk_id, chunk, kwargs] | |
| ) | |
| last_chunk_id += 1 | |
| chunk = [] | |
| if len(chunk) > 0: | |
| input_queue.put([last_chunk_id, chunk, kwargs]) | |
| last_chunk_id += 1 | |
| output_queue = pool["output"] | |
| results_list = sorted( | |
| [output_queue.get() for _ in trange(last_chunk_id, desc="")], | |
| key=lambda x: x[0], | |
| ) | |
| embeddings = self._concatenate_results_from_multi_process([result[1] for result in results_list]) | |
| return embeddings | |
| def _concatenate_results_from_multi_process(self, results_list: List[Union[torch.Tensor, np.ndarray, Any]]): | |
| """concatenate and return the results from all the processes | |
| Args: | |
| results_list (List[Union[torch.Tensor, np.ndarray, Any]]): A list of results from all the processes. | |
| Raises: | |
| NotImplementedError: Unsupported type for results_list | |
| Returns: | |
| Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor. | |
| """ | |
| if isinstance(results_list[0], torch.Tensor): | |
| # move all tensors to the same device | |
| results_list = [res.to(self.target_devices[0]) for res in results_list] | |
| return torch.cat(results_list, dim=0) | |
| elif isinstance(results_list[0], np.ndarray): | |
| return np.concatenate(results_list, axis=0) | |
| else: | |
| raise NotImplementedError("Unsupported type for results_list") | |
| def encode_single_device( | |
| self, | |
| sentences: Union[List[str], str], | |
| batch_size: int = 16, | |
| convert_to_numpy: bool = False, | |
| convert_to_tensor: bool = True, | |
| show_progress_bar: bool = True, | |
| max_seq_length: int = 1024, | |
| device: Optional[str] = None, | |
| **kwargs: Any | |
| ): | |
| if max_seq_length is None: | |
| max_seq_length = self.inf_seq_length | |
| input_is_string = False | |
| if isinstance(sentences, str) or not hasattr(sentences, "__len__"): | |
| sentences = [sentences] | |
| input_is_string = True | |
| all_embeddings = [] | |
| length_sorted_idx = np.argsort([-len(s) for s in sentences]) | |
| sentences_sorted = [sentences[idx] for idx in length_sorted_idx] # 大到小重排 | |
| with torch.no_grad(): | |
| for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): | |
| sentences_batch = sentences_sorted[start_index: start_index + batch_size] | |
| inputs = self.tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, return_tensors='pt').to(self.plm_model.device) | |
| hidden_states = self.get_hidden_states(**inputs) | |
| embeddings = self.get_sentence_embedding(self.embedding_method, hidden_states, 'disc', inputs['attention_mask']) | |
| embeddings = embeddings.detach() | |
| if convert_to_numpy: | |
| if embeddings.dtype == torch.bfloat16: | |
| embeddings = embeddings.cpu().to(torch.float32) | |
| else: | |
| embeddings = embeddings.cpu() | |
| all_embeddings.extend(embeddings) | |
| all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] | |
| if convert_to_tensor: | |
| all_embeddings = torch.stack(all_embeddings) | |
| elif convert_to_numpy: | |
| all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) | |
| if input_is_string: | |
| all_embeddings = all_embeddings[0] | |
| return all_embeddings | |
| def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, return_dict: bool=True, **kwargs): | |
| outputs = self.plm_model(input_ids, attention_mask, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states[self.keep_max_layer] | |
| embeddings = self.get_sentence_embedding(self.embedding_method, hidden_states, 'disc', attention_mask) | |
| if not return_dict: | |
| return (embeddings,) | |
| return {"sentence_embedding": embeddings} | |
| def encode(self, sentences, batch_size=16, convert_to_numpy=False, | |
| convert_to_tensor=True, show_progress_bar=True, max_seq_length=None, **kwargs): | |
| if max_seq_length is None: | |
| max_seq_length = self.inf_seq_length | |
| if convert_to_tensor == convert_to_numpy: | |
| convert_to_tensor=True | |
| convert_to_numpy=False | |
| if isinstance(sentences, str) or len(self.target_devices) == 1: | |
| return self.encode_single_device( | |
| sentences, | |
| batch_size=batch_size, | |
| convert_to_numpy=convert_to_numpy, | |
| convert_to_tensor=convert_to_tensor, | |
| show_progress_bar=show_progress_bar, | |
| max_seq_length=max_seq_length, | |
| device=self.target_devices[0], | |
| **kwargs | |
| ) | |
| if self.pool is None: | |
| self.pool = self.start_multi_process_pool(CgeForEmbedding._encode_multi_process_worker) | |
| all_embeddings = [] | |
| length_sorted_idx = np.argsort([-len(s) for s in sentences]) | |
| sentences_sorted = [sentences[idx] for idx in length_sorted_idx] # 大到小重排 | |
| with torch.no_grad(): | |
| for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): | |
| sentences_batch = sentences_sorted[start_index: start_index + batch_size] | |
| embeddings_batch = self.encode_multi_process( | |
| sentences_batch, | |
| self.pool, | |
| convert_to_numpy=convert_to_numpy, | |
| convert_to_tensor=convert_to_tensor, | |
| show_progress_bar=show_progress_bar, | |
| max_seq_length=max_seq_length, | |
| **kwargs | |
| ) | |
| embeddings_batch = embeddings_batch.detach() | |
| if convert_to_numpy: | |
| if embeddings_batch.dtype == torch.bfloat16: | |
| embeddings_batch = embeddings_batch.cpu().to(torch.float32) | |
| else: | |
| embeddings_batch = embeddings_batch.cpu() | |
| all_embeddings.extend(embeddings_batch) | |
| all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] | |
| if convert_to_tensor: | |
| all_embeddings = torch.stack(all_embeddings) | |
| elif convert_to_numpy: | |
| all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) | |
| return all_embeddings | |
| def encode_queries(self, sentences, batch_size=16, convert_to_numpy=False, | |
| convert_to_tensor=True, show_progress_bar=True, max_seq_length=None, **kwargs): | |
| if max_seq_length is None: | |
| max_seq_length = self.inf_seq_length | |
| if convert_to_tensor == convert_to_numpy: | |
| convert_to_tensor=True | |
| convert_to_numpy=False | |
| return self.encode( | |
| sentences=sentences, | |
| batch_size=batch_size, | |
| convert_to_numpy=convert_to_numpy, | |
| convert_to_tensor=convert_to_tensor, | |
| show_progress_bar=show_progress_bar, | |
| max_seq_length=max_seq_length, | |
| **kwargs | |
| ) | |
| def encode_corpus(self, sentences, batch_size=16, convert_to_numpy=False, | |
| convert_to_tensor=True, show_progress_bar=True, max_seq_length=None, **kwargs): | |
| if max_seq_length is None: | |
| max_seq_length = self.inf_seq_length | |
| if convert_to_tensor == convert_to_numpy: | |
| convert_to_tensor=True | |
| convert_to_numpy=False | |
| sentences = [sentence['title']+' '+sentence['text'] for sentence in sentences] | |
| return self.encode( | |
| sentences=sentences, | |
| batch_size=batch_size, | |
| convert_to_numpy=convert_to_numpy, | |
| convert_to_tensor=convert_to_tensor, | |
| show_progress_bar=show_progress_bar, | |
| max_seq_length=max_seq_length, | |
| **kwargs | |
| ) | |
| def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"], Any]) -> None: | |
| """ | |
| Stops all processes started with start_multi_process_pool. | |
| Args: | |
| pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list. | |
| Returns: | |
| None | |
| """ | |
| for p in pool["processes"]: | |
| p.terminate() | |
| for p in pool["processes"]: | |
| p.join() | |
| p.close() | |
| pool["input"].close() | |
| pool["output"].close() | |
| pool = None | |
| def stop_self_pool(self): | |
| if self.pool is not None: | |
| self.stop_multi_process_pool(self.pool) | |
| self.pool = None | |
| try: | |
| self.model.to('cpu') | |
| torch.cuda.empty_cache() | |
| except: | |
| pass | |
| if gc is not None and callable(gc.collect): | |
| gc.collect() | |
| def __del__(self): | |
| self.stop_self_pool() | |