AndyRaoTHU commited on
Commit
20077d0
·
1 Parent(s): 3cb25e1
Files changed (2) hide show
  1. app.py +0 -1
  2. revq/models/revq.py +19 -19
app.py CHANGED
@@ -144,7 +144,6 @@ class Handler:
144
  self.vae.to(self.device)
145
  self.vae.eval()
146
  self.preprocesser = load_preprocessor(self.device)
147
- # 待修改
148
  self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T")
149
  self.revq.to(self.device)
150
  self.revq.eval()
 
144
  self.vae.to(self.device)
145
  self.vae.eval()
146
  self.preprocesser = load_preprocessor(self.device)
 
147
  self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T")
148
  self.revq.to(self.device)
149
  self.revq.eval()
revq/models/revq.py CHANGED
@@ -36,7 +36,7 @@ class Viewer:
36
  x = x.view(B, C, H, W)
37
  return x
38
 
39
- class ReVQ(nn.Module, PyTorchModelHubMixin):
40
  def __init__(self,
41
  decoder: dict = {},
42
  quantize: dict = {},
@@ -52,6 +52,23 @@ class ReVQ(nn.Module, PyTorchModelHubMixin):
52
  checkpoint = torch.load(ckpt_path, map_location="cpu")
53
  self.decoder.load_state_dict(checkpoint["decoder"])
54
  self.quantizer.load_state_dict(checkpoint["quantizer"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def setup_quantizer(self, quantizer_config):
57
  quantizer = Quantizer(**quantizer_config)
@@ -70,21 +87,4 @@ class ReVQ(nn.Module, PyTorchModelHubMixin):
70
  def forward(self, x):
71
  quant = self.quantize(x)
72
  rec = self.decode(quant)
73
- return quant, rec
74
-
75
- @classmethod
76
- def _from_pretrained(cls, repo_id: str, **kwargs):
77
- config_path = hf_hub_download(repo_id=repo_id, filename="512T_NC=16384.yaml")
78
- ckpt_path = hf_hub_download(repo_id=repo_id, filename="ckpt.pth")
79
-
80
- full_cfg = OmegaConf.load(config_path)
81
- model_cfg = full_cfg.get("model", {})
82
- decoder_config = model_cfg.get("decoder", {})
83
- quantize_config = model_cfg.get("quantizer", {})
84
-
85
- model = cls(decoder=decoder_config,
86
- quantize=quantize_config,
87
- ckpt_path=ckpt_path,
88
- **kwargs)
89
-
90
- return model
 
36
  x = x.view(B, C, H, W)
37
  return x
38
 
39
+ class ReVQ(PyTorchModelHubMixin, nn.Module):
40
  def __init__(self,
41
  decoder: dict = {},
42
  quantize: dict = {},
 
52
  checkpoint = torch.load(ckpt_path, map_location="cpu")
53
  self.decoder.load_state_dict(checkpoint["decoder"])
54
  self.quantizer.load_state_dict(checkpoint["quantizer"])
55
+
56
+ @classmethod
57
+ def _from_pretrained(cls, repo_id: str, **kwargs):
58
+ config_path = hf_hub_download(repo_id=repo_id, filename="512T_NC=16384.yaml")
59
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename="ckpt.pth")
60
+
61
+ full_cfg = OmegaConf.load(config_path)
62
+ model_cfg = full_cfg.get("model", {})
63
+ decoder_config = model_cfg.get("decoder", {})
64
+ quantize_config = model_cfg.get("quantizer", {})
65
+
66
+ model = cls(decoder=decoder_config,
67
+ quantize=quantize_config,
68
+ ckpt_path=ckpt_path,
69
+ **kwargs)
70
+
71
+ return model
72
 
73
  def setup_quantizer(self, quantizer_config):
74
  quantizer = Quantizer(**quantizer_config)
 
87
  def forward(self, x):
88
  quant = self.quantize(x)
89
  rec = self.decode(quant)
90
+ return quant, rec