Spaces:
Running
Running
Update inference.py
Browse files- inference.py +12 -12
inference.py
CHANGED
|
@@ -114,7 +114,7 @@ class Inference(object):
|
|
| 114 |
|
| 115 |
def decoder_load(self, dictionary_name):
|
| 116 |
''' Loading the atom and bond decoders'''
|
| 117 |
-
with open("
|
| 118 |
return pickle.load(f)
|
| 119 |
|
| 120 |
|
|
@@ -140,16 +140,16 @@ class Inference(object):
|
|
| 140 |
self.restore_model(self.submodel, self.inference_model)
|
| 141 |
|
| 142 |
# smiles data for metrics calculation.
|
| 143 |
-
chembl_smiles = [line for line in open("
|
| 144 |
-
chembl_test = [line for line in open("
|
| 145 |
-
drug_smiles = [line for line in open("
|
| 146 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 147 |
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
|
| 148 |
|
| 149 |
|
| 150 |
# Make directories if not exist.
|
| 151 |
-
if not os.path.exists("
|
| 152 |
-
os.makedirs("
|
| 153 |
if self.correct:
|
| 154 |
correct = smi_correct(self.submodel, "DrugGEN_/experiments/inference/{}".format(self.submodel))
|
| 155 |
search_res = pd.DataFrame(columns=["submodel", "validity",
|
|
@@ -166,7 +166,7 @@ class Inference(object):
|
|
| 166 |
uniqueness_calc = []
|
| 167 |
real_smiles_snn = []
|
| 168 |
nodes_sample = torch.Tensor(size=[1,45,1]).to(self.device)
|
| 169 |
-
f = open("
|
| 170 |
f.write("SMILES")
|
| 171 |
f.write("\n")
|
| 172 |
val_counter = 0
|
|
@@ -226,16 +226,16 @@ class Inference(object):
|
|
| 226 |
f.close()
|
| 227 |
print("Inference completed, starting metrics calculation.")
|
| 228 |
if self.correct:
|
| 229 |
-
corrected = correct.correct("
|
| 230 |
gen_smi = corrected["SMILES"].tolist()
|
| 231 |
|
| 232 |
else:
|
| 233 |
-
gen_smi = pd.read_csv("
|
| 234 |
|
| 235 |
|
| 236 |
et = time.time() - start_time
|
| 237 |
|
| 238 |
-
with open("
|
| 239 |
for i in gen_smi:
|
| 240 |
f.write(i)
|
| 241 |
f.write("\n")
|
|
@@ -265,9 +265,9 @@ if __name__=="__main__":
|
|
| 265 |
|
| 266 |
# Data configuration.
|
| 267 |
parser.add_argument('--inf_dataset_file', type=str, default='chembl45_test.pt')
|
| 268 |
-
parser.add_argument('--inf_raw_file', type=str, default='
|
| 269 |
parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
|
| 270 |
-
parser.add_argument('--mol_data_dir', type=str, default='
|
| 271 |
parser.add_argument('--features', type=str2bool, default=False, help='features dimension for nodes')
|
| 272 |
|
| 273 |
# Model configuration.
|
|
|
|
| 114 |
|
| 115 |
def decoder_load(self, dictionary_name):
|
| 116 |
''' Loading the atom and bond decoders'''
|
| 117 |
+
with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
| 118 |
return pickle.load(f)
|
| 119 |
|
| 120 |
|
|
|
|
| 140 |
self.restore_model(self.submodel, self.inference_model)
|
| 141 |
|
| 142 |
# smiles data for metrics calculation.
|
| 143 |
+
chembl_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
| 144 |
+
chembl_test = [line for line in open("data/chembl_test.smi", 'r').read().splitlines()]
|
| 145 |
+
drug_smiles = [line for line in open("data/akt_inhibitors.smi", 'r').read().splitlines()]
|
| 146 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 147 |
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
|
| 148 |
|
| 149 |
|
| 150 |
# Make directories if not exist.
|
| 151 |
+
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
| 152 |
+
os.makedirs("experiments/inference/{}".format(self.submodel))
|
| 153 |
if self.correct:
|
| 154 |
correct = smi_correct(self.submodel, "DrugGEN_/experiments/inference/{}".format(self.submodel))
|
| 155 |
search_res = pd.DataFrame(columns=["submodel", "validity",
|
|
|
|
| 166 |
uniqueness_calc = []
|
| 167 |
real_smiles_snn = []
|
| 168 |
nodes_sample = torch.Tensor(size=[1,45,1]).to(self.device)
|
| 169 |
+
f = open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w")
|
| 170 |
f.write("SMILES")
|
| 171 |
f.write("\n")
|
| 172 |
val_counter = 0
|
|
|
|
| 226 |
f.close()
|
| 227 |
print("Inference completed, starting metrics calculation.")
|
| 228 |
if self.correct:
|
| 229 |
+
corrected = correct.correct("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
|
| 230 |
gen_smi = corrected["SMILES"].tolist()
|
| 231 |
|
| 232 |
else:
|
| 233 |
+
gen_smi = pd.read_csv("experiments/inference/{}/inference_drugs.txt".format(self.submodel))["SMILES"].tolist()
|
| 234 |
|
| 235 |
|
| 236 |
et = time.time() - start_time
|
| 237 |
|
| 238 |
+
with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w") as f:
|
| 239 |
for i in gen_smi:
|
| 240 |
f.write(i)
|
| 241 |
f.write("\n")
|
|
|
|
| 265 |
|
| 266 |
# Data configuration.
|
| 267 |
parser.add_argument('--inf_dataset_file', type=str, default='chembl45_test.pt')
|
| 268 |
+
parser.add_argument('--inf_raw_file', type=str, default='data/chembl_test.smi')
|
| 269 |
parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
|
| 270 |
+
parser.add_argument('--mol_data_dir', type=str, default='data')
|
| 271 |
parser.add_argument('--features', type=str2bool, default=False, help='features dimension for nodes')
|
| 272 |
|
| 273 |
# Model configuration.
|