implementing predict smiles
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import torch
|
|
| 5 |
import numpy as np
|
| 6 |
#import matplotlib.pyplot as plt
|
| 7 |
#import pathlib
|
| 8 |
-
|
| 9 |
#from utils_graph import *
|
| 10 |
from Object_Smiles import Objects_Smiles
|
| 11 |
|
|
@@ -112,5 +112,126 @@ if image_file is not None:
|
|
| 112 |
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
|
| 113 |
image_vis = Image.open("example_image.png")
|
| 114 |
col2.image(image_vis, use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
#x = st.slider('Select a value')
|
| 116 |
#st.write(x, 'squared is', x * x)
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
#import matplotlib.pyplot as plt
|
| 7 |
#import pathlib
|
| 8 |
+
from AtomLenz import *
|
| 9 |
#from utils_graph import *
|
| 10 |
from Object_Smiles import Objects_Smiles
|
| 11 |
|
|
|
|
| 112 |
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
|
| 113 |
image_vis = Image.open("example_image.png")
|
| 114 |
col2.image(image_vis, use_column_width=True)
|
| 115 |
+
for image_idx, bonds in enumerate(bond_preds):
|
| 116 |
+
count_bonds_preds = np.zeros(8)
|
| 117 |
+
count_atoms_preds = np.zeros(18)
|
| 118 |
+
atom_boxes = atom_preds[image_idx]['boxes'][0]
|
| 119 |
+
atom_labels = atom_preds[image_idx]['preds'][0]
|
| 120 |
+
atom_scores = atom_preds[image_idx]['scores'][0]
|
| 121 |
+
charge_boxes = charges_preds[image_idx]['boxes'][0]
|
| 122 |
+
charge_labels = charges_preds[image_idx]['preds'][0]
|
| 123 |
+
charge_mask=torch.where(charge_labels>1)
|
| 124 |
+
filtered_ch_labels=charge_labels[charge_mask]
|
| 125 |
+
filtered_ch_boxes=charge_boxes[charge_mask]
|
| 126 |
+
#import ipdb; ipdb.set_trace()
|
| 127 |
+
filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores)
|
| 128 |
+
#for atom_label in filtered_labels:
|
| 129 |
+
# count_atoms_preds[atom_label] += 1
|
| 130 |
+
#import ipdb; ipdb.set_trace()
|
| 131 |
+
mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes)))
|
| 132 |
+
stereo_atoms = np.zeros(len(filtered_bboxes))
|
| 133 |
+
charge_atoms = np.ones(len(filtered_bboxes))
|
| 134 |
+
for index,box_atom in enumerate(filtered_bboxes):
|
| 135 |
+
for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels):
|
| 136 |
+
if bb_box_intersects(box_atom,box_charge) == 1:
|
| 137 |
+
charge_atoms[index]=label_charge
|
| 138 |
+
|
| 139 |
+
for bond_idx, bond_box in enumerate(bonds['boxes'][0]):
|
| 140 |
+
label_bond = bonds['preds'][0][bond_idx]
|
| 141 |
+
if label_bond > 1:
|
| 142 |
+
try:
|
| 143 |
+
count_bonds_preds[label_bond] += 1
|
| 144 |
+
except:
|
| 145 |
+
count_bonds_preds=count_bonds_preds
|
| 146 |
+
#import ipdb; ipdb.set_trace()
|
| 147 |
+
result = []
|
| 148 |
+
limit = 0
|
| 149 |
+
#TODO: values of 50 and 5 should be made dependent of mean size of atom_boxes
|
| 150 |
+
while result.count(1) < 2 and limit < 80:
|
| 151 |
+
result=[]
|
| 152 |
+
bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit]
|
| 153 |
+
for atom_box in filtered_bboxes:
|
| 154 |
+
result.append(bb_box_intersects(atom_box,bigger_bond_box))
|
| 155 |
+
limit+=5
|
| 156 |
+
indices = [i for i, x in enumerate(result) if x == 1]
|
| 157 |
+
if len(indices) == 2:
|
| 158 |
+
#import ipdb; ipdb.set_trace()
|
| 159 |
+
mol_graph[indices[0],indices[1]]=label_bond
|
| 160 |
+
mol_graph[indices[1],indices[0]]=label_bond
|
| 161 |
+
if len(indices) > 2:
|
| 162 |
+
#we have more then two canidate atoms for one bond, we filter ...
|
| 163 |
+
cand_bboxes = filtered_bboxes[indices,:]
|
| 164 |
+
cand_indices = dist_filter_bboxes(cand_bboxes)
|
| 165 |
+
#import ipdb; ipdb.set_trace()
|
| 166 |
+
mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond
|
| 167 |
+
mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond
|
| 168 |
+
#print("more than 2 indices")
|
| 169 |
+
#if len(indices) < 2:
|
| 170 |
+
# print("less than 2 indices")
|
| 171 |
+
#import ipdb; ipdb.set_trace()
|
| 172 |
+
# else:
|
| 173 |
+
# result=[]
|
| 174 |
+
# for atom_box in filtered_bboxes:
|
| 175 |
+
# result.append(bb_box_intersects(atom_box,bond_box))
|
| 176 |
+
# indices = [i for i, x in enumerate(result) if x == 1]
|
| 177 |
+
# if len(indices) == 1:
|
| 178 |
+
# stereo_atoms[indices[0]]=label_bond
|
| 179 |
+
stereo_bonds = np.where(mol_graph>4, True, False)
|
| 180 |
+
if np.any(stereo_bonds):
|
| 181 |
+
stereo_boxes = stereo_preds[image_idx]['boxes'][0]
|
| 182 |
+
stereo_labels= stereo_preds[image_idx]['preds'][0]
|
| 183 |
+
for stereo_box in stereo_boxes:
|
| 184 |
+
result=[]
|
| 185 |
+
for atom_box in filtered_bboxes:
|
| 186 |
+
result.append(bb_box_intersects(atom_box,stereo_box))
|
| 187 |
+
indices = [i for i, x in enumerate(result) if x == 1]
|
| 188 |
+
if len(indices) == 1:
|
| 189 |
+
stereo_atoms[indices[0]]=1
|
| 190 |
+
|
| 191 |
+
molecule = dict()
|
| 192 |
+
molecule['graph'] = mol_graph
|
| 193 |
+
#molecule['atom_labels'] = atom_preds[image_idx]['preds'][0]
|
| 194 |
+
molecule['atom_labels'] = filtered_labels
|
| 195 |
+
molecule['atom_boxes'] = filtered_bboxes
|
| 196 |
+
molecule['stereo_atoms'] = stereo_atoms
|
| 197 |
+
molecule['charge_atoms'] = charge_atoms
|
| 198 |
+
mol_graphs.append(molecule)
|
| 199 |
+
base_path = pathlib.Path(args.data_path)
|
| 200 |
+
image_dir = base_path.joinpath("images")
|
| 201 |
+
smiles_dir = base_path.joinpath("smiles")
|
| 202 |
+
impath = image_dir.joinpath(f"{image_idx}.png")
|
| 203 |
+
smilespath = smiles_dir.joinpath(f"{image_idx}.txt")
|
| 204 |
+
save_mol_to_file(molecule,'molfile')
|
| 205 |
+
mol = Chem.MolFromMolFile('molfile',sanitize=False)
|
| 206 |
+
problematic = 0
|
| 207 |
+
try:
|
| 208 |
+
problems = Chem.DetectChemistryProblems(mol)
|
| 209 |
+
if len(problems) > 0:
|
| 210 |
+
mol = solve_mol_problems(mol,problems)
|
| 211 |
+
problematic = 1
|
| 212 |
+
#import ipdb; ipdb.set_trace()
|
| 213 |
+
try:
|
| 214 |
+
Chem.SanitizeMol(mol)
|
| 215 |
+
except:
|
| 216 |
+
problems = Chem.DetectChemistryProblems(mol)
|
| 217 |
+
if len(problems) > 0:
|
| 218 |
+
mol = solve_mol_problems(mol,problems)
|
| 219 |
+
try:
|
| 220 |
+
Chem.SanitizeMol(mol)
|
| 221 |
+
except:
|
| 222 |
+
pass
|
| 223 |
+
except:
|
| 224 |
+
problematic = 1
|
| 225 |
+
try:
|
| 226 |
+
pred_smiles = Chem.MolToSmiles(mol)
|
| 227 |
+
except:
|
| 228 |
+
pred_smiles = ""
|
| 229 |
+
problematic = 1
|
| 230 |
+
predictions+=1
|
| 231 |
+
predictions_list.append([image_idx,pred_smiles,problematic])
|
| 232 |
+
#import ipdb; ipdb.set_trace()
|
| 233 |
+
file_preds = open('preds_atomlenz','w')
|
| 234 |
+
for pred in predictions_list:
|
| 235 |
+
print(pred)
|
| 236 |
#x = st.slider('Select a value')
|
| 237 |
#st.write(x, 'squared is', x * x)
|