import streamlit as st import os import torch #import math import numpy as np #import matplotlib.pyplot as plt #import pathlib from AtomLenz import * #from utils_graph import * from Object_Smiles import Objects_Smiles #from robust_detection import wandb_config from robust_detection import utils from robust_detection.models.rcnn import RCNN from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import CSVLogger from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks import LearningRateMonitor from rdkit import Chem from rdkit.Chem import AllChem from rdkit import DataStructs from PIL import Image import matplotlib.pyplot as plt def main_page(top_n, model_path): st.markdown( """test """ ) #### TRYOUT MENU ##### page_names_to_funcs = { # "Microscopy images from a molecule": images_from_molecule, # "Molecules from a microscopy image": molecules_from_image, "About AtomLenz": main_page, } selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys()) st.sidebar.markdown('') selected_model = st.sidebar.selectbox( "Select a AtomLenz model to load", ("AtomLenz trained on synthetic data (default)", "AtomLenz for hand-drawn images", "ChemExpert (not available yet)")) model_dict = { "AtomLenz trained on synthetic data (default)" : "atomlenz_default.pt", "AtomLenz for hand-drawn images" : "atomlenz_handdrawn.pt", "ChemExpert (not available yet)" : "atomlenz_default.pt" } model_file = model_dict[selected_model] #model_path = os.path.join(datapath, model_file) #if model_path.endswith("320).pt"): # image_resolution = 320 #else: # image_resolution = 520 #page_names_to_funcs[selected_page](n_objects, model_path) ###################### colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"] def plot_bbox(bbox_XYXY, label): xmin, ymin, xmax, ymax =bbox_XYXY plt.plot( [xmin, xmin, xmax, xmax, xmin], [ymin, ymax, ymax, ymin, ymin], color=colors[label], label=str(label)) model_cls = RCNN experiment_path_atoms="./models/atoms_model/" dir_list = os.listdir(experiment_path_atoms) dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list] dir_list.sort(key=os.path.getctime, reverse=True) checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0] model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms) model_atom.model.roi_heads.score_thresh = 0.65 experiment_path_bonds = "./models/bonds_model/" dir_list = os.listdir(experiment_path_bonds) dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list] dir_list.sort(key=os.path.getctime, reverse=True) checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0] model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds) model_bond.model.roi_heads.score_thresh = 0.65 experiment_path_stereo = "./models/stereos_model/" dir_list = os.listdir(experiment_path_stereo) dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list] dir_list.sort(key=os.path.getctime, reverse=True) checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0] model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo) model_stereo.model.roi_heads.score_thresh = 0.65 experiment_path_charges = "./models/charges_model/" dir_list = os.listdir(experiment_path_charges) dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list] dir_list.sort(key=os.path.getctime, reverse=True) checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0] model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges) model_charge.model.roi_heads.score_thresh = 0.65 data_cls = Objects_Smiles dataset = data_cls(data_path="./uploads/", batch_size=1) # dataset.prepare_data() st.title("Atom Level Entity Detector") image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png']) #st.write('filename is', file_name) if image_file is not None: #col1, col2 = st.columns(2) image = Image.open(image_file) #col1.image(image, use_column_width=True) st.image(image, use_column_width=True) col1, col2 = st.columns(2) if not os.path.exists("uploads/images"): os.makedirs("uploads/images") with open(os.path.join("uploads/images/","0.png"),"wb") as f: f.write(image_file.getbuffer()) #st.success("Saved File") dataset.prepare_data() trainer = pl.Trainer(logger=False) st.toast('Predicting atoms,bonds,charges,..., please wait') atom_preds = trainer.predict(model_atom, dataset.test_dataloader()) bond_preds = trainer.predict(model_bond, dataset.test_dataloader()) stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader()) charges_preds = trainer.predict(model_charge, dataset.test_dataloader()) st.toast('Done') #st.write(atom_preds) plt.imshow(image, cmap="gray") for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]): # st.write(bbox) # st.write(label) plot_bbox(bbox, label) plt.axis('off') plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0) image_vis = Image.open("example_image.png") col1.image(image_vis, use_column_width=True) plt.clf() plt.imshow(image, cmap="gray") for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]): # st.write(bbox) # st.write(label) plot_bbox(bbox, label) plt.axis('off') plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0) image_vis = Image.open("example_image.png") col2.image(image_vis, use_column_width=True) mol_graphs = [] count_bonds_preds = np.zeros(4) count_atoms_preds = np.zeros(15) correct=0 correct_objects=0 correct_both=0 predictions=0 tanimoto_dists=[] predictions_list = [] for image_idx, bonds in enumerate(bond_preds): count_bonds_preds = np.zeros(8) count_atoms_preds = np.zeros(18) atom_boxes = atom_preds[image_idx]['boxes'][0] atom_labels = atom_preds[image_idx]['preds'][0] atom_scores = atom_preds[image_idx]['scores'][0] charge_boxes = charges_preds[image_idx]['boxes'][0] charge_labels = charges_preds[image_idx]['preds'][0] charge_mask=torch.where(charge_labels>1) filtered_ch_labels=charge_labels[charge_mask] filtered_ch_boxes=charge_boxes[charge_mask] #import ipdb; ipdb.set_trace() filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores) #for atom_label in filtered_labels: # count_atoms_preds[atom_label] += 1 #import ipdb; ipdb.set_trace() mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes))) stereo_atoms = np.zeros(len(filtered_bboxes)) charge_atoms = np.ones(len(filtered_bboxes)) for index,box_atom in enumerate(filtered_bboxes): for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels): if bb_box_intersects(box_atom,box_charge) == 1: charge_atoms[index]=label_charge for bond_idx, bond_box in enumerate(bonds['boxes'][0]): label_bond = bonds['preds'][0][bond_idx] if label_bond > 1: try: count_bonds_preds[label_bond] += 1 except: count_bonds_preds=count_bonds_preds #import ipdb; ipdb.set_trace() result = [] limit = 0 #TODO: values of 50 and 5 should be made dependent of mean size of atom_boxes while result.count(1) < 2 and limit < 80: result=[] bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit] for atom_box in filtered_bboxes: result.append(bb_box_intersects(atom_box,bigger_bond_box)) limit+=5 indices = [i for i, x in enumerate(result) if x == 1] if len(indices) == 2: #import ipdb; ipdb.set_trace() mol_graph[indices[0],indices[1]]=label_bond mol_graph[indices[1],indices[0]]=label_bond if len(indices) > 2: #we have more then two canidate atoms for one bond, we filter ... cand_bboxes = filtered_bboxes[indices,:] cand_indices = dist_filter_bboxes(cand_bboxes) #import ipdb; ipdb.set_trace() mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond #print("more than 2 indices") #if len(indices) < 2: # print("less than 2 indices") #import ipdb; ipdb.set_trace() # else: # result=[] # for atom_box in filtered_bboxes: # result.append(bb_box_intersects(atom_box,bond_box)) # indices = [i for i, x in enumerate(result) if x == 1] # if len(indices) == 1: # stereo_atoms[indices[0]]=label_bond stereo_bonds = np.where(mol_graph>4, True, False) if np.any(stereo_bonds): stereo_boxes = stereo_preds[image_idx]['boxes'][0] stereo_labels= stereo_preds[image_idx]['preds'][0] for stereo_box in stereo_boxes: result=[] for atom_box in filtered_bboxes: result.append(bb_box_intersects(atom_box,stereo_box)) indices = [i for i, x in enumerate(result) if x == 1] if len(indices) == 1: stereo_atoms[indices[0]]=1 molecule = dict() molecule['graph'] = mol_graph #molecule['atom_labels'] = atom_preds[image_idx]['preds'][0] molecule['atom_labels'] = filtered_labels molecule['atom_boxes'] = filtered_bboxes molecule['stereo_atoms'] = stereo_atoms molecule['charge_atoms'] = charge_atoms mol_graphs.append(molecule) #base_path="./" #base_path = pathlib.Path(args.data_path) #image_dir = base_path.joinpath("images") #smiles_dir = base_path.joinpath("smiles") #impath = image_dir.joinpath(f"{image_idx}.png") #smilespath = smiles_dir.joinpath(f"{image_idx}.txt") save_mol_to_file(molecule,'molfile') mol = Chem.MolFromMolFile('molfile',sanitize=False) problematic = 0 try: problems = Chem.DetectChemistryProblems(mol) if len(problems) > 0: mol = solve_mol_problems(mol,problems) problematic = 1 #import ipdb; ipdb.set_trace() try: Chem.SanitizeMol(mol) except: problems = Chem.DetectChemistryProblems(mol) if len(problems) > 0: mol = solve_mol_problems(mol,problems) try: Chem.SanitizeMol(mol) except: pass except: problematic = 1 try: pred_smiles = Chem.MolToSmiles(mol) except: pred_smiles = "" problematic = 1 predictions+=1 predictions_list.append([image_idx,pred_smiles,problematic]) #import ipdb; ipdb.set_trace() file_preds = open('preds_atomlenz','w') for pred in predictions_list: print(pred) #x = st.slider('Select a value') #st.write(x, 'squared is', x * x)