Trouter-Library commited on
Commit
2c59201
·
verified ·
1 Parent(s): d5025ae

Update inference/generate_shards.py

Browse files
Files changed (1) hide show
  1. inference/generate_shards.py +3 -3
inference/generate_shards.py CHANGED
@@ -47,7 +47,7 @@ class ShardGenerator:
47
 
48
  def generate_placeholder_shards(
49
  self,
50
- shard_size_mb: float = 2800,
51
  tensor_dtype: torch.dtype = torch.bfloat16
52
  ):
53
  """
@@ -122,7 +122,7 @@ class ShardGenerator:
122
  def split_large_model(
123
  self,
124
  model_state_dict: Dict[str, torch.Tensor],
125
- max_shard_size_gb: float = 2.8
126
  ):
127
  """
128
  Split a large model into shards
@@ -287,7 +287,7 @@ def main():
287
  parser.add_argument(
288
  "--shard-size",
289
  type=float,
290
- default=2800,
291
  help="Target shard size in MB"
292
  )
293
 
 
47
 
48
  def generate_placeholder_shards(
49
  self,
50
+ shard_size_mb: float = 3010,
51
  tensor_dtype: torch.dtype = torch.bfloat16
52
  ):
53
  """
 
122
  def split_large_model(
123
  self,
124
  model_state_dict: Dict[str, torch.Tensor],
125
+ max_shard_size_gb: float = 3.01
126
  ):
127
  """
128
  Split a large model into shards
 
287
  parser.add_argument(
288
  "--shard-size",
289
  type=float,
290
+ default=3010,
291
  help="Target shard size in MB"
292
  )
293