Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from io import BytesIO | |
| import os | |
| import sys | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import numpy as np | |
| from PIL import Image | |
| from omegaconf import OmegaConf | |
| import torch | |
| from torchvision import transforms as T | |
| from revq.models.quantizer import sinkhorn | |
| from revq.models.preprocessor import Preprocessor | |
| from revq.models.revq import ReVQ | |
| from revq.utils.init import seed_everything | |
| seed_everything(42) | |
| from revq.models.vqgan_hf import VQModelHF | |
| matplotlib.rcParams['font.family'] = 'Times New Roman' | |
| from diffusers import AutoencoderDC | |
| ################# | |
| N_data = 50 | |
| N_code = 20 | |
| dim = 2 | |
| handler = None | |
| device = torch.device("cpu") | |
| ################# | |
| def load_preprocessor(device, is_eval: bool = True, ckpt_path: str = "./ckpt/preprocessor.pth"): | |
| preprocessor = Preprocessor( | |
| input_data_size=[32,8,8] | |
| ).to(device) | |
| preprocessor.load_state_dict( | |
| torch.load(ckpt_path, map_location=device, weights_only=True) | |
| ) | |
| if is_eval: | |
| preprocessor.eval() | |
| return preprocessor | |
| def nearest(src, trg): | |
| dis_mat = torch.cdist(src, trg) | |
| min_idx = torch.argmin(dis_mat, dim=-1) | |
| return min_idx | |
| def normalize(A, dim, mode="all"): | |
| if mode == "all": | |
| A = (A - A.mean()) / (A.std() + 1e-6) | |
| A = A - A.min() | |
| elif mode == "dim": | |
| A = A / dim | |
| elif mode == "null": | |
| pass | |
| return A | |
| def draw_NN(data, code): | |
| # nearest neighbor method | |
| indices = nearest(data, code) | |
| data = data.numpy() | |
| code = code.numpy() | |
| plt.figure(figsize=(3, 2.5), dpi=400) | |
| # draw arrows in blue color, alpha=0.5 | |
| for i in range(data.shape[0]): | |
| idx = indices[i].item() | |
| start = data[i] | |
| end = code[idx] | |
| plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1], | |
| head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6, | |
| ls="-", lw=0.5) | |
| plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data") | |
| plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code") | |
| plt.legend(loc="lower right") | |
| plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5) | |
| plt.title("Nearest neighbor") | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| image = Image.open(buf) | |
| return image | |
| def draw_optvq(data, code): | |
| cost = torch.cdist(data, code, p=2.0) | |
| cost = normalize(cost, dim, mode="all") | |
| Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False) | |
| indices = torch.argmax(Q, dim=-1) | |
| data = data.numpy() | |
| code = code.numpy() | |
| plt.figure(figsize=(3, 2.5), dpi=400) | |
| # draw arrows in blue color, alpha=0.5 | |
| for i in range(data.shape[0]): | |
| idx = indices[i].item() | |
| start = data[i] | |
| end = code[idx] | |
| plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1], | |
| head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6, | |
| ls="-", lw=0.5) | |
| plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data") | |
| plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code") | |
| plt.legend(loc="lower right") | |
| plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5) | |
| plt.title("Optimal Transport (OptVQ)") | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| image = Image.open(buf) | |
| return image | |
| def draw_process(x, y, std): | |
| data = torch.randn(N_data, dim) | |
| code = torch.randn(N_code, dim) * std | |
| code[:, 0] += x | |
| code[:, 1] += y | |
| image_NN = draw_NN(data, code) | |
| image_optvq = draw_optvq(data, code) | |
| return image_NN, image_optvq | |
| class Handler: | |
| def __init__(self, device): | |
| self.transform = T.Compose([ | |
| T.Resize(256), | |
| T.CenterCrop(256), | |
| T.ToTensor() | |
| ]) | |
| self.device = device | |
| self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4") | |
| self.basevq.to(self.device) | |
| self.basevq.eval() | |
| self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16") | |
| self.vqgan.to(self.device) | |
| self.vqgan.eval() | |
| self.vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers") | |
| self.vae.to(self.device) | |
| self.vae.eval() | |
| self.preprocesser = load_preprocessor(self.device) | |
| # 待修改 | |
| self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T") | |
| self.revq.to(self.device) | |
| self.revq.eval() | |
| def tensor_to_image(self, tensor): | |
| img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy() | |
| img = (img + 1) / 2 * 255 | |
| img = img.astype("uint8") | |
| return img | |
| def process_image(self, img: np.ndarray): | |
| img = Image.fromarray(img.astype("uint8")) | |
| img = self.transform(img) | |
| img = img.unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| img = 2 * img - 1 | |
| # basevq | |
| quant, *_ = self.basevq.encode(img) | |
| basevq_rec = self.basevq.decode(quant) | |
| # vqgan | |
| quant, *_ = self.vqgan.encode(img) | |
| vqgan_rec = self.vqgan.decode(quant) | |
| # revq | |
| lat = self.vae.encode(img).latent | |
| lat = lat.contiguous() | |
| lat = self.preprocesser(lat) | |
| lat = self.revq.quantize(lat) | |
| revq_rec = self.revq.decode(lat) | |
| revq_rec = revq_rec.contiguous() | |
| revq_rec = self.preprocesser.inverse(revq_rec) | |
| revq_rec = self.vae.decode(revq_rec).sample | |
| # tensor to PIL image | |
| img = self.tensor_to_image(img) | |
| basevq_rec = self.tensor_to_image(basevq_rec) | |
| vqgan_rec = self.tensor_to_image(vqgan_rec) | |
| revq_rec = self.tensor_to_image(revq_rec) | |
| return img, basevq_rec, vqgan_rec, revq_rec | |
| if __name__ == "__main__": | |
| # create the model handler | |
| handler = Handler(device=device) | |
| # create the interface | |
| with gr.Blocks() as demo: | |
| gr.Textbox(value="This demo shows the image reconstruction comparison between ReVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy") | |
| btn_demo1 = gr.Button(value="Run reconstruction") | |
| image_basevq = gr.Image(label="BaseVQ rec.") | |
| image_vqgan = gr.Image(label="VQGAN rec.") | |
| image_revq = gr.Image(label="ReVQ rec.") | |
| btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_input, image_basevq, image_vqgan, image_revq]) | |
| gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1) | |
| input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1) | |
| input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1) | |
| btn_demo2 = gr.Button(value="Run 2D example") | |
| output_nn = gr.Image(label="NN") | |
| output_optvq = gr.Image(label="OptVQ") | |
| # set the function | |
| input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) | |
| input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) | |
| input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) | |
| btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq]) | |
| demo.launch() | |