AndyRaoTHU commited on
Commit
e7f69aa
·
1 Parent(s): 0f0afe4
Files changed (1) hide show
  1. revq/models/revq.py +5 -5
revq/models/revq.py CHANGED
@@ -36,11 +36,11 @@ class Viewer:
36
 
37
  class ReVQ(PyTorchModelHubMixin, nn.Module):
38
  @classmethod
39
- def _from_pretrained(cls, repo_id: str, **kwargs):
40
- print(f"Loading ReVQ model from {repo_id}...")
41
- config_path = hf_hub_download(repo_id=repo_id, filename="512T_NC=16384.yaml")
42
- ckpt_path = hf_hub_download(repo_id=repo_id, filename="ckpt.pth")
43
-
44
  full_cfg = OmegaConf.load(config_path)
45
  model_cfg = full_cfg.get("model", {})
46
  decoder_config = model_cfg.get("decoder", {})
 
36
 
37
  class ReVQ(PyTorchModelHubMixin, nn.Module):
38
  @classmethod
39
+ def _from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
40
+ print(f"Loading ReVQ model from {pretrained_model_name_or_path}...")
41
+ config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="512T_NC=16384.yaml")
42
+ ckpt_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="ckpt.pth")
43
+
44
  full_cfg = OmegaConf.load(config_path)
45
  model_cfg = full_cfg.get("model", {})
46
  decoder_config = model_cfg.get("decoder", {})