Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Carlos GO
RNAmigos_GCN
Commits
7ae793bd
Commit
7ae793bd
authored
Jan 21, 2020
by
Carlos GO
Browse files
drawing option
parent
70cdc345
Changes
7
Hide whitespace changes
Inline
Side-by-side
data_processor/binding_pocket_filter.py
View file @
7ae793bd
...
...
@@ -39,7 +39,8 @@ def get_valids(lig_dict, max_dist, min_conc, min_size=4):
unique_ligs
=
set
()
for
pdb
,
ligands
in
lig_dict
.
items
():
for
lig_id
,
lig_cuts
in
ligands
:
lig_name
=
lig_id
.
split
(
":"
)[
1
]
# lig_name = lig_id.split(":")[1]
lig_name
=
lig_id
.
split
(
":"
)[
2
]
#go over each distance cutoff
for
c
in
lig_cuts
:
tot
=
c
[
'rna'
]
+
c
[
'protein'
]
...
...
@@ -73,9 +74,9 @@ def ligs_to_txt(d, dest="../data/ligs.txt"):
o
.
write
(
" "
.
join
([
pdb
,
*
ligs
])
+
"
\n
"
)
pass
if
__name__
==
"__main__"
:
d
=
pickle
.
load
(
open
(
'../data/lig_dict.p'
,
'rb'
))
d
=
pickle
.
load
(
open
(
'../data/lig_dict
_ismb
.p'
,
'rb'
))
c
=
10
conc
=
.
6
ligs
=
get_valids
(
d
,
c
,
conc
,
min_size
=
5
)
#
pickle.dump(ligs, open("../data/lig_dict_
r10_d06
.p", "wb"))
ligs
=
get_valids
(
d
,
c
,
conc
,
min_size
=
4
)
pickle
.
dump
(
ligs
,
open
(
"../data/lig_dict_
ismb_rna06_rad10
.p"
,
"wb"
))
# ligs_to_txt(ligs)
data_processor/build_dataset.py
View file @
7ae793bd
...
...
@@ -122,11 +122,12 @@ def get_pocket_graph(pdb_structure_path, ligand_id, graph,
assert
labels
.
issubset
(
valid_edges
)
print
(
pocket
)
rna_draw
(
G
,
title
=
"BINDING"
)
# rna_draw(G, title="BINDING")
if
len
(
G
.
nodes
())
<
4
:
return
None
#
if dump_path
and (len(G.nodes()) > 4)
:
#
nx.write_gpickle(G, os.path.join(dump_path, f"{pdbid}_{ligand_id}_BIND.nx"))
if
dump_path
:
nx
.
write_gpickle
(
G
,
os
.
path
.
join
(
dump_path
,
f
"
{
pdbid
}
_
{
ligand_id
}
_BIND.nx"
))
#sample and build non-binding graph.
if
non_binding
:
...
...
@@ -164,6 +165,7 @@ def get_binding_site_graphs_all(lig_dict_path, dump_path, non_binding=False):
lig_dict
=
pickle
.
load
(
open
(
lig_dict_path
,
'rb'
))
print
(
f
">>> building graphs for
{
len
(
lig_dict
)
}
PDBs"
)
print
(
f
">>> dumping in
{
dump_path
}
"
)
print
(
f
">>> and
{
sum
(
map
(
len
,
lig_dict
.
values
()))
}
binding sites."
)
failed
=
0
...
...
@@ -177,12 +179,15 @@ def get_binding_site_graphs_all(lig_dict_path, dump_path, non_binding=False):
print
(
f
">>> skipping
{
done_pdbs
}
"
)
failed
=
[]
empties
=
0
num_found
=
0
missing_graphs
=
[]
for
pdbid
,
ligs
in
tqdm
(
lig_dict
.
items
()):
pdbid
=
pdbid
.
split
(
"."
)[
0
]
pdb_path
=
f
"../data/all_rna_prot_lig_2019/
{
pdbid
}
.cif"
# if pdbid in done_pdbs:
# continue
# pdb_path = f"../data/all_rna_prot_lig_2019/{pdbid}.cif"
pdb_path
=
f
"../../carlos_docking/data/all_rna_with_lig_2019/
{
pdbid
}
.cif"
if
pdbid
in
done_pdbs
:
continue
# try:
print
(
">>> "
,
pdbid
)
try
:
...
...
@@ -197,19 +202,26 @@ def get_binding_site_graphs_all(lig_dict_path, dump_path, non_binding=False):
for
lig
in
ligs
:
#dump binding site graphs
try
:
get_pocket_graph
(
pdb_path
,
lig
,
g
=
get_pocket_graph
(
pdb_path
,
lig
,
pdb_graph
,
dump_path
=
dump_path
,
non_binding
=
non_binding
)
if
g
is
None
:
empties
+=
1
else
:
num_found
+=
1
print
(
f
">>> pockets so far
{
num_found
}
"
)
except
FileNotFoundError
:
print
(
f
"
{
pdbid
}
not found"
)
failed
.
append
(
pdbid
)
print
(
f
">>> missing graphs for
{
missing_graphs
}
"
)
print
(
failed
)
print
(
f
">>> failed on
{
len
(
failed
)
}
graphs"
)
print
(
f
">>> got
{
empties
}
empty graphs"
)
if
__name__
==
"__main__"
:
#take all ligands with 8 angstrom sphere and 0.6 RNA concentration, build a graph for each.
# get_binding_site_graphs_all('../data/lig_dict_c_8A_06rna.p','../data/pockets_nx_pfind',
# non_binding=True)
get_binding_site_graphs_all
(
'../data/lig_dict_
c_10A_08rna
.p'
,
'../data/pockets_nx_
large
'
,
non_binding
=
False
)
get_binding_site_graphs_all
(
'../data/lig_dict_
ismb_rna06_rad10
.p'
,
'../data/pockets_nx_
ismb
'
,
non_binding
=
False
)
pass
data_processor/lig_dict_cluster.py
View file @
7ae793bd
...
...
@@ -13,14 +13,18 @@ from sklearn.cluster import AgglomerativeClustering
def
ligands_cluster
(
bs_dict
,
fp_dict
,
n_clusters
=
8
):
"""
Assign cluster labels to each ligand in ligand_list.
Create new fingerprint dictionary {'lig_id': cluster_id}
"""
#get which ligands to use in clustering
binding_sites
=
pickle
.
load
(
open
(
bs_dict
,
'rb'
))
fingerprints
=
pickle
.
load
(
open
(
fp_dict
,
'rb'
))
ligs_2_cluster
=
[]
for
_
,
ligs
in
binding_sites
.
items
():
ligs_2_cluster
.
extend
([
f
.
split
(
":"
)[
2
]
for
f
in
ligs
])
ligs_2_cluster
=
list
(
set
(
ligs_2_cluster
))
pocket_ids
=
[
f
.
split
(
":"
)[
2
]
for
f
in
ligs
]
ligs_2_cluster
.
extend
(
pocket_ids
)
# ligs_2_cluster_unique = list(set(ligs_2_cluster))
fps
=
[]
for
l
in
ligs_2_cluster
:
...
...
@@ -33,10 +37,13 @@ def ligands_cluster(bs_dict, fp_dict, n_clusters=8):
clusterer
=
AgglomerativeClustering
(
n_clusters
=
n_clusters
)
clusterer
.
fit
(
fps
)
labels
=
clusterer
.
labels_
clustered_fp_dict
=
dict
(
zip
(
ligs_2_cluster
,
labels
))
sns
.
distplot
(
labels
)
plt
.
show
()
return
clustered_fp_dict
pass
if
__name__
==
"__main__"
:
ligands_cluster
(
"../data/lig_dict_c_8A_06rna.p"
,
"../data/all_ligs_maccs.p"
)
clustered_fp_dict
=
ligands_cluster
(
"../data/lig_dict_c_8A_06rna.p"
,
"../data/all_ligs_maccs.p"
)
pickle
.
dump
(
clustered_fp_dict
,
open
(
"../data/fp_dict_8clusters.p"
,
'wb'
))
learning/learn.py
View file @
7ae793bd
...
...
@@ -61,7 +61,7 @@ def print_gradients(model):
name
,
p
=
param
print
(
name
,
p
.
grad
)
pass
def
test
(
model
,
test_loader
,
device
,
decoys
=
None
):
def
test
(
model
,
test_loader
,
device
,
decoys
=
None
,
fp_draw
=
False
):
"""
Compute accuracy and loss of model over given dataset
:param model:
...
...
@@ -85,27 +85,48 @@ def test(model, test_loader, device, decoys=None):
with
torch
.
no_grad
():
fp_pred
,
embeddings
=
model
(
graph
)
loss
=
model
.
compute_loss
(
fp
,
fp_pred
)
kws
=
{
'cbar'
:
False
,
'square'
:
False
,
'vmin'
:
0
,
'vmax'
:
1
}
jaccards
=
[]
enrichments
=
[]
for
i
,
f
in
zip
(
idx
,
fp_pred
):
true_lig
=
all_graphs
[
i
.
item
()].
split
(
":"
)[
2
]
rank
,
sim
=
decoy_test
(
f
,
true_lig
,
decoys
)
enrichments
.
append
(
rank
)
decoy_ranks
=
np
.
mean
(
enrichments
)
jaccards
.
append
(
sim
)
mean_ranks
=
np
.
mean
(
enrichments
)
mean_jaccard
=
np
.
mean
(
jaccards
)
del
K
del
fp
del
graph
test_loss
+=
loss
.
item
()
del
loss
if
fp_draw
:
fig
,
(
ax1
,
ax2
,
ax3
)
=
plt
.
subplots
(
1
,
3
)
sns
.
heatmap
(
fp
,
ax
=
ax1
,
**
kws
)
bina
=
fp_pred
>
0.5
fp_true
=
fp
.
clone
().
detach
()
fp_true
=
fp_true
.
int
()
bina
=
bina
.
int
()
sns
.
heatmap
(
bina
,
ax
=
ax2
,
**
kws
)
sns
.
heatmap
(
fp_true
!=
bina
,
ax
=
ax3
,
**
kws
)
ax1
.
set_title
(
"True"
)
ax2
.
set_title
(
"Pred"
)
ax3
.
set_title
(
"Diff"
)
plt
.
show
()
del
fp
return
test_loss
/
test_size
,
decoy
_ranks
return
test_loss
/
test_size
,
mean
_ranks
,
mean_jaccard
def
train_model
(
model
,
criterion
,
optimizer
,
device
,
train_loader
,
test_loader
,
save_path
,
writer
=
None
,
num_epochs
=
25
,
wall_time
=
None
,
reconstruction_lam
=
1
,
moti
f_lam
=
1
,
embed_only
=-
1
,
decoys
=
None
):
reconstruction_lam
=
1
,
f
p
_lam
=
1
,
embed_only
=-
1
,
decoys
=
None
,
early_stop_threshold
=
10
,
fp_draw
=
False
):
"""
Performs the entire training routine.
:param model: (torch.nn.Module): the model to train
...
...
@@ -119,22 +140,22 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
:param num_epochs: int number of epochs
:param wall_time: The number of hours you want the model to run
:param reconstruction_lam: how much to enforce pariwise similarity conservation
:param
moti
f_lam: how much to enforce motif assignment
:param f
p
_lam: how much to enforce motif assignment
:param embed_only: number of epochs before starting attributor training.
:return:
"""
edge_map
=
train_loader
.
dataset
.
dataset
.
edge_map
all_graphs
=
train_loader
.
dataset
.
dataset
.
all_graphs
decoys
=
get_decoys
(
mode
=
'pdb'
,
annots_dir
=
train_loader
.
dataset
.
dataset
.
path
)
epochs_from_best
=
0
early_stop_threshold
=
10
start_time
=
time
.
time
()
best_loss
=
sys
.
maxsize
moti
f_lam_orig
=
moti
f_lam
f
p
_lam_orig
=
f
p
_lam
reconstruction_lam_orig
=
reconstruction_lam
#if we delay attributor, start with attributor OFF
...
...
@@ -142,7 +163,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
if
embed_only
>
-
1
:
print
(
"Switching attriutor OFF. Embeddings still ON."
)
set_gradients
(
model
,
attributor
=
False
)
moti
f_lam
=
0
f
p
_lam
=
0
for
epoch
in
range
(
num_epochs
):
print
(
'Epoch {}/{}'
.
format
(
epoch
+
1
,
num_epochs
))
...
...
@@ -156,7 +177,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
print
(
"Switching attributor ON, embeddings OFF."
)
set_gradients
(
model
,
embedding
=
False
,
attributor
=
True
)
reconstruction_lam
=
0
moti
f_lam
=
moti
f_lam_orig
f
p
_lam
=
f
p
_lam_orig
running_loss
=
0.0
...
...
@@ -164,6 +185,8 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
num_batches
=
len
(
train_loader
)
train_enrichments
=
[]
train_jaccards
=
[]
for
batch_idx
,
(
graph
,
K
,
fp
,
idx
)
in
enumerate
(
train_loader
):
# Get data on the devices
...
...
@@ -173,7 +196,31 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
fp_pred
,
embeddings
=
model
(
graph
)
if
fp_draw
:
fig
,
(
ax1
,
ax2
,
ax3
)
=
plt
.
subplots
(
1
,
3
)
kws
=
{
'cbar'
:
False
,
'square'
:
False
,
'vmin'
:
0
,
'vmax'
:
1
}
sns
.
heatmap
(
fp
,
ax
=
ax1
,
**
kws
)
bina
=
fp_pred
>
0.5
fp_true
=
fp
.
clone
().
detach
()
fp_true
=
fp_true
.
int
()
bina
=
bina
.
int
()
sns
.
heatmap
(
bina
,
ax
=
ax2
,
**
kws
)
sns
.
heatmap
(
fp_true
!=
bina
,
ax
=
ax3
,
**
kws
)
ax1
.
set_title
(
"True"
)
ax2
.
set_title
(
"Pred"
)
ax3
.
set_title
(
"Diff"
)
plt
.
show
()
loss
=
model
.
compute_loss
(
fp
,
fp_pred
)
for
i
,
f
in
zip
(
idx
,
fp_pred
):
true_lig
=
all_graphs
[
i
.
item
()].
split
(
":"
)[
2
]
rank
,
sim
=
decoy_test
(
f
,
true_lig
,
decoys
)
train_enrichments
.
append
(
rank
)
train_jaccards
.
append
(
sim
)
# l = model.rec_loss(embeddings, K, similarity=False)
# print(l)
...
...
@@ -188,7 +235,6 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
batch_loss
=
loss
.
item
()
running_loss
+=
batch_loss
# running_corrects += labels.eq(target.view_as(out)).sum().item()
if
batch_idx
%
20
==
0
:
time_elapsed
=
time
.
time
()
-
start_time
print
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f} Time: {:.2f}'
.
format
(
...
...
@@ -208,14 +254,17 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# Log training metrics
train_loss
=
running_loss
/
num_batches
writer
.
add_scalar
(
"Training epoch loss"
,
train_loss
,
epoch
)
print
(
">> train enrichments"
,
np
.
mean
(
train_enrichments
))
print
(
">> train jaccards"
,
np
.
mean
(
train_jaccards
))
# train_accuracy = running_corrects / num_batches
# writer.log_scalar("Train accuracy during training", train_accuracy, epoch)
# Test phase
test_loss
,
enrichments
=
test
(
model
,
test_loader
,
device
,
decoys
=
decoys
)
test_loss
,
enrichments
,
jaccards
=
test
(
model
,
test_loader
,
device
,
decoys
=
decoys
)
print
(
">> test loss "
,
test_loss
)
print
(
">> test enrichments"
,
enrichments
)
print
(
">> test jaccards "
,
jaccards
)
writer
.
add_scalar
(
"Test loss during training"
,
test_loss
,
epoch
)
...
...
learning/loader.py
View file @
7ae793bd
...
...
@@ -79,8 +79,8 @@ class V1(Dataset):
one_hot_nucs
=
{
node
:
torch
.
tensor
(
self
.
nuc_map
[
label
],
dtype
=
torch
.
float32
)
for
node
,
label
in
(
nx
.
get_node_attributes
(
graph
,
'nt'
)).
items
()}
else
:
one_hot_nucs
=
{
node
:
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)
for
node
,
label
in
(
nx
.
get_node_attributes
(
graph
,
'nt'
)).
item
s
()}
one_hot_nucs
=
{
node
:
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
)
for
node
in
graph
.
node
s
()}
nx
.
set_node_attributes
(
graph
,
name
=
'one_hot'
,
values
=
one_hot_nucs
)
...
...
@@ -212,9 +212,9 @@ class Loader():
collate_block
=
collate_wrapper
(
self
.
dataset
.
node_sim_func
)
train_loader
=
DataLoader
(
dataset
=
train_set
,
shuffle
=
True
,
batch_size
=
self
.
batch_size
,
train_loader
=
DataLoader
(
dataset
=
train_set
,
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
collate_fn
=
collate_block
)
test_loader
=
DataLoader
(
dataset
=
test_set
,
shuffle
=
True
,
batch_size
=
self
.
batch_size
,
test_loader
=
DataLoader
(
dataset
=
test_set
,
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
collate_fn
=
collate_block
)
# return train_loader, valid_loader, test_loader
...
...
learning/main.py
View file @
7ae793bd
...
...
@@ -16,9 +16,8 @@ parser.add_argument("-da", "--annotated_data", default='pockets_nx_symmetric')
parser
.
add_argument
(
"-bs"
,
"--batch_size"
,
type
=
int
,
default
=
8
,
help
=
"choose the batch size"
)
parser
.
add_argument
(
"-nw"
,
"--workers"
,
type
=
int
,
default
=
20
,
help
=
"Number of workers to load data"
)
parser
.
add_argument
(
"-n"
,
"--name"
,
type
=
str
,
default
=
'default_name'
,
help
=
"Name for the logs"
)
parser
.
add_argument
(
"-t"
,
"--timed"
,
help
=
"to use timed learn"
,
action
=
'store_true'
)
parser
.
add_argument
(
"-ep"
,
"--num_epochs"
,
type
=
int
,
help
=
"number of epochs to train"
,
default
=
3
)
parser
.
add_argument
(
"-
m
l"
,
"--
moti
f_lam"
,
type
=
float
,
help
=
"
motif
lambda"
,
default
=
1.0
)
parser
.
add_argument
(
"-
f
l"
,
"--f
p
_lam"
,
type
=
float
,
help
=
"
fingerprint
lambda"
,
default
=
1.0
)
parser
.
add_argument
(
"-rl"
,
"--reconstruction_lam"
,
type
=
float
,
help
=
"reconstruction lambda"
,
default
=
1.0
)
parser
.
add_argument
(
'-ad'
,
'--attributor_dims'
,
nargs
=
'+'
,
type
=
int
,
help
=
'Dimensions for attributor.'
,
default
=
[
16
,
166
])
parser
.
add_argument
(
'-ed'
,
'--embedding_dims'
,
nargs
=
'+'
,
type
=
int
,
help
=
'Dimensions for embeddings.'
,
default
=
[
16
]
*
3
)
...
...
@@ -30,10 +29,12 @@ parser.add_argument('-po', '--pool', type=str, default='sum', help='Pooling func
parser
.
add_argument
(
"-nu"
,
"--nucs"
,
default
=
True
,
help
=
"Use nucleotide IDs for learn"
,
action
=
'store_false'
)
parser
.
add_argument
(
'-rs'
,
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'Random seed to use (if > 0, else no seed is set).'
)
parser
.
add_argument
(
'-kf'
,
'--kfold'
,
type
=
int
,
default
=
0
,
help
=
'Do k-fold crossval and do decoys on each fold..'
)
parser
.
add_argument
(
'-es'
,
'--early_stop'
,
type
=
int
,
default
=
10
,
help
=
'Early stop epoch threshold (default=10)'
)
args
=
parser
.
parse_args
()
print
(
f
"OPTIONS USED:
{
args
}
"
)
print
(
"OPTIONS USED"
)
print
(
"
\n
"
.
join
(
map
(
str
,
zip
(
vars
(
args
).
items
()))))
# Torch impors
import
torch
import
torch.optim
as
optim
...
...
@@ -110,7 +111,7 @@ else:
if
dims
[
-
1
]
!=
attributor_dims
[
0
]
-
dim_add
:
raise
ValueError
(
f
"Final embedding size must match first attributor dimension:
{
dims
[
-
1
]
}
!=
{
attributor_dims
[
0
]
}
"
)
moti
f_lam
=
args
.
moti
f_lam
f
p
_lam
=
args
.
f
p
_lam
reconstruction_lam
=
args
.
reconstruction_lam
data
=
loader
.
get_data
(
k_fold
=
args
.
kfold
)
...
...
@@ -183,5 +184,6 @@ for k, (train_loader, test_loader) in enumerate(data):
writer
=
writer
,
num_epochs
=
num_epochs
,
reconstruction_lam
=
reconstruction_lam
,
motif_lam
=
motif_lam
,
embed_only
=
args
.
embed_only
)
fp_lam
=
fp_lam
,
embed_only
=
args
.
embed_only
,
early_stop_threshold
=
args
.
early_stop
)
learning/rgcn.py
View file @
7ae793bd
...
...
@@ -12,6 +12,34 @@ from dgl.nn.pytorch.glob import SumPooling,GlobalAttentionPooling
from
dgl
import
mean_nodes
from
dgl.nn.pytorch.conv
import
RelGraphConv
class
JaccardDistanceLoss
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
smooth
=
100
,
dim
=
1
,
size_average
=
True
,
reduce
=
True
):
"""
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
= sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
The jaccard distance loss is usefull for unbalanced datasets. This has been
shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
gradient.
Ref: https://en.wikipedia.org/wiki/Jaccard_index
@url: https://gist.github.com/wassname/d1551adac83931133f6a84c5095ea101
@author: wassname
"""
super
(
JaccardDistanceLoss
,
self
).
__init__
()
self
.
smooth
=
smooth
self
.
dim
=
dim
self
.
size_average
=
size_average
self
.
reduce
=
reduce
def
forward
(
self
,
y_true
,
y_pred
):
intersection
=
(
y_true
*
y_pred
).
abs
().
sum
(
self
.
dim
)
sum_
=
(
y_true
.
abs
()
+
y_pred
.
abs
()).
sum
(
self
.
dim
)
jac
=
(
intersection
+
self
.
smooth
)
/
(
sum_
-
intersection
+
self
.
smooth
)
losses
=
(
1
-
jac
)
*
self
.
smooth
if
self
.
reduce
:
return
losses
.
mean
()
if
self
.
size_average
else
losses
.
sum
()
else
:
return
losses
class
Attributor
(
nn
.
Module
):
"""
...
...
@@ -206,6 +234,7 @@ class Model(nn.Module):
# pw = torch.tensor([self.pos_weight], dtype=torch.float, requires_grad=False).to(self.device)
# loss = torch.nn.BCEWithLogitsLoss(pos_weight=pw)(pred_fp, target_fp)
loss
=
torch
.
nn
.
BCELoss
()(
pred_fp
,
target_fp
)
# loss = JaccardDistanceLoss()(pred_fp, target_fp)
return
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment