Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Carlos GO
RNAmigos_GCN
Commits
464dc769
Commit
464dc769
authored
Jan 24, 2020
by
Carlos GO
Browse files
reconstrution compute
parent
f83b8d3b
Changes
3
Hide whitespace changes
Inline
Side-by-side
learning/learn.py
View file @
464dc769
...
...
@@ -63,7 +63,7 @@ def print_gradients(model):
name
,
p
=
param
print
(
name
,
p
.
grad
)
pass
def
test
(
model
,
test_loader
,
device
,
fp_
draw
=
False
):
def
test
(
model
,
test_loader
,
device
,
fp_
lam
=
1
,
rec_lam
=
1
):
"""
Compute accuracy and loss of model over given dataset
:param model:
...
...
@@ -89,31 +89,15 @@ def test(model, test_loader, device, fp_draw=False):
# Do the computations for the forward pass
with
torch
.
no_grad
():
fp_pred
,
embeddings
=
model
(
graph
)
loss
=
model
.
compute_loss
(
fp
,
fp_pred
)
kws
=
{
'cbar'
:
False
,
'square'
:
False
,
'vmin'
:
0
,
'vmax'
:
1
}
loss
=
model
.
compute_loss
(
fp
,
fp_pred
,
embeddings
,
K
,
fp_lam
=
fp_lam
,
rec_lam
=
rec_lam
)
del
K
del
graph
test_loss
+=
loss
.
item
()
del
loss
if
fp_draw
:
fig
,
(
ax1
,
ax2
,
ax3
)
=
plt
.
subplots
(
1
,
3
)
sns
.
heatmap
(
fp
,
ax
=
ax1
,
**
kws
)
bina
=
fp_pred
>
0.5
fp_true
=
fp
.
clone
().
detach
()
fp_true
=
fp_true
.
int
()
bina
=
bina
.
int
()
sns
.
heatmap
(
bina
,
ax
=
ax2
,
**
kws
)
sns
.
heatmap
(
fp_true
!=
bina
,
ax
=
ax3
,
**
kws
)
ax1
.
set_title
(
"True"
)
ax2
.
set_title
(
"Pred"
)
ax3
.
set_title
(
"Diff"
)
plt
.
show
()
del
fp
...
...
@@ -122,7 +106,7 @@ def test(model, test_loader, device, fp_draw=False):
def
train_model
(
model
,
criterion
,
optimizer
,
device
,
train_loader
,
test_loader
,
save_path
,
writer
=
None
,
num_epochs
=
25
,
wall_time
=
None
,
reconstruction_lam
=
1
,
fp_lam
=
1
,
embed_only
=-
1
,
early_stop_threshold
=
10
,
fp_draw
=
False
):
early_stop_threshold
=
10
,
fp_draw
=
False
):
"""
Performs the entire training routine.
:param model: (torch.nn.Module): the model to train
...
...
@@ -195,26 +179,8 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
fp_pred
,
embeddings
=
model
(
graph
)
if
fp_draw
:
fig
,
(
ax1
,
ax2
,
ax3
)
=
plt
.
subplots
(
1
,
3
)
kws
=
{
'cbar'
:
False
,
'square'
:
False
,
'vmin'
:
0
,
'vmax'
:
1
}
sns
.
heatmap
(
fp
,
ax
=
ax1
,
**
kws
)
bina
=
fp_pred
>
0.5
fp_true
=
fp
.
clone
().
detach
()
fp_true
=
fp_true
.
int
()
bina
=
bina
.
int
()
sns
.
heatmap
(
bina
,
ax
=
ax2
,
**
kws
)
sns
.
heatmap
(
fp_true
!=
bina
,
ax
=
ax3
,
**
kws
)
ax1
.
set_title
(
"True"
)
ax2
.
set_title
(
"Pred"
)
ax3
.
set_title
(
"Diff"
)
plt
.
show
()
loss
=
model
.
compute_loss
(
fp
,
fp_pred
)
loss
=
model
.
compute_loss
(
fp
,
fp_pred
,
embeddings
,
K
,
fp_lam
=
fp_lam
,
rec_lam
=
reconstruction_lam
)
# l = model.rec_loss(embeddings, K, similarity=False)
# print(l)
...
...
@@ -253,7 +219,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# writer.log_scalar("Train accuracy during training", train_accuracy, epoch)
# Test phase
test_loss
=
test
(
model
,
test_loader
,
device
)
test_loss
=
test
(
model
,
test_loader
,
device
,
fp_lam
=
fp_lam
,
rec_lam
=
reconstruction_lam
)
print
(
">> test loss "
,
test_loss
)
writer
.
add_scalar
(
"Test loss during training"
,
test_loss
,
epoch
)
...
...
learning/rgcn.py
View file @
464dc769
...
...
@@ -169,7 +169,7 @@ class Model(nn.Module):
target_K
=
torch
.
ones
(
target_K
.
shape
,
device
=
target_K
.
device
)
-
target_K
reconstruction_loss
=
torch
.
nn
.
MSELoss
()(
K_predict
,
target_K
)
self
.
draw_rec
(
target_K
,
K_predict
)
#
self.draw_rec(target_K, K_predict)
return
reconstruction_loss
# Below are loss computation function related to this model
@
staticmethod
...
...
@@ -194,7 +194,17 @@ class Model(nn.Module):
sim_mt
=
torch
.
mm
(
a_norm
,
b_norm
.
transpose
(
0
,
1
))
return
sim_mt
def
compute_loss
(
self
,
target_fp
,
pred_fp
):
def
fp_loss
(
self
,
target_fp
,
pred_fp
):
if
self
.
clustered
:
loss
=
torch
.
nn
.
CrossEntropyLoss
()(
pred_fp
,
target_fp
)
else
:
# loss = torch.nn.MSELoss()(pred_fp, target_fp)
loss
=
torch
.
nn
.
BCELoss
()(
pred_fp
,
target_fp
)
return
loss
def
compute_loss
(
self
,
target_fp
,
pred_fp
,
embeddings
,
target_K
,
rec_lam
=
1
,
fp_lam
=
1
,
similarity
=
False
):
"""
Compute the total loss of the model.
Includes the reconstruction loss with optional similarity/distance boolean switch
...
...
@@ -207,14 +217,8 @@ class Model(nn.Module):
:param scaled:
:return:
"""
# pw = torch.tensor([self.pos_weight], dtype=torch.float, requires_grad=False).to(self.device)
# loss = torch.nn.BCEWithLogitsLoss(pos_weight=pw)(pred_fp, target_fp)
if
self
.
clustered
:
loss
=
torch
.
nn
.
CrossEntropyLoss
()(
pred_fp
,
target_fp
)
else
:
# loss = torch.nn.MSELoss()(pred_fp, target_fp)
loss
=
torch
.
nn
.
BCELoss
()(
pred_fp
,
target_fp
)
loss
=
fp_lam
*
self
.
fp_loss
(
target_fp
,
pred_fp
)
\
+
rec_lam
*
self
.
rec_loss
(
embeddings
,
target_K
,
similarity
=
similarity
)
return
loss
def
draw_rec
(
self
,
true_K
,
predicted_K
,
title
=
""
):
...
...
post/validation.py
View file @
464dc769
...
...
@@ -30,7 +30,7 @@ from post.utils import *
from
learning.attn
import
get_attention_map
from
learning.utils
import
dgl_to_nx
from
tools.learning_utils
import
load_model
from
post.drawing
import
rna_draw
#
from post.drawing import rna_draw
def
mse
(
x
,
y
):
d
=
np
.
sum
((
x
-
y
)
**
2
)
/
len
(
x
)
...
...
@@ -49,30 +49,31 @@ def get_decoys(mode='pdb', annots_dir='../data/annotated/pockets_nx_2'):
print
(
f
"failed on
{
g
}
"
)
_
,
_
,
_
,
fp
=
pickle
.
load
(
open
(
os
.
path
.
join
(
annots_dir
,
g
),
'rb'
))
fp_dict
[
lig_id
]
=
fp
decoy_list
=
list
(
fp_dict
.
values
())
decoy_dict
=
{
k
:(
v
,
decoy_list
)
for
k
,
v
in
fp_dict
.
items
()}
decoy_dict
=
{
k
:(
v
,
[
f
for
lig
,
f
in
fp_dict
.
items
()
if
lig
!=
k
])
for
k
,
v
in
fp_dict
.
items
()}
return
decoy_dict
if
mode
==
'dude'
:
return
pickle
.
load
(
open
(
'../data/decoys_zinc.p'
,
'rb'
))
pass
def
distance_rank
(
active
,
pred
,
decoys
,
dist_func
=
jaccard
):
def
distance_rank
(
active
,
pred
,
decoys
,
dist_func
=
mse
):
"""
Get rank of prediction in `decoys` given a known active ligand.
"""
pred_dist
=
dist_func
(
active
,
pred
)
rank
=
0
for
lig
in
decoys
:
d
=
dist_func
(
active
,
lig
)
for
decoy
in
decoys
:
d
=
dist_func
(
pred
,
decoy
)
#if find a decoy closer to prediction, worsen the rank.
if
d
<
pred_dist
:
rank
+=
1
return
1
-
(
rank
/
(
len
(
decoys
)
+
1
))
return
1
-
(
rank
/
(
len
(
decoys
)
+
1
))
def
decoy_test
(
model
,
decoys
,
edge_map
,
embed_dim
,
test_graphlist
=
None
,
shuffle
=
False
,
nucs
=
False
,
test_graph_path
=
"../data/annotated/pockets_nx"
):
test_graph_path
=
"../data/annotated/pockets_nx"
,
majority
=
False
):
"""
Check performance against decoy set.
decoys --> {'ligand_id', ('expected_FP', [decoy_fps])}
...
...
@@ -91,34 +92,56 @@ def decoy_test(model, decoys, edge_map, embed_dim,
test_graphlist
=
os
.
listdir
(
test_graph_path
)
ligs
=
list
(
decoys
.
keys
())
if
majority
:
generic
=
generic_fp
(
"../data/annotated/pockets_nx_symmetric_orig"
)
for
g_path
in
test_graphlist
:
g
,
_
,
_
,
_
=
pickle
.
load
(
open
(
os
.
path
.
join
(
test_graph_path
,
g_path
),
'rb'
))
g
,
_
,
_
,
true_fp
=
pickle
.
load
(
open
(
os
.
path
.
join
(
test_graph_path
,
g_path
),
'rb'
))
try
:
true_id
=
g_path
.
split
(
":"
)[
2
]
except
:
print
(
f
">> failed on
{
g_path
}
"
)
continue
try
:
decoys
[
true_id
]
except
KeyError
:
print
(
"missing fp"
,
true_id
)
continue
nx_graph
,
dgl_graph
=
nx_to_dgl
(
g
,
edge_map
,
nucs
=
nucs
)
fp_pred
,
_
=
model
(
dgl_graph
)
with
torch
.
no_grad
():
fp_pred
,
_
=
model
(
dgl_graph
)
fp_pred
=
fp_pred
.
detach
().
numpy
()
>
0.5
fp_pred
=
fp_pred
.
astype
(
int
)
if
majority
:
fp_pred
=
generic
# fp_pred = fp_pred.detach().numpy()
if
shuffle
:
# true_id = np.random.choice(ligs, replace=False)
fp_pred
=
np
.
random
.
rand
(
166
)
active
=
decoys
[
true_id
][
0
]
decs
=
decoys
[
true_id
][
1
]
rank
=
distance_rank
(
active
,
fp_pred
,
decs
)
sim
=
mse
(
active
,
fp_pred
)
rank
=
distance_rank
(
active
,
fp_pred
,
decs
,
dist_func
=
mse
)
sim
=
jaccard
(
true_fp
,
fp_pred
)
ranks
.
append
(
rank
)
sims
.
append
(
sim
)
return
ranks
,
sims
def
wilcoxon_all_pairs
(
df
):
"""
Compute pairwise wilcoxon on all runs.
"""
from
scipy.stats
import
wilcoxon
wilcoxons
=
{
'method_1'
:
[],
'method_2'
:[],
'p-value'
:
[]}
for
method_1
,
df1
in
df
.
groupby
(
'method'
):
for
method_2
,
df2
in
df
.
groupby
(
'method'
):
p_val
=
wilcoxon
(
df1
[
'rank'
],
df2
[
'rank'
])
wilcoxons
[
'method_1'
].
append
(
method_1
)
wilcoxons
[
'method_2'
].
append
(
method_2
)
wilcoxons
[
'p-value'
].
append
(
p_val
[
1
])
pass
wil_df
=
pd
.
DataFrame
(
wilcoxons
)
wil_df
.
fillna
(
0
)
pvals
=
wil_df
.
pivot
(
"method_1"
,
"method_2"
,
"p-value"
)
pvals
.
fillna
(
0
)
mask
=
np
.
zeros_like
(
pvals
)
mask
[
np
.
triu_indices_from
(
mask
)]
=
True
g
=
sns
.
heatmap
(
pvals
,
cmap
=
"Reds_r"
,
annot
=
True
,
mask
=
mask
,
cbar
=
True
)
g
.
set_facecolor
(
'grey'
)
plt
.
show
()
pass
def
generic_fp
(
annot_dir
):
"""
Compute generic fingerprint by majority over dimensions.
...
...
@@ -126,10 +149,13 @@ def generic_fp(annot_dir):
"""
fps
=
[]
for
g
in
os
.
listdir
(
annot_dir
):
_
,
_
,
fp
,
_
=
pickle
.
load
(
open
(
os
.
path
.
join
(
annot_dir
,
g
),
'rb'
))
_
,
_
,
_
,
fp
=
pickle
.
load
(
open
(
os
.
path
.
join
(
annot_dir
,
g
),
'rb'
))
fps
.
append
(
fp
)
consensus
=
np
.
unique
(
fps
,
axis
=
0
)
pass
counts
=
np
.
sum
(
fps
,
axis
=
0
)
consensus
=
np
.
zeros
(
166
)
ones
=
counts
>
len
(
fps
)
/
2
consensus
[
ones
]
=
1
return
consensus
def
make_violins
(
df
,
x
=
'method'
,
y
=
'rank'
,
save
=
None
,
show
=
True
):
ax
=
sns
.
violinplot
(
x
=
x
,
y
=
y
,
data
=
df
,
color
=
'0.8'
,
bw
=
.
1
)
...
...
@@ -145,32 +171,66 @@ def make_violins(df, x='method', y='rank', save=None, show=True):
pass
def
make_ridge
(
df
,
x
=
'method'
,
y
=
'rank'
,
save
=
None
,
show
=
True
):
# Initialize the FacetGrid object
sns
.
set
(
style
=
"white"
,
rc
=
{
"axes.facecolor"
:
(
0
,
0
,
0
,
0
)})
pal
=
sns
.
cubehelix_palette
(
10
,
rot
=-
.
25
,
light
=
.
7
)
g
=
sns
.
FacetGrid
(
df
,
row
=
x
,
hue
=
x
,
aspect
=
15
,
height
=
.
5
,
palette
=
pal
)
# Draw the densities in a few steps
g
.
map
(
sns
.
kdeplot
,
y
,
clip_on
=
False
,
shade
=
True
,
alpha
=
1
,
lw
=
1.5
,
bw
=
.
2
)
g
.
map
(
sns
.
kdeplot
,
y
,
clip_on
=
False
,
color
=
"w"
,
lw
=
2
,
bw
=
.
2
)
g
.
map
(
plt
.
axhline
,
y
=
0
,
lw
=
2
,
clip_on
=
False
)
# Define and use a simple function to label the plot in axes coordinates
def
label
(
x
,
color
,
label
):
ax
=
plt
.
gca
()
ax
.
text
(
0
,
.
2
,
label
,
fontweight
=
"bold"
,
color
=
color
,
ha
=
"left"
,
va
=
"center"
,
transform
=
ax
.
transAxes
)
g
.
map
(
label
,
x
)
# Set the subplots to overlap
g
.
fig
.
subplots_adjust
(
hspace
=-
.
25
)
# Remove axes details that don't play well with overlap
g
.
set_titles
(
""
)
g
.
set
(
yticks
=
[])
g
.
despine
(
bottom
=
True
,
left
=
True
)
plt
.
show
()
def
ablation_results
():
# modes =
[
'', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle', 'pair-shuffle']
# modes =
h
'', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle', 'pair-shuffle']
# modes = ['raw', 'bb', 'wc-bb', 'pair-shuffle']
modes
=
[
'raw'
,
'bb'
,
'wc-bb'
,
'swap'
,
'random'
]
# modes = ['raw', 'warm', 'wc-bb', 'bb', 'majority', 'swap', 'random']
modes
=
[
'raw'
,
'wc-bb'
,
'bb'
,
'majority'
,
'swap'
,
'random'
]
decoys
=
get_decoys
(
mode
=
'pdb'
)
ranks
,
methods
,
jaccards
=
[],
[],
[]
graph_dir
=
'../data/annotated/pockets_nx_symmetric'
graph_dir
=
'../data/annotated/pockets_nx_symmetric
_orig
'
# graph_dir = '../data/annotated/pockets_nx_2'
run
=
'ismb'
# run = 'teste'
# run = 'random'
num_folds
=
10
num_folds
=
10
majority
=
False
for
m
in
modes
:
print
(
m
)
if
m
in
[
'raw'
,
'pair-shuffle'
]:
graph_dir
=
"../data/annotated/pockets_nx_symmetric"
graph_dir
=
"../data/annotated/pockets_nx_symmetric
_orig
"
run
=
'ismb-raw'
# run = 'teste'
elif
m
==
'swap'
:
graph_dir
=
'../data/annotated/pockets_nx_symmetric_scramble'
graph_dir
=
'../data/annotated/pockets_nx_symmetric_scramble
_orig
'
run
=
'ismb-'
+
m
elif
m
==
'majority'
:
run
=
'ismb-raw'
majority
=
True
elif
m
==
'random'
:
graph_dir
=
'../data/annotated/pockets_nx_symmetric_random'
graph_dir
=
'../data/annotated/pockets_nx_symmetric_random
_orig
'
run
=
'random'
elif
m
==
'warm'
:
graph_dir
=
'../data/annotated/pockets_nx_symmetric_orig'
run
=
'ismb-warm'
else
:
graph_dir
=
"../data/annotated/pockets_nx_symmetric_"
+
m
graph_dir
=
"../data/annotated/pockets_nx_symmetric_"
+
m
+
"_orig"
run
=
'ismb-'
+
m
...
...
@@ -184,33 +244,31 @@ def ablation_results():
graph_ids
=
pickle
.
load
(
open
(
f
'../results/trained_models/
{
run
}
_
{
fold
}
/splits_
{
fold
}
.p'
,
'rb'
))
# graph_ids = pickle.load(open(f'../results/trained_models/{run}/splits.p', 'rb'))
shuffle
=
False
if
m
==
'pair-shuffle'
:
shuffle
=
True
ranks_this
,
sims_this
=
decoy_test
(
model
,
decoys
,
edge_map
,
embed_dim
,
shuffle
=
shuffle
,
nucs
=
meta
[
'nucs'
],
test_graphlist
=
graph_ids
[
'test'
],
test_graph_path
=
graph_dir
)
test_graph_path
=
graph_dir
,
majority
=
majority
)
test_ligs
=
[]
ranks
.
extend
(
ranks_this
)
jaccards
.
extend
(
sims_this
)
methods
.
extend
([
m
]
*
len
(
ranks_this
))
# decoy distance distribution
dists
=
[]
for
_
,(
active
,
decs
)
in
decoys
.
items
():
for
d
in
decs
:
dists
.
append
(
jaccard
(
active
,
d
))
#
dists = []
#
for _,(active, decs) in decoys.items():
#
for d in decs:
#
dists.append(jaccard(active, d))
# plt.scatter(ranks_this, sims_this)
# plt.xlabel("ranks")
# plt.ylabel("distance")
# plt.show()
sns
.
distplot
(
dists
,
label
=
'decoy distance'
)
sns
.
distplot
(
sims_this
,
label
=
'pred distance'
)
plt
.
xlabel
(
"distance"
)
plt
.
legend
()
plt
.
show
()
#
sns.distplot(dists, label='decoy distance')
#
sns.distplot(sims_this, label='pred distance')
#
plt.xlabel("distance")
#
plt.legend()
#
plt.show()
# # rank_cut = 0.9
# cool = [graph_ids['test'][i] for i,(d,r) in enumerate(zip(sims_this, ranks_this)) if d <0.4 and r > 0.8]
...
...
@@ -239,8 +297,11 @@ def ablation_results():
# plt.legend()
# plt.show()
df
=
pd
.
DataFrame
({
'rank'
:
ranks
,
'jaccard'
:
jaccards
,
'method'
:
methods
})
make_violins
(
df
,
x
=
'method'
,
y
=
'jaccard'
)
make_violins
(
df
,
x
=
'method'
,
y
=
'rank'
)
wilcoxon_all_pairs
(
df
)
# make_ridge(df, x='method', y='rank')
# make_ridge(df, x='method', y='jaccard')
# make_violins(df, x='method', y='jaccard')
# make_violins(df, x='method', y='rank')
def
structure_scanning
(
pdb
,
ligname
,
graph
,
model
,
edge_map
,
embed_dim
):
"""
...
...
@@ -307,7 +368,7 @@ def scanning_analyze():
r
=
find_residue
(
structure
[
chain
],
pos
)
r_center
=
lig_center
(
r
.
get_atoms
())
dists
.
append
(
euclidean
(
r_center
,
lig_c
))
jaccards
.
append
(
jaccard
(
true_fp
,
fp
))
jaccards
.
append
(
mse
(
true_fp
,
fp
))
plt
.
title
(
f
)
plt
.
distplot
(
dists
,
jaccards
)
plt
.
xlabel
(
"dist to binding site"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment