Commit f83b8d3b authored by Carlos GO's avatar Carlos GO
Browse files

decoy utils

parent 12ce2a5a
......@@ -2,9 +2,11 @@
Swap the fingerprints around from an annotated folder.
"""
import os, pickle
import random
import numpy as np
def scramble_fingerprints(anot_dir, dump_dir):
def scramble_fingerprints(annot_dir, dump_dir):
"""
Assign a fingerprint to each graph chosen at random from all other
fingerprints in annot_dir.
......@@ -13,9 +15,32 @@ def scramble_fingerprints(anot_dir, dump_dir):
datas = os.listdir(annot_dir)
indices = range(len(datas))
fps = []
for i,g in datas:
g,tree,ring,fp = pickle.load(open(os.path.join(annot_dir, g)))
for i,g in enumerate(datas):
_,_,_,fp = pickle.load(open(os.path.join(annot_dir, g), 'rb'))
fps.append(fp)
for i,g in datas:
pickle.dump((g,tree,ring,fps[random.choice(indices)]), os.path.join(dump_dir, g))
for i,g in enumerate(datas):
G,tree,ring,fp = pickle.load(open(os.path.join(annot_dir, g), 'rb'))
new_fp = fps[random.choice(indices)]
print(f"old fp {fp}")
print(f"new fp {new_fp}")
pickle.dump((G,tree,ring,new_fp), open(os.path.join(dump_dir, g), 'wb'))
pass
def random_fingerprints(annot_dir, dump_dir):
"""
Assign a fingerprint to each graph chosen at random from all other
fingerprints in annot_dir.
"""
datas = os.listdir(annot_dir)
for i,g in enumerate(datas):
G,tree,ring,_ = pickle.load(open(os.path.join(annot_dir, g), 'rb'))
rand_fp = np.random.randint(2, size=166)
print(rand_fp)
pickle.dump((G,tree,ring,rand_fp), open(os.path.join(dump_dir, g), 'wb'))
pass
if __name__ == "__main__":
scramble_fingerprints('../data/annotated/pockets_nx_symmetric', '../data/annotated/pockets_nx_symmetric_scramble')
# random_fingerprints('../data/annotated/pockets_nx_symmetric', '../data/annotated/pockets_nx_symmetric_random')
pass
"""
Functions for getting decoys.
"""
import os
import pickle
import random
import numpy as np
from scipy.spatial.distance import jaccard
def get_decoys(mode='pdb', annots_dir='../data/annotated/pockets_nx'):
"""
Build decoys set for validation.
"""
if mode == 'pdb-whole':
fp_dict = pickle.load(open('data/all_ligs_pdb_maccs.p', 'rb'))
return fp_dict
if mode=='pdb':
fp_dict = {}
for g in os.listdir(annots_dir):
try:
lig_id = g.split(":")[2]
except Exception as e:
print(f"failed on {g}, {e}")
continue
_,_,_,fp = pickle.load(open(os.path.join(annots_dir, g), 'rb'))
fp_dict[lig_id] = fp
#fp_dict = {'lig_id': [fp], ...}
#ligand's decoys are all others except for the active
decoy_dict = {k:(v, [f for lig,f in fp_dict.items() if lig != k]) for k,v in fp_dict.items()}
return decoy_dict
pass
def distance_rank(active, pred, decoys, dist_func=jaccard):
"""
Get rank of prediction in `decoys` given a known active ligand.
"""
pred_dist = dist_func(active, pred)
rank = 0
for lig in decoys:
d = dist_func(active, lig)
if d < pred_dist:
rank += 1
return 1- (rank / (len(decoys) + 1))
def decoy_test(fp_pred, true_id, decoys,
shuffle=False):
"""
Check performance against decoy set.
decoys --> {'ligand_id', ('expected_FP', [decoy_fps])}
test_set --> [annot_graph_path,]
:model trained model
:test_set inputs for model to test (RNA graphs)
:decoys dictionary with list of decoys for each input to test.
:test_graphlist list of graph names to use in the test.
:return: enrichment score
"""
ranks = []
sims = []
ligs = list(decoys.keys())
try:
decoys[true_id]
except KeyError:
print("missing fp", true_id)
fp_pred = fp_pred.detach().numpy() > 0.5
fp_pred = fp_pred.astype(int)
if shuffle:
#pick a random ligand to be the true one
orig = true_id
true_id = np.random.choice(ligs, replace=False)
active = decoys[true_id][0]
decs = decoys[true_id][1]
rank = distance_rank(active, fp_pred, decs)
sim = jaccard(active, fp_pred)
return rank, sim
def decoy_test_(fp_pred, true_fp, decoys,
shuffle=False):
"""
Check performance against decoy set.
:decoys list of decoys
:return: enrichment score
"""
fp_pred = fp_pred.detach().numpy() > 0.5
fp_pred = fp_pred.astype(int)
print(fp_pred, true_fp)
rank = distance_rank(true_fp, fp_pred, decoys)
sim = jaccard(true_fp, fp_pred)
return rank, sim
......@@ -2,7 +2,9 @@ import time
import torch
import torch.nn.functional as F
import sys
import dgl
import networkx as nx
#debug modules
import numpy as np
......@@ -150,8 +152,6 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
fp_lam_orig = fp_lam
reconstruction_lam_orig = reconstruction_lam
dec_mode = 'pdb-whole'
batch_size = train_loader.batch_size
#if we delay attributor, start with attributor OFF
#if <= -1, both always ON.
......@@ -190,6 +190,9 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
fp = fp.long()
fp = fp.to(device)
# for f, i in zip(fp, idx):
# print(all_graphs[i.item()], f)
fp_pred, embeddings = model(graph)
if fp_draw:
......
......@@ -36,6 +36,7 @@ class V1(Dataset):
nucs (bool): whether to include nucleotide ID in node (default=False).
depth (int): number of hops to use in the node kernel.
"""
print(f">>> fetching data from {annotated_path}")
self.path = annotated_path
self.all_graphs = sorted(os.listdir(annotated_path))
if seed:
......@@ -139,7 +140,7 @@ def collate_wrapper(node_sim_func, get_sim_mat=True):
idx = np.array(idx)
batched_graph = dgl.batch(graphs)
K = k_block_list(rings, node_sim_func)
return batched_graph, torch.from_numpy(K).detach().float(), torch.from_numpy(fp).detach().float(), torch.from_numpy(idx)
return batched_graph, torch.from_numpy(K).float(), torch.from_numpy(fp).float(), torch.from_numpy(idx)
else:
def collate_block(samples):
# The input `samples` is a list of pairs
......@@ -188,6 +189,7 @@ class Loader():
def get_data(self, k_fold=0):
n = len(self.dataset)
indices = list(range(n))
collate_block = collate_wrapper(self.dataset.node_sim_func)
if k_fold > 1:
from sklearn.model_selection import KFold
......@@ -196,11 +198,9 @@ class Loader():
train_set = Subset(self.dataset, train_indices)
test_set = Subset(self.dataset, test_indices)
collate_block = collate_wrapper(self.dataset.node_sim_func)
train_loader = DataLoader(dataset=train_set, batch_size=self.batch_size,
train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=collate_block)
test_loader = DataLoader(dataset=test_set, batch_size=self.batch_size,
test_loader = DataLoader(dataset=test_set, shuffle=True, batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=collate_block)
yield train_loader, test_loader
......@@ -219,11 +219,9 @@ class Loader():
print("training graphs ", len(train_set))
print("testing graphs ", len(test_set))
collate_block = collate_wrapper(self.dataset.node_sim_func)
train_loader = DataLoader(dataset=train_set, batch_size=self.batch_size,
train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=collate_block)
test_loader = DataLoader(dataset=test_set, batch_size=self.batch_size,
test_loader = DataLoader(dataset=test_set, shuffle=True, batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=collate_block)
# return train_loader, valid_loader, test_loader
......
......@@ -130,10 +130,10 @@ for k, (train_loader, test_loader) in enumerate(data):
print("warm starting")
m = torch.load(args.warm_start, map_location='cpu')['model_state_dict']
#remove keys not related to embeddings
for k in list(m.keys()):
if 'embedder' not in k:
print("killing ", k)
del m[k]
for key in list(m.keys()):
if 'embedder' not in key:
print("killing ", key)
del m[key]
missing = model.load_state_dict(m, strict=False)
print(missing)
......@@ -153,6 +153,7 @@ for k, (train_loader, test_loader) in enumerate(data):
'''
name = f"{args.name}_{k}"
print(name)
result_folder, save_path = mkdirs(name)
print(save_path)
writer = SummaryWriter(result_folder)
......
......@@ -12,35 +12,6 @@ from dgl.nn.pytorch.glob import SumPooling,GlobalAttentionPooling
from dgl import mean_nodes
from dgl.nn.pytorch.conv import RelGraphConv
class JaccardDistanceLoss(torch.nn.Module):
def __init__(self, smooth=100, dim=1, size_average=True, reduce=True):
"""
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
= sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
The jaccard distance loss is usefull for unbalanced datasets. This has been
shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
gradient.
Ref: https://en.wikipedia.org/wiki/Jaccard_index
@url: https://gist.github.com/wassname/d1551adac83931133f6a84c5095ea101
@author: wassname
"""
super(JaccardDistanceLoss, self).__init__()
self.smooth = smooth
self.dim = dim
self.size_average = size_average
self.reduce = reduce
def forward(self, y_true, y_pred):
intersection = (y_true * y_pred).abs().sum(self.dim)
sum_ = (y_true.abs() + y_pred.abs()).sum(self.dim)
jac = (intersection + self.smooth) / (sum_ - intersection + self.smooth)
losses = (1 - jac) * self.smooth
if self.reduce:
return losses.mean() if self.size_average else losses.sum()
else:
return losses
class Attributor(nn.Module):
"""
NN which makes a prediction (fp or binding/non binding) from a pooled
......@@ -241,8 +212,8 @@ class Model(nn.Module):
if self.clustered:
loss = torch.nn.CrossEntropyLoss()(pred_fp, target_fp)
else:
# loss = torch.nn.MSELoss()(pred_fp, target_fp)
loss = torch.nn.BCELoss()(pred_fp, target_fp)
# loss = JaccardDistanceLoss()(pred_fp, target_fp)
return loss
......
......@@ -32,6 +32,10 @@ from learning.utils import dgl_to_nx
from tools.learning_utils import load_model
from post.drawing import rna_draw
def mse(x,y):
d = np.sum((x-y)**2) / len(x)
return d
def get_decoys(mode='pdb', annots_dir='../data/annotated/pockets_nx_2'):
"""
Build decoys set for validation.
......@@ -103,15 +107,14 @@ def decoy_test(model, decoys, edge_map, embed_dim,
fp_pred, _ = model(dgl_graph)
fp_pred = fp_pred.detach().numpy() > 0.5
# print(fp_pred)
# fp_pred = np.random.choice([0, 1], size=(166,), p=[1./2, 1./2])
# fp_pred = fp_pred.detach().numpy()
if shuffle:
orig = true_id
true_id = np.random.choice(ligs, replace=False)
# true_id = np.random.choice(ligs, replace=False)
fp_pred = np.random.rand(166)
active = decoys[true_id][0]
decs = decoys[true_id][1]
rank = distance_rank(active, fp_pred, decs)
sim = jaccard(active, fp_pred)
sim = mse(active, fp_pred)
ranks.append(rank)
sims.append(sim)
return ranks, sims
......@@ -144,18 +147,28 @@ def make_violins(df, x='method', y='rank', save=None, show=True):
def ablation_results():
# modes = ['', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle', 'pair-shuffle']
modes = ['raw', 'bb', 'wc-bb', 'pair-shuffle']
# modes = ['raw', 'bb', 'wc-bb', 'pair-shuffle']
modes = ['raw', 'bb', 'wc-bb', 'swap', 'random']
decoys = get_decoys(mode='pdb')
ranks, methods, jaccards = [], [], []
graph_dir = '../data/annotated/pockets_nx_symmetric'
# graph_dir = '../data/annotated/pockets_nx_2'
run = 'ismb'
# run = 'teste'
# run = 'random'
num_folds = 10
for m in modes:
if m in ['raw', 'pair-shuffle']:
graph_dir = "../data/annotated/pockets_nx_symmetric"
run = 'ismb-raw'
# run = 'teste'
elif m == 'swap':
graph_dir = '../data/annotated/pockets_nx_symmetric_scramble'
run = 'ismb-' + m
elif m == 'random':
graph_dir = '../data/annotated/pockets_nx_symmetric_random'
run = 'random'
else:
graph_dir = "../data/annotated/pockets_nx_symmetric_" + m
run = 'ismb-' + m
......@@ -163,11 +176,13 @@ def ablation_results():
for fold in range(num_folds):
model, meta = load_model(run +"_" + str(fold))
# model, meta = load_model(run)
edge_map = meta['edge_map']
embed_dim = meta['embedding_dims'][-1]
num_edge_types = len(edge_map)
graph_ids = pickle.load(open(f'../results/trained_models/{run}_{fold}/splits_{fold}.p', 'rb'))
# graph_ids = pickle.load(open(f'../results/trained_models/{run}/splits.p', 'rb'))
shuffle = False
if m == 'pair-shuffle':
......@@ -182,20 +197,20 @@ def ablation_results():
jaccards.extend(sims_this)
methods.extend([m]*len(ranks_this))
#decoy distance distribution
# dists = []
# for _,(active, decs) in decoys.items():
# for d in decs:
# dists.append(jaccard(active, d))
# decoy distance distribution
dists = []
for _,(active, decs) in decoys.items():
for d in decs:
dists.append(jaccard(active, d))
# plt.scatter(ranks_this, sims_this)
# plt.xlabel("ranks")
# plt.ylabel("distance")
# plt.show()
# sns.distplot(dists, label='decoy distance')
# sns.distplot(sims_this, label='pred distance')
# plt.xlabel("distance")
# plt.legend()
# plt.show()
sns.distplot(dists, label='decoy distance')
sns.distplot(sims_this, label='pred distance')
plt.xlabel("distance")
plt.legend()
plt.show()
# # rank_cut = 0.9
# cool = [graph_ids['test'][i] for i,(d,r) in enumerate(zip(sims_this, ranks_this)) if d <0.4 and r > 0.8]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment