import streamlit as st import os import torch #import math import numpy as np #import matplotlib.pyplot as plt #import pathlib #from AtomLenz import * #from utils_graph import * from Object_Smiles import Objects_Smiles #from robust_detection import wandb_config from robust_detection import utils from robust_detection.models.rcnn import RCNN from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import CSVLogger from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks import LearningRateMonitor from rdkit import Chem from rdkit.Chem import AllChem from rdkit import DataStructs from PIL import Image import matplotlib.pyplot as plt colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"] def plot_bbox(bbox_XYXY, label): xmin, ymin, xmax, ymax =bbox_XYXY plt.plot( [xmin, xmin, xmax, xmax, xmin], [ymin, ymax, ymax, ymin, ymin], color=colors[label], label=str(label)) model_cls = RCNN experiment_path_atoms="./models/atoms_model/" dir_list = os.listdir(experiment_path_atoms) dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list] dir_list.sort(key=os.path.getctime, reverse=True) checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0] model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms) model_atom.model.roi_heads.score_thresh = 0.65 experiment_path_bonds = "./models/bonds_model/" dir_list = os.listdir(experiment_path_bonds) dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list] dir_list.sort(key=os.path.getctime, reverse=True) checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0] model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds) model_bond.model.roi_heads.score_thresh = 0.65 experiment_path_stereo = "./models/stereos_model/" dir_list = os.listdir(experiment_path_stereo) dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list] dir_list.sort(key=os.path.getctime, reverse=True) checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0] model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo) model_stereo.model.roi_heads.score_thresh = 0.65 experiment_path_charges = "./models/charges_model/" dir_list = os.listdir(experiment_path_charges) dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list] dir_list.sort(key=os.path.getctime, reverse=True) checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0] model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges) model_charge.model.roi_heads.score_thresh = 0.65 data_cls = Objects_Smiles dataset = data_cls(data_path="./uploads/", batch_size=1) # dataset.prepare_data() st.title("Atom Level Entity Detector") image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png']) #st.write('filename is', file_name) if image_file is not None: #col1, col2 = st.columns(2) image = Image.open(image_file) #col1.image(image, use_column_width=True) st.image(image, use_column_width=True) col1, col2 = st.columns(2) if not os.path.exists("uploads/images"): os.makedirs("uploads/images") with open(os.path.join("uploads/images/","0.png"),"wb") as f: f.write(image_file.getbuffer()) #st.success("Saved File") dataset.prepare_data() trainer = pl.Trainer(logger=False) st.toast('Predicting atoms,bonds,charges,..., please wait') atom_preds = trainer.predict(model_atom, dataset.test_dataloader()) bond_preds = trainer.predict(model_bond, dataset.test_dataloader()) stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader()) charge_preds = trainer.predict(model_charge, dataset.test_dataloader()) st.toast('Done') #st.write(atom_preds) plt.imshow(image, cmap="gray") for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]): # st.write(bbox) # st.write(label) plot_bbox(bbox, label) plt.axis('off') plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0) image_vis = Image.open("example_image.png") col1.image(image_vis, use_column_width=True) plt.clf() plt.imshow(image, cmap="gray") for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]): # st.write(bbox) # st.write(label) plot_bbox(bbox, label) plt.axis('off') plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0) image_vis = Image.open("example_image.png") col2.image(image_vis, use_column_width=True) #x = st.slider('Select a value') #st.write(x, 'squared is', x * x)