Commit 464dc769 authored by Carlos GO's avatar Carlos GO
Browse files

reconstrution compute

parent f83b8d3b
......@@ -63,7 +63,7 @@ def print_gradients(model):
name, p = param
print(name, p.grad)
pass
def test(model, test_loader, device, fp_draw=False):
def test(model, test_loader, device, fp_lam=1, rec_lam=1):
"""
Compute accuracy and loss of model over given dataset
:param model:
......@@ -89,31 +89,15 @@ def test(model, test_loader, device, fp_draw=False):
# Do the computations for the forward pass
with torch.no_grad():
fp_pred, embeddings = model(graph)
loss = model.compute_loss(fp, fp_pred)
kws = {'cbar': False,
'square':False,
'vmin': 0,
'vmax': 1}
loss = model.compute_loss(fp, fp_pred, embeddings, K,
fp_lam=fp_lam,
rec_lam=rec_lam)
del K
del graph
test_loss += loss.item()
del loss
if fp_draw:
fig, (ax1, ax2, ax3) = plt.subplots(1,3)
sns.heatmap(fp, ax=ax1, **kws)
bina = fp_pred > 0.5
fp_true = fp.clone().detach()
fp_true = fp_true.int()
bina = bina.int()
sns.heatmap(bina, ax=ax2, **kws)
sns.heatmap(fp_true != bina, ax=ax3, **kws)
ax1.set_title("True")
ax2.set_title("Pred")
ax3.set_title("Diff")
plt.show()
del fp
......@@ -122,7 +106,7 @@ def test(model, test_loader, device, fp_draw=False):
def train_model(model, criterion, optimizer, device, train_loader, test_loader, save_path,
writer=None, num_epochs=25, wall_time=None,
reconstruction_lam=1, fp_lam=1, embed_only=-1,
early_stop_threshold=10, fp_draw=False):
early_stop_threshold=10, fp_draw=False):
"""
Performs the entire training routine.
:param model: (torch.nn.Module): the model to train
......@@ -195,26 +179,8 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
fp_pred, embeddings = model(graph)
if fp_draw:
fig, (ax1, ax2, ax3) = plt.subplots(1,3)
kws = {'cbar': False,
'square':False,
'vmin': 0,
'vmax': 1}
sns.heatmap(fp, ax=ax1, **kws)
bina = fp_pred > 0.5
fp_true = fp.clone().detach()
fp_true = fp_true.int()
bina = bina.int()
sns.heatmap(bina, ax=ax2, **kws)
sns.heatmap(fp_true != bina, ax=ax3, **kws)
ax1.set_title("True")
ax2.set_title("Pred")
ax3.set_title("Diff")
plt.show()
loss = model.compute_loss(fp, fp_pred)
loss = model.compute_loss(fp, fp_pred, embeddings, K,
fp_lam=fp_lam, rec_lam=reconstruction_lam)
# l = model.rec_loss(embeddings, K, similarity=False)
# print(l)
......@@ -253,7 +219,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# writer.log_scalar("Train accuracy during training", train_accuracy, epoch)
# Test phase
test_loss = test(model, test_loader, device)
test_loss = test(model, test_loader, device, fp_lam=fp_lam, rec_lam=reconstruction_lam)
print(">> test loss ", test_loss)
writer.add_scalar("Test loss during training", test_loss, epoch)
......
......@@ -169,7 +169,7 @@ class Model(nn.Module):
target_K = torch.ones(target_K.shape, device=target_K.device) - target_K
reconstruction_loss = torch.nn.MSELoss()(K_predict, target_K)
self.draw_rec(target_K, K_predict)
# self.draw_rec(target_K, K_predict)
return reconstruction_loss
# Below are loss computation function related to this model
@staticmethod
......@@ -194,7 +194,17 @@ class Model(nn.Module):
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
return sim_mt
def compute_loss(self, target_fp, pred_fp):
def fp_loss(self, target_fp, pred_fp):
if self.clustered:
loss = torch.nn.CrossEntropyLoss()(pred_fp, target_fp)
else:
# loss = torch.nn.MSELoss()(pred_fp, target_fp)
loss = torch.nn.BCELoss()(pred_fp, target_fp)
return loss
def compute_loss(self, target_fp, pred_fp, embeddings, target_K,
rec_lam=1,
fp_lam=1,
similarity=False):
"""
Compute the total loss of the model.
Includes the reconstruction loss with optional similarity/distance boolean switch
......@@ -207,14 +217,8 @@ class Model(nn.Module):
:param scaled:
:return:
"""
# pw = torch.tensor([self.pos_weight], dtype=torch.float, requires_grad=False).to(self.device)
# loss = torch.nn.BCEWithLogitsLoss(pos_weight=pw)(pred_fp, target_fp)
if self.clustered:
loss = torch.nn.CrossEntropyLoss()(pred_fp, target_fp)
else:
# loss = torch.nn.MSELoss()(pred_fp, target_fp)
loss = torch.nn.BCELoss()(pred_fp, target_fp)
loss = fp_lam * self.fp_loss(target_fp, pred_fp)\
+ rec_lam * self.rec_loss(embeddings, target_K, similarity=similarity)
return loss
def draw_rec(self, true_K, predicted_K, title=""):
......
......@@ -30,7 +30,7 @@ from post.utils import *
from learning.attn import get_attention_map
from learning.utils import dgl_to_nx
from tools.learning_utils import load_model
from post.drawing import rna_draw
# from post.drawing import rna_draw
def mse(x,y):
d = np.sum((x-y)**2) / len(x)
......@@ -49,30 +49,31 @@ def get_decoys(mode='pdb', annots_dir='../data/annotated/pockets_nx_2'):
print(f"failed on {g}")
_,_,_,fp = pickle.load(open(os.path.join(annots_dir, g), 'rb'))
fp_dict[lig_id] = fp
decoy_list = list(fp_dict.values())
decoy_dict = {k:(v, decoy_list) for k,v in fp_dict.items()}
decoy_dict = {k:(v, [f for lig,f in fp_dict.items() if lig != k]) for k,v in fp_dict.items()}
return decoy_dict
if mode == 'dude':
return pickle.load(open('../data/decoys_zinc.p', 'rb'))
pass
def distance_rank(active, pred, decoys, dist_func=jaccard):
def distance_rank(active, pred, decoys, dist_func=mse):
"""
Get rank of prediction in `decoys` given a known active ligand.
"""
pred_dist = dist_func(active, pred)
rank = 0
for lig in decoys:
d = dist_func(active, lig)
for decoy in decoys:
d = dist_func(pred, decoy)
#if find a decoy closer to prediction, worsen the rank.
if d < pred_dist:
rank += 1
return 1- (rank / (len(decoys) + 1))
return 1 - (rank / (len(decoys) + 1))
def decoy_test(model, decoys, edge_map, embed_dim,
test_graphlist=None,
shuffle=False,
nucs=False,
test_graph_path="../data/annotated/pockets_nx"):
test_graph_path="../data/annotated/pockets_nx",
majority=False):
"""
Check performance against decoy set.
decoys --> {'ligand_id', ('expected_FP', [decoy_fps])}
......@@ -91,34 +92,56 @@ def decoy_test(model, decoys, edge_map, embed_dim,
test_graphlist = os.listdir(test_graph_path)
ligs = list(decoys.keys())
if majority:
generic = generic_fp("../data/annotated/pockets_nx_symmetric_orig")
for g_path in test_graphlist:
g,_,_,_ = pickle.load(open(os.path.join(test_graph_path, g_path), 'rb'))
g,_,_,true_fp = pickle.load(open(os.path.join(test_graph_path, g_path), 'rb'))
try:
true_id = g_path.split(":")[2]
except:
print(f">> failed on {g_path}")
continue
try:
decoys[true_id]
except KeyError:
print("missing fp", true_id)
continue
nx_graph, dgl_graph = nx_to_dgl(g, edge_map, nucs=nucs)
fp_pred, _ = model(dgl_graph)
with torch.no_grad():
fp_pred, _ = model(dgl_graph)
fp_pred = fp_pred.detach().numpy() > 0.5
fp_pred = fp_pred.astype(int)
if majority:
fp_pred = generic
# fp_pred = fp_pred.detach().numpy()
if shuffle:
# true_id = np.random.choice(ligs, replace=False)
fp_pred = np.random.rand(166)
active = decoys[true_id][0]
decs = decoys[true_id][1]
rank = distance_rank(active, fp_pred, decs)
sim = mse(active, fp_pred)
rank = distance_rank(active, fp_pred, decs, dist_func=mse)
sim = jaccard(true_fp, fp_pred)
ranks.append(rank)
sims.append(sim)
return ranks, sims
def wilcoxon_all_pairs(df):
"""
Compute pairwise wilcoxon on all runs.
"""
from scipy.stats import wilcoxon
wilcoxons = {'method_1': [], 'method_2':[], 'p-value': []}
for method_1, df1 in df.groupby('method'):
for method_2, df2 in df.groupby('method'):
p_val = wilcoxon(df1['rank'], df2['rank'])
wilcoxons['method_1'].append(method_1)
wilcoxons['method_2'].append(method_2)
wilcoxons['p-value'].append(p_val[1])
pass
wil_df = pd.DataFrame(wilcoxons)
wil_df.fillna(0)
pvals = wil_df.pivot("method_1", "method_2", "p-value")
pvals.fillna(0)
mask = np.zeros_like(pvals)
mask[np.triu_indices_from(mask)] = True
g = sns.heatmap(pvals, cmap="Reds_r", annot=True, mask=mask, cbar=True)
g.set_facecolor('grey')
plt.show()
pass
def generic_fp(annot_dir):
"""
Compute generic fingerprint by majority over dimensions.
......@@ -126,10 +149,13 @@ def generic_fp(annot_dir):
"""
fps = []
for g in os.listdir(annot_dir):
_,_,fp,_ = pickle.load(open(os.path.join(annot_dir, g), 'rb'))
_,_,_,fp = pickle.load(open(os.path.join(annot_dir, g), 'rb'))
fps.append(fp)
consensus = np.unique(fps, axis=0)
pass
counts = np.sum(fps, axis=0)
consensus = np.zeros(166)
ones = counts > len(fps) / 2
consensus[ones] = 1
return consensus
def make_violins(df, x='method', y='rank', save=None, show=True):
ax = sns.violinplot(x=x, y=y, data=df, color='0.8', bw=.1)
......@@ -145,32 +171,66 @@ def make_violins(df, x='method', y='rank', save=None, show=True):
pass
def make_ridge(df, x='method', y='rank', save=None, show=True):
# Initialize the FacetGrid object
sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
pal = sns.cubehelix_palette(10, rot=-.25, light=.7)
g = sns.FacetGrid(df, row=x, hue=x, aspect=15, height=.5, palette=pal)
# Draw the densities in a few steps
g.map(sns.kdeplot, y, clip_on=False, shade=True, alpha=1, lw=1.5, bw=.2)
g.map(sns.kdeplot, y, clip_on=False, color="w", lw=2, bw=.2)
g.map(plt.axhline, y=0, lw=2, clip_on=False)
# Define and use a simple function to label the plot in axes coordinates
def label(x, color, label):
ax = plt.gca()
ax.text(0, .2, label, fontweight="bold", color=color,
ha="left", va="center", transform=ax.transAxes)
g.map(label, x)
# Set the subplots to overlap
g.fig.subplots_adjust(hspace=-.25)
# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[])
g.despine(bottom=True, left=True)
plt.show()
def ablation_results():
# modes = ['', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle', 'pair-shuffle']
# modes = h'', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle', 'pair-shuffle']
# modes = ['raw', 'bb', 'wc-bb', 'pair-shuffle']
modes = ['raw', 'bb', 'wc-bb', 'swap', 'random']
# modes = ['raw', 'warm', 'wc-bb', 'bb', 'majority', 'swap', 'random']
modes = ['raw', 'wc-bb', 'bb', 'majority', 'swap', 'random']
decoys = get_decoys(mode='pdb')
ranks, methods, jaccards = [], [], []
graph_dir = '../data/annotated/pockets_nx_symmetric'
graph_dir = '../data/annotated/pockets_nx_symmetric_orig'
# graph_dir = '../data/annotated/pockets_nx_2'
run = 'ismb'
# run = 'teste'
# run = 'random'
num_folds = 10
num_folds = 10
majority = False
for m in modes:
print(m)
if m in ['raw', 'pair-shuffle']:
graph_dir = "../data/annotated/pockets_nx_symmetric"
graph_dir = "../data/annotated/pockets_nx_symmetric_orig"
run = 'ismb-raw'
# run = 'teste'
elif m == 'swap':
graph_dir = '../data/annotated/pockets_nx_symmetric_scramble'
graph_dir = '../data/annotated/pockets_nx_symmetric_scramble_orig'
run = 'ismb-' + m
elif m == 'majority':
run = 'ismb-raw'
majority = True
elif m == 'random':
graph_dir = '../data/annotated/pockets_nx_symmetric_random'
graph_dir = '../data/annotated/pockets_nx_symmetric_random_orig'
run = 'random'
elif m == 'warm':
graph_dir = '../data/annotated/pockets_nx_symmetric_orig'
run = 'ismb-warm'
else:
graph_dir = "../data/annotated/pockets_nx_symmetric_" + m
graph_dir = "../data/annotated/pockets_nx_symmetric_" + m + "_orig"
run = 'ismb-' + m
......@@ -184,33 +244,31 @@ def ablation_results():
graph_ids = pickle.load(open(f'../results/trained_models/{run}_{fold}/splits_{fold}.p', 'rb'))
# graph_ids = pickle.load(open(f'../results/trained_models/{run}/splits.p', 'rb'))
shuffle = False
if m == 'pair-shuffle':
shuffle = True
ranks_this,sims_this = decoy_test(model, decoys, edge_map, embed_dim,
shuffle=shuffle,
nucs=meta['nucs'],
test_graphlist=graph_ids['test'],
test_graph_path=graph_dir)
test_graph_path=graph_dir,
majority=majority)
test_ligs = []
ranks.extend(ranks_this)
jaccards.extend(sims_this)
methods.extend([m]*len(ranks_this))
# decoy distance distribution
dists = []
for _,(active, decs) in decoys.items():
for d in decs:
dists.append(jaccard(active, d))
# dists = []
# for _,(active, decs) in decoys.items():
# for d in decs:
# dists.append(jaccard(active, d))
# plt.scatter(ranks_this, sims_this)
# plt.xlabel("ranks")
# plt.ylabel("distance")
# plt.show()
sns.distplot(dists, label='decoy distance')
sns.distplot(sims_this, label='pred distance')
plt.xlabel("distance")
plt.legend()
plt.show()
# sns.distplot(dists, label='decoy distance')
# sns.distplot(sims_this, label='pred distance')
# plt.xlabel("distance")
# plt.legend()
# plt.show()
# # rank_cut = 0.9
# cool = [graph_ids['test'][i] for i,(d,r) in enumerate(zip(sims_this, ranks_this)) if d <0.4 and r > 0.8]
......@@ -239,8 +297,11 @@ def ablation_results():
# plt.legend()
# plt.show()
df = pd.DataFrame({'rank': ranks, 'jaccard': jaccards, 'method':methods})
make_violins(df, x='method', y='jaccard')
make_violins(df, x='method', y='rank')
wilcoxon_all_pairs(df)
# make_ridge(df, x='method', y='rank')
# make_ridge(df, x='method', y='jaccard')
# make_violins(df, x='method', y='jaccard')
# make_violins(df, x='method', y='rank')
def structure_scanning(pdb, ligname, graph, model, edge_map, embed_dim):
"""
......@@ -307,7 +368,7 @@ def scanning_analyze():
r = find_residue(structure[chain], pos)
r_center = lig_center(r.get_atoms())
dists.append(euclidean(r_center, lig_c))
jaccards.append(jaccard(true_fp, fp))
jaccards.append(mse(true_fp, fp))
plt.title(f)
plt.distplot(dists, jaccards)
plt.xlabel("dist to binding site")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment