Use torch.inference_mode() and disable gradient checkpointing
#4
by
prathamj31
- opened
- config.json +4 -1
- modeling_zeranker.py +4 -3
config.json
CHANGED
|
@@ -64,5 +64,8 @@
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
-
"vocab_size": 151936
|
|
|
|
|
|
|
|
|
|
| 68 |
}
|
|
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
+
"vocab_size": 151936,
|
| 68 |
+
"auto_map": {
|
| 69 |
+
"AutoConfig": "modeling_zeranker.ZEConfig"
|
| 70 |
+
}
|
| 71 |
}
|
modeling_zeranker.py
CHANGED
|
@@ -24,7 +24,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
-
PER_DEVICE_BATCH_SIZE_TOKENS =
|
| 28 |
global_device = (
|
| 29 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 30 |
)
|
|
@@ -127,8 +127,8 @@ def predict(
|
|
| 127 |
|
| 128 |
if not hasattr(self, "inner_model"):
|
| 129 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 130 |
-
self.inner_model.gradient_checkpointing_enable()
|
| 131 |
self.inner_model.eval()
|
|
|
|
| 132 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
"Yes", add_special_tokens=False
|
| 134 |
)[0]
|
|
@@ -172,7 +172,8 @@ def predict(
|
|
| 172 |
batch_inputs = batch_inputs.to(global_device)
|
| 173 |
|
| 174 |
try:
|
| 175 |
-
|
|
|
|
| 176 |
except torch.OutOfMemoryError:
|
| 177 |
print(f"GPU OOM! {torch.cuda.memory_reserved()}")
|
| 178 |
torch.cuda.empty_cache()
|
|
|
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
+
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 28 |
global_device = (
|
| 29 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 30 |
)
|
|
|
|
| 127 |
|
| 128 |
if not hasattr(self, "inner_model"):
|
| 129 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
|
|
|
| 130 |
self.inner_model.eval()
|
| 131 |
+
self.inner_model.gradient_checkpointing_disable()
|
| 132 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
"Yes", add_special_tokens=False
|
| 134 |
)[0]
|
|
|
|
| 172 |
batch_inputs = batch_inputs.to(global_device)
|
| 173 |
|
| 174 |
try:
|
| 175 |
+
with torch.inference_mode():
|
| 176 |
+
outputs = model(**batch_inputs, use_cache=False)
|
| 177 |
except torch.OutOfMemoryError:
|
| 178 |
print(f"GPU OOM! {torch.cuda.memory_reserved()}")
|
| 179 |
torch.cuda.empty_cache()
|