Use torch.inference_mode() and disable gradient checkpointing

#4
by prathamj31 - opened
Files changed (2) hide show
  1. config.json +4 -1
  2. modeling_zeranker.py +4 -3
config.json CHANGED
@@ -64,5 +64,8 @@
64
  "transformers_version": "4.57.1",
65
  "use_cache": true,
66
  "use_sliding_window": false,
67
- "vocab_size": 151936
 
 
 
68
  }
 
64
  "transformers_version": "4.57.1",
65
  "use_cache": true,
66
  "use_sliding_window": false,
67
+ "vocab_size": 151936,
68
+ "auto_map": {
69
+ "AutoConfig": "modeling_zeranker.ZEConfig"
70
+ }
71
  }
modeling_zeranker.py CHANGED
@@ -24,7 +24,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
24
  # pyright: reportUnknownVariableType=false
25
 
26
  MODEL_PATH = "zeroentropy/zerank-2"
27
- PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
28
  global_device = (
29
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
  )
@@ -127,8 +127,8 @@ def predict(
127
 
128
  if not hasattr(self, "inner_model"):
129
  self.inner_tokenizer, self.inner_model = load_model(global_device)
130
- self.inner_model.gradient_checkpointing_enable()
131
  self.inner_model.eval()
 
132
  self.inner_yes_token_id = self.inner_tokenizer.encode(
133
  "Yes", add_special_tokens=False
134
  )[0]
@@ -172,7 +172,8 @@ def predict(
172
  batch_inputs = batch_inputs.to(global_device)
173
 
174
  try:
175
- outputs = model(**batch_inputs, use_cache=False)
 
176
  except torch.OutOfMemoryError:
177
  print(f"GPU OOM! {torch.cuda.memory_reserved()}")
178
  torch.cuda.empty_cache()
 
24
  # pyright: reportUnknownVariableType=false
25
 
26
  MODEL_PATH = "zeroentropy/zerank-2"
27
+ PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
28
  global_device = (
29
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
  )
 
127
 
128
  if not hasattr(self, "inner_model"):
129
  self.inner_tokenizer, self.inner_model = load_model(global_device)
 
130
  self.inner_model.eval()
131
+ self.inner_model.gradient_checkpointing_disable()
132
  self.inner_yes_token_id = self.inner_tokenizer.encode(
133
  "Yes", add_special_tokens=False
134
  )[0]
 
172
  batch_inputs = batch_inputs.to(global_device)
173
 
174
  try:
175
+ with torch.inference_mode():
176
+ outputs = model(**batch_inputs, use_cache=False)
177
  except torch.OutOfMemoryError:
178
  print(f"GPU OOM! {torch.cuda.memory_reserved()}")
179
  torch.cuda.empty_cache()