Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Carlos GO
RNAmigos_GCN
Commits
39d75522
Commit
39d75522
authored
Jan 19, 2020
by
Carlos GO
Browse files
decoys in trianing almost ready
parent
49311fe1
Changes
3
Hide whitespace changes
Inline
Side-by-side
learning/learn.py
View file @
39d75522
...
...
@@ -14,6 +14,7 @@ if __name__ == '__main__':
sys
.
path
.
append
(
'../'
)
from
learning.utils
import
dgl_to_nx
from
learning.decoy_utils
import
*
from
post.drawing
import
rna_draw
def
send_graph_to_device
(
g
,
device
):
...
...
@@ -60,7 +61,7 @@ def print_gradients(model):
name
,
p
=
param
print
(
name
,
p
.
grad
)
pass
def
test
(
model
,
test_loader
,
device
):
def
test
(
model
,
test_loader
,
device
,
decoys
=
None
):
"""
Compute accuracy and loss of model over given dataset
:param model:
...
...
@@ -71,8 +72,9 @@ def test(model, test_loader, device):
"""
model
.
eval
()
test_loss
,
motif_loss_tot
,
recons_loss_tot
=
(
0
,)
*
3
all_graphs
=
test_loader
.
dataset
.
dataset
.
all_graphs
test_size
=
len
(
test_loader
)
for
batch_idx
,
(
graph
,
K
,
fp
)
in
enumerate
(
test_loader
):
for
batch_idx
,
(
graph
,
K
,
fp
,
idx
)
in
enumerate
(
test_loader
):
# Get data on the devices
K
=
K
.
to
(
device
)
fp
=
fp
.
to
(
device
)
...
...
@@ -83,6 +85,10 @@ def test(model, test_loader, device):
with
torch
.
no_grad
():
fp_pred
,
embeddings
=
model
(
graph
)
loss
=
model
.
compute_loss
(
fp
,
fp_pred
)
for
i
,
f
in
zip
(
idx
,
fp_pred
):
true_lig
=
all_graphs
[
i
.
item
()].
split
(
":"
)[
2
]
rank
,
sim
=
decoy_test
(
f
,
true_lig
,
decoys
)
del
K
del
fp
del
graph
...
...
@@ -95,7 +101,8 @@ def test(model, test_loader, device):
def
train_model
(
model
,
criterion
,
optimizer
,
device
,
train_loader
,
test_loader
,
save_path
,
writer
=
None
,
num_epochs
=
25
,
wall_time
=
None
,
reconstruction_lam
=
1
,
motif_lam
=
1
,
embed_only
=-
1
):
reconstruction_lam
=
1
,
motif_lam
=
1
,
embed_only
=-
1
,
decoys
=
None
):
"""
Performs the entire training routine.
:param model: (torch.nn.Module): the model to train
...
...
@@ -152,7 +159,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
num_batches
=
len
(
train_loader
)
for
batch_idx
,
(
graph
,
K
,
fp
)
in
enumerate
(
train_loader
):
for
batch_idx
,
(
graph
,
K
,
fp
,
idx
)
in
enumerate
(
train_loader
):
# Get data on the devices
batch_size
=
len
(
K
)
...
...
@@ -201,7 +208,7 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# writer.log_scalar("Train accuracy during training", train_accuracy, epoch)
# Test phase
test_loss
=
test
(
model
,
test_loader
,
device
)
test_loss
=
test
(
model
,
test_loader
,
device
,
decoys
=
decoys
)
print
(
">> test loss "
,
test_loss
)
writer
.
add_scalar
(
"Test loss during training"
,
test_loss
,
epoch
)
...
...
learning/loader.py
View file @
39d75522
...
...
@@ -91,10 +91,10 @@ class V1(Dataset):
if
self
.
get_sim_mat
:
# put the rings in same order as the dgl graph
ring
=
dict
(
sorted
(
ring
.
items
()))
return
g_dgl
,
ring
,
fp
return
g_dgl
,
ring
,
fp
,
[
idx
]
else
:
return
g_dgl
,
fp
return
g_dgl
,
fp
,
[
idx
]
def
_get_edge_data
(
self
):
"""
...
...
@@ -125,20 +125,21 @@ def collate_wrapper(node_sim_func, get_sim_mat=True):
# The input `samples` is a list of pairs
# (graph, label).
# print(len(samples))
graphs
,
rings
,
fp
=
map
(
list
,
zip
(
*
samples
))
graphs
,
rings
,
fp
,
idx
=
map
(
list
,
zip
(
*
samples
))
fp
=
np
.
array
(
fp
)
idx
=
np
.
array
(
idx
)
batched_graph
=
dgl
.
batch
(
graphs
)
K
=
k_block_list
(
rings
,
node_sim_func
)
return
batched_graph
,
torch
.
from_numpy
(
K
).
detach
().
float
(),
torch
.
from_numpy
(
fp
).
detach
().
float
()
return
batched_graph
,
torch
.
from_numpy
(
K
).
detach
().
float
(),
torch
.
from_numpy
(
fp
).
detach
().
float
()
,
torch
.
from_numpy
(
idx
)
else
:
def
collate_block
(
samples
):
# The input `samples` is a list of pairs
# (graph, label).
# print(len(samples))
graphs
,
_
,
fp
=
map
(
list
,
zip
(
*
samples
))
graphs
,
_
,
fp
,
idx
=
map
(
list
,
zip
(
*
samples
))
fp
=
np
.
array
(
fp
)
batched_graph
=
dgl
.
batch
(
graphs
)
return
batched_graph
,
[
1
for
_
in
samples
],
torch
.
from_numpy
(
fp
)
return
batched_graph
,
[
1
for
_
in
samples
],
torch
.
from_numpy
(
fp
)
,
torch
.
from_numpy
(
idx
)
return
collate_block
class
Loader
():
...
...
learning/main.py
View file @
39d75522
...
...
@@ -100,7 +100,6 @@ print('Created data loader')
Model loading
'''
#increase output embeddings by 1 for nuc info
if
args
.
nucs
:
dim_add
=
1
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment