Update inference/generate_shards.py
Browse files
inference/generate_shards.py
CHANGED
|
@@ -47,7 +47,7 @@ class ShardGenerator:
|
|
| 47 |
|
| 48 |
def generate_placeholder_shards(
|
| 49 |
self,
|
| 50 |
-
shard_size_mb: float =
|
| 51 |
tensor_dtype: torch.dtype = torch.bfloat16
|
| 52 |
):
|
| 53 |
"""
|
|
@@ -122,7 +122,7 @@ class ShardGenerator:
|
|
| 122 |
def split_large_model(
|
| 123 |
self,
|
| 124 |
model_state_dict: Dict[str, torch.Tensor],
|
| 125 |
-
max_shard_size_gb: float =
|
| 126 |
):
|
| 127 |
"""
|
| 128 |
Split a large model into shards
|
|
@@ -287,7 +287,7 @@ def main():
|
|
| 287 |
parser.add_argument(
|
| 288 |
"--shard-size",
|
| 289 |
type=float,
|
| 290 |
-
default=
|
| 291 |
help="Target shard size in MB"
|
| 292 |
)
|
| 293 |
|
|
|
|
| 47 |
|
| 48 |
def generate_placeholder_shards(
|
| 49 |
self,
|
| 50 |
+
shard_size_mb: float = 3010,
|
| 51 |
tensor_dtype: torch.dtype = torch.bfloat16
|
| 52 |
):
|
| 53 |
"""
|
|
|
|
| 122 |
def split_large_model(
|
| 123 |
self,
|
| 124 |
model_state_dict: Dict[str, torch.Tensor],
|
| 125 |
+
max_shard_size_gb: float = 3.01
|
| 126 |
):
|
| 127 |
"""
|
| 128 |
Split a large model into shards
|
|
|
|
| 287 |
parser.add_argument(
|
| 288 |
"--shard-size",
|
| 289 |
type=float,
|
| 290 |
+
default=3010,
|
| 291 |
help="Target shard size in MB"
|
| 292 |
)
|
| 293 |
|