Spaces:
Sleeping
Sleeping
Commit
·
e7f69aa
1
Parent(s):
0f0afe4
debug
Browse files- revq/models/revq.py +5 -5
revq/models/revq.py
CHANGED
|
@@ -36,11 +36,11 @@ class Viewer:
|
|
| 36 |
|
| 37 |
class ReVQ(PyTorchModelHubMixin, nn.Module):
|
| 38 |
@classmethod
|
| 39 |
-
def _from_pretrained(cls,
|
| 40 |
-
print(f"Loading ReVQ model from {
|
| 41 |
-
config_path = hf_hub_download(repo_id=
|
| 42 |
-
ckpt_path = hf_hub_download(repo_id=
|
| 43 |
-
|
| 44 |
full_cfg = OmegaConf.load(config_path)
|
| 45 |
model_cfg = full_cfg.get("model", {})
|
| 46 |
decoder_config = model_cfg.get("decoder", {})
|
|
|
|
| 36 |
|
| 37 |
class ReVQ(PyTorchModelHubMixin, nn.Module):
|
| 38 |
@classmethod
|
| 39 |
+
def _from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 40 |
+
print(f"Loading ReVQ model from {pretrained_model_name_or_path}...")
|
| 41 |
+
config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="512T_NC=16384.yaml")
|
| 42 |
+
ckpt_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="ckpt.pth")
|
| 43 |
+
|
| 44 |
full_cfg = OmegaConf.load(config_path)
|
| 45 |
model_cfg = full_cfg.get("model", {})
|
| 46 |
decoder_config = model_cfg.get("decoder", {})
|