Spaces:
Sleeping
Sleeping
Commit
·
20077d0
1
Parent(s):
3cb25e1
debug
Browse files- app.py +0 -1
- revq/models/revq.py +19 -19
app.py
CHANGED
|
@@ -144,7 +144,6 @@ class Handler:
|
|
| 144 |
self.vae.to(self.device)
|
| 145 |
self.vae.eval()
|
| 146 |
self.preprocesser = load_preprocessor(self.device)
|
| 147 |
-
# 待修改
|
| 148 |
self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T")
|
| 149 |
self.revq.to(self.device)
|
| 150 |
self.revq.eval()
|
|
|
|
| 144 |
self.vae.to(self.device)
|
| 145 |
self.vae.eval()
|
| 146 |
self.preprocesser = load_preprocessor(self.device)
|
|
|
|
| 147 |
self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T")
|
| 148 |
self.revq.to(self.device)
|
| 149 |
self.revq.eval()
|
revq/models/revq.py
CHANGED
|
@@ -36,7 +36,7 @@ class Viewer:
|
|
| 36 |
x = x.view(B, C, H, W)
|
| 37 |
return x
|
| 38 |
|
| 39 |
-
class ReVQ(nn.Module
|
| 40 |
def __init__(self,
|
| 41 |
decoder: dict = {},
|
| 42 |
quantize: dict = {},
|
|
@@ -52,6 +52,23 @@ class ReVQ(nn.Module, PyTorchModelHubMixin):
|
|
| 52 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 53 |
self.decoder.load_state_dict(checkpoint["decoder"])
|
| 54 |
self.quantizer.load_state_dict(checkpoint["quantizer"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def setup_quantizer(self, quantizer_config):
|
| 57 |
quantizer = Quantizer(**quantizer_config)
|
|
@@ -70,21 +87,4 @@ class ReVQ(nn.Module, PyTorchModelHubMixin):
|
|
| 70 |
def forward(self, x):
|
| 71 |
quant = self.quantize(x)
|
| 72 |
rec = self.decode(quant)
|
| 73 |
-
return quant, rec
|
| 74 |
-
|
| 75 |
-
@classmethod
|
| 76 |
-
def _from_pretrained(cls, repo_id: str, **kwargs):
|
| 77 |
-
config_path = hf_hub_download(repo_id=repo_id, filename="512T_NC=16384.yaml")
|
| 78 |
-
ckpt_path = hf_hub_download(repo_id=repo_id, filename="ckpt.pth")
|
| 79 |
-
|
| 80 |
-
full_cfg = OmegaConf.load(config_path)
|
| 81 |
-
model_cfg = full_cfg.get("model", {})
|
| 82 |
-
decoder_config = model_cfg.get("decoder", {})
|
| 83 |
-
quantize_config = model_cfg.get("quantizer", {})
|
| 84 |
-
|
| 85 |
-
model = cls(decoder=decoder_config,
|
| 86 |
-
quantize=quantize_config,
|
| 87 |
-
ckpt_path=ckpt_path,
|
| 88 |
-
**kwargs)
|
| 89 |
-
|
| 90 |
-
return model
|
|
|
|
| 36 |
x = x.view(B, C, H, W)
|
| 37 |
return x
|
| 38 |
|
| 39 |
+
class ReVQ(PyTorchModelHubMixin, nn.Module):
|
| 40 |
def __init__(self,
|
| 41 |
decoder: dict = {},
|
| 42 |
quantize: dict = {},
|
|
|
|
| 52 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 53 |
self.decoder.load_state_dict(checkpoint["decoder"])
|
| 54 |
self.quantizer.load_state_dict(checkpoint["quantizer"])
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def _from_pretrained(cls, repo_id: str, **kwargs):
|
| 58 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="512T_NC=16384.yaml")
|
| 59 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename="ckpt.pth")
|
| 60 |
+
|
| 61 |
+
full_cfg = OmegaConf.load(config_path)
|
| 62 |
+
model_cfg = full_cfg.get("model", {})
|
| 63 |
+
decoder_config = model_cfg.get("decoder", {})
|
| 64 |
+
quantize_config = model_cfg.get("quantizer", {})
|
| 65 |
+
|
| 66 |
+
model = cls(decoder=decoder_config,
|
| 67 |
+
quantize=quantize_config,
|
| 68 |
+
ckpt_path=ckpt_path,
|
| 69 |
+
**kwargs)
|
| 70 |
+
|
| 71 |
+
return model
|
| 72 |
|
| 73 |
def setup_quantizer(self, quantizer_config):
|
| 74 |
quantizer = Quantizer(**quantizer_config)
|
|
|
|
| 87 |
def forward(self, x):
|
| 88 |
quant = self.quantize(x)
|
| 89 |
rec = self.decode(quant)
|
| 90 |
+
return quant, rec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|