Commit 7ae793bd authored by Carlos GO's avatar Carlos GO
Browse files

drawing option

parent 70cdc345
......@@ -39,7 +39,8 @@ def get_valids(lig_dict, max_dist, min_conc, min_size=4):
unique_ligs = set()
for pdb, ligands in lig_dict.items():
for lig_id,lig_cuts in ligands:
lig_name = lig_id.split(":")[1]
# lig_name = lig_id.split(":")[1]
lig_name = lig_id.split(":")[2]
#go over each distance cutoff
for c in lig_cuts:
tot = c['rna'] + c['protein']
......@@ -73,9 +74,9 @@ def ligs_to_txt(d, dest="../data/ligs.txt"):
o.write(" ".join([pdb, *ligs]) + "\n")
pass
if __name__ == "__main__":
d = pickle.load(open('../data/lig_dict.p', 'rb'))
d = pickle.load(open('../data/lig_dict_ismb.p', 'rb'))
c = 10
conc = .6
ligs = get_valids(d, c, conc, min_size=5)
# pickle.dump(ligs, open("../data/lig_dict_r10_d06.p", "wb"))
ligs = get_valids(d, c, conc, min_size=4)
pickle.dump(ligs, open("../data/lig_dict_ismb_rna06_rad10.p", "wb"))
# ligs_to_txt(ligs)
......@@ -122,11 +122,12 @@ def get_pocket_graph(pdb_structure_path, ligand_id, graph,
assert labels.issubset(valid_edges)
print(pocket)
rna_draw(G, title="BINDING")
# rna_draw(G, title="BINDING")
if len(G.nodes()) < 4:
return None
# if dump_path and (len(G.nodes()) > 4):
# nx.write_gpickle(G, os.path.join(dump_path, f"{pdbid}_{ligand_id}_BIND.nx"))
if dump_path:
nx.write_gpickle(G, os.path.join(dump_path, f"{pdbid}_{ligand_id}_BIND.nx"))
#sample and build non-binding graph.
if non_binding:
......@@ -164,6 +165,7 @@ def get_binding_site_graphs_all(lig_dict_path, dump_path, non_binding=False):
lig_dict = pickle.load(open(lig_dict_path, 'rb'))
print(f">>> building graphs for {len(lig_dict)} PDBs")
print(f">>> dumping in {dump_path}")
print(f">>> and {sum(map(len, lig_dict.values()))} binding sites.")
failed = 0
......@@ -177,12 +179,15 @@ def get_binding_site_graphs_all(lig_dict_path, dump_path, non_binding=False):
print(f">>> skipping {done_pdbs}")
failed = []
empties = 0
num_found = 0
missing_graphs = []
for pdbid, ligs in tqdm(lig_dict.items()):
pdbid = pdbid.split(".")[0]
pdb_path = f"../data/all_rna_prot_lig_2019/{pdbid}.cif"
# if pdbid in done_pdbs:
# continue
# pdb_path = f"../data/all_rna_prot_lig_2019/{pdbid}.cif"
pdb_path = f"../../carlos_docking/data/all_rna_with_lig_2019/{pdbid}.cif"
if pdbid in done_pdbs:
continue
# try:
print(">>> ", pdbid)
try:
......@@ -197,19 +202,26 @@ def get_binding_site_graphs_all(lig_dict_path, dump_path, non_binding=False):
for lig in ligs:
#dump binding site graphs
try:
get_pocket_graph(pdb_path, lig,
g = get_pocket_graph(pdb_path, lig,
pdb_graph, dump_path=dump_path,
non_binding=non_binding)
if g is None:
empties += 1
else:
num_found += 1
print(f">>> pockets so far {num_found}")
except FileNotFoundError:
print(f"{pdbid} not found")
failed.append(pdbid)
print(f">>> missing graphs for {missing_graphs}")
print(failed)
print(f">>> failed on {len(failed)} graphs")
print(f">>> got {empties} empty graphs")
if __name__ == "__main__":
#take all ligands with 8 angstrom sphere and 0.6 RNA concentration, build a graph for each.
# get_binding_site_graphs_all('../data/lig_dict_c_8A_06rna.p','../data/pockets_nx_pfind',
# non_binding=True)
get_binding_site_graphs_all('../data/lig_dict_c_10A_08rna.p', '../data/pockets_nx_large', non_binding=False)
get_binding_site_graphs_all('../data/lig_dict_ismb_rna06_rad10.p', '../data/pockets_nx_ismb', non_binding=False)
pass
......@@ -13,14 +13,18 @@ from sklearn.cluster import AgglomerativeClustering
def ligands_cluster(bs_dict, fp_dict, n_clusters=8):
"""
Assign cluster labels to each ligand in ligand_list.
Create new fingerprint dictionary {'lig_id': cluster_id}
"""
#get which ligands to use in clustering
binding_sites = pickle.load(open(bs_dict, 'rb'))
fingerprints = pickle.load(open(fp_dict, 'rb'))
ligs_2_cluster = []
for _,ligs in binding_sites.items():
ligs_2_cluster.extend([f.split(":")[2] for f in ligs])
ligs_2_cluster = list(set(ligs_2_cluster))
pocket_ids = [f.split(":")[2] for f in ligs]
ligs_2_cluster.extend(pocket_ids)
# ligs_2_cluster_unique = list(set(ligs_2_cluster))
fps = []
for l in ligs_2_cluster:
......@@ -33,10 +37,13 @@ def ligands_cluster(bs_dict, fp_dict, n_clusters=8):
clusterer = AgglomerativeClustering(n_clusters=n_clusters)
clusterer.fit(fps)
labels = clusterer.labels_
clustered_fp_dict = dict(zip(ligs_2_cluster, labels))
sns.distplot(labels)
plt.show()
return clustered_fp_dict
pass
if __name__ == "__main__":
ligands_cluster("../data/lig_dict_c_8A_06rna.p", "../data/all_ligs_maccs.p")
clustered_fp_dict = ligands_cluster("../data/lig_dict_c_8A_06rna.p", "../data/all_ligs_maccs.p")
pickle.dump(clustered_fp_dict, open("../data/fp_dict_8clusters.p", 'wb'))
......@@ -61,7 +61,7 @@ def print_gradients(model):
name, p = param
print(name, p.grad)
pass
def test(model, test_loader, device, decoys=None):
def test(model, test_loader, device, decoys=None, fp_draw=False):
"""
Compute accuracy and loss of model over given dataset
:param model:
......@@ -85,27 +85,48 @@ def test(model, test_loader, device, decoys=None):
with torch.no_grad():
fp_pred, embeddings = model(graph)
loss = model.compute_loss(fp, fp_pred)
kws = {'cbar': False,
'square':False,
'vmin': 0,
'vmax': 1}
jaccards = []
enrichments = []
for i,f in zip(idx, fp_pred):
true_lig = all_graphs[i.item()].split(":")[2]
rank,sim = decoy_test(f, true_lig, decoys)
enrichments.append(rank)
decoy_ranks = np.mean(enrichments)
jaccards.append(sim)
mean_ranks = np.mean(enrichments)
mean_jaccard = np.mean(jaccards)
del K
del fp
del graph
test_loss += loss.item()
del loss
if fp_draw:
fig, (ax1, ax2, ax3) = plt.subplots(1,3)
sns.heatmap(fp, ax=ax1, **kws)
bina = fp_pred > 0.5
fp_true = fp.clone().detach()
fp_true = fp_true.int()
bina = bina.int()
sns.heatmap(bina, ax=ax2, **kws)
sns.heatmap(fp_true != bina, ax=ax3, **kws)
ax1.set_title("True")
ax2.set_title("Pred")
ax3.set_title("Diff")
plt.show()
del fp
return test_loss / test_size, decoy_ranks
return test_loss / test_size, mean_ranks, mean_jaccard
def train_model(model, criterion, optimizer, device, train_loader, test_loader, save_path,
writer=None, num_epochs=25, wall_time=None,
reconstruction_lam=1, motif_lam=1, embed_only=-1,
decoys=None):
reconstruction_lam=1, fp_lam=1, embed_only=-1,
decoys=None, early_stop_threshold=10, fp_draw=False):
"""
Performs the entire training routine.
:param model: (torch.nn.Module): the model to train
......@@ -119,22 +140,22 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
:param num_epochs: int number of epochs
:param wall_time: The number of hours you want the model to run
:param reconstruction_lam: how much to enforce pariwise similarity conservation
:param motif_lam: how much to enforce motif assignment
:param fp_lam: how much to enforce motif assignment
:param embed_only: number of epochs before starting attributor training.
:return:
"""
edge_map = train_loader.dataset.dataset.edge_map
all_graphs = train_loader.dataset.dataset.all_graphs
decoys = get_decoys(mode='pdb', annots_dir=train_loader.dataset.dataset.path)
epochs_from_best = 0
early_stop_threshold = 10
start_time = time.time()
best_loss = sys.maxsize
motif_lam_orig = motif_lam
fp_lam_orig = fp_lam
reconstruction_lam_orig = reconstruction_lam
#if we delay attributor, start with attributor OFF
......@@ -142,7 +163,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
if embed_only > -1:
print("Switching attriutor OFF. Embeddings still ON.")
set_gradients(model, attributor=False)
motif_lam = 0
fp_lam = 0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch + 1, num_epochs))
......@@ -156,7 +177,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
print("Switching attributor ON, embeddings OFF.")
set_gradients(model, embedding=False, attributor=True)
reconstruction_lam = 0
motif_lam = motif_lam_orig
fp_lam = fp_lam_orig
running_loss = 0.0
......@@ -164,6 +185,8 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
num_batches = len(train_loader)
train_enrichments = []
train_jaccards = []
for batch_idx, (graph, K, fp, idx) in enumerate(train_loader):
# Get data on the devices
......@@ -173,7 +196,31 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
fp_pred, embeddings = model(graph)
if fp_draw:
fig, (ax1, ax2, ax3) = plt.subplots(1,3)
kws = {'cbar': False,
'square':False,
'vmin': 0,
'vmax': 1}
sns.heatmap(fp, ax=ax1, **kws)
bina = fp_pred > 0.5
fp_true = fp.clone().detach()
fp_true = fp_true.int()
bina = bina.int()
sns.heatmap(bina, ax=ax2, **kws)
sns.heatmap(fp_true != bina, ax=ax3, **kws)
ax1.set_title("True")
ax2.set_title("Pred")
ax3.set_title("Diff")
plt.show()
loss = model.compute_loss(fp, fp_pred)
for i,f in zip(idx, fp_pred):
true_lig = all_graphs[i.item()].split(":")[2]
rank,sim = decoy_test(f, true_lig, decoys)
train_enrichments.append(rank)
train_jaccards.append(sim)
# l = model.rec_loss(embeddings, K, similarity=False)
# print(l)
......@@ -188,7 +235,6 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
batch_loss = loss.item()
running_loss += batch_loss
# running_corrects += labels.eq(target.view_as(out)).sum().item()
if batch_idx % 20 == 0:
time_elapsed = time.time() - start_time
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Time: {:.2f}'.format(
......@@ -208,14 +254,17 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# Log training metrics
train_loss = running_loss / num_batches
writer.add_scalar("Training epoch loss", train_loss, epoch)
print(">> train enrichments", np.mean(train_enrichments))
print(">> train jaccards", np.mean(train_jaccards))
# train_accuracy = running_corrects / num_batches
# writer.log_scalar("Train accuracy during training", train_accuracy, epoch)
# Test phase
test_loss, enrichments = test(model, test_loader, device, decoys=decoys)
test_loss, enrichments, jaccards = test(model, test_loader, device, decoys=decoys)
print(">> test loss ", test_loss)
print(">> test enrichments", enrichments)
print(">> test jaccards ", jaccards)
writer.add_scalar("Test loss during training", test_loss, epoch)
......
......@@ -79,8 +79,8 @@ class V1(Dataset):
one_hot_nucs = {node: torch.tensor(self.nuc_map[label], dtype=torch.float32) for node, label in
(nx.get_node_attributes(graph, 'nt')).items()}
else:
one_hot_nucs = {node: torch.tensor(0, dtype=torch.float32) for node, label in
(nx.get_node_attributes(graph, 'nt')).items()}
one_hot_nucs = {node: torch.tensor(0, dtype=torch.float32) for node in
graph.nodes()}
nx.set_node_attributes(graph, name='one_hot', values=one_hot_nucs)
......@@ -212,9 +212,9 @@ class Loader():
collate_block = collate_wrapper(self.dataset.node_sim_func)
train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=self.batch_size,
train_loader = DataLoader(dataset=train_set, batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=collate_block)
test_loader = DataLoader(dataset=test_set, shuffle=True, batch_size=self.batch_size,
test_loader = DataLoader(dataset=test_set, batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=collate_block)
# return train_loader, valid_loader, test_loader
......
......@@ -16,9 +16,8 @@ parser.add_argument("-da", "--annotated_data", default='pockets_nx_symmetric')
parser.add_argument("-bs", "--batch_size", type=int, default=8, help="choose the batch size")
parser.add_argument("-nw", "--workers", type=int, default=20, help="Number of workers to load data")
parser.add_argument("-n", "--name", type=str, default='default_name', help="Name for the logs")
parser.add_argument("-t", "--timed", help="to use timed learn", action='store_true')
parser.add_argument("-ep", "--num_epochs", type=int, help="number of epochs to train", default=3)
parser.add_argument("-ml", "--motif_lam", type=float, help="motif lambda", default=1.0)
parser.add_argument("-fl", "--fp_lam", type=float, help="fingerprint lambda", default=1.0)
parser.add_argument("-rl", "--reconstruction_lam", type=float, help="reconstruction lambda", default=1.0)
parser.add_argument('-ad','--attributor_dims', nargs='+', type=int, help='Dimensions for attributor.', default=[16,166])
parser.add_argument('-ed','--embedding_dims', nargs='+', type=int, help='Dimensions for embeddings.', default=[16]*3)
......@@ -30,10 +29,12 @@ parser.add_argument('-po', '--pool', type=str, default='sum', help='Pooling func
parser.add_argument("-nu", "--nucs", default=True, help="Use nucleotide IDs for learn", action='store_false')
parser.add_argument('-rs', '--seed', type=int, default=0, help='Random seed to use (if > 0, else no seed is set).')
parser.add_argument('-kf', '--kfold', type=int, default=0, help='Do k-fold crossval and do decoys on each fold..')
parser.add_argument('-es', '--early_stop', type=int, default=10, help='Early stop epoch threshold (default=10)')
args = parser.parse_args()
print(f"OPTIONS USED: {args}")
print("OPTIONS USED")
print("\n".join(map(str, zip(vars(args).items()))))
# Torch impors
import torch
import torch.optim as optim
......@@ -110,7 +111,7 @@ else:
if dims[-1] != attributor_dims[0] - dim_add:
raise ValueError(f"Final embedding size must match first attributor dimension: {dims[-1]} != {attributor_dims[0]}")
motif_lam = args.motif_lam
fp_lam = args.fp_lam
reconstruction_lam = args.reconstruction_lam
data = loader.get_data(k_fold=args.kfold)
......@@ -183,5 +184,6 @@ for k, (train_loader, test_loader) in enumerate(data):
writer=writer,
num_epochs=num_epochs,
reconstruction_lam=reconstruction_lam,
motif_lam=motif_lam,
embed_only=args.embed_only)
fp_lam=fp_lam,
embed_only=args.embed_only,
early_stop_threshold=args.early_stop)
......@@ -12,6 +12,34 @@ from dgl.nn.pytorch.glob import SumPooling,GlobalAttentionPooling
from dgl import mean_nodes
from dgl.nn.pytorch.conv import RelGraphConv
class JaccardDistanceLoss(torch.nn.Module):
def __init__(self, smooth=100, dim=1, size_average=True, reduce=True):
"""
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
= sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
The jaccard distance loss is usefull for unbalanced datasets. This has been
shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
gradient.
Ref: https://en.wikipedia.org/wiki/Jaccard_index
@url: https://gist.github.com/wassname/d1551adac83931133f6a84c5095ea101
@author: wassname
"""
super(JaccardDistanceLoss, self).__init__()
self.smooth = smooth
self.dim = dim
self.size_average = size_average
self.reduce = reduce
def forward(self, y_true, y_pred):
intersection = (y_true * y_pred).abs().sum(self.dim)
sum_ = (y_true.abs() + y_pred.abs()).sum(self.dim)
jac = (intersection + self.smooth) / (sum_ - intersection + self.smooth)
losses = (1 - jac) * self.smooth
if self.reduce:
return losses.mean() if self.size_average else losses.sum()
else:
return losses
class Attributor(nn.Module):
"""
......@@ -206,6 +234,7 @@ class Model(nn.Module):
# pw = torch.tensor([self.pos_weight], dtype=torch.float, requires_grad=False).to(self.device)
# loss = torch.nn.BCEWithLogitsLoss(pos_weight=pw)(pred_fp, target_fp)
loss = torch.nn.BCELoss()(pred_fp, target_fp)
# loss = JaccardDistanceLoss()(pred_fp, target_fp)
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment