Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Carlos GO
RNAmigos_GCN
Commits
12ce2a5a
Commit
12ce2a5a
authored
Jan 22, 2020
by
Carlos GO
Browse files
proper cluster loss
parent
16e9faef
Changes
3
Hide whitespace changes
Inline
Side-by-side
learning/learn.py
View file @
12ce2a5a
...
...
@@ -78,10 +78,7 @@ def test(model, test_loader, device, fp_draw=False):
# Get data on the devices
K
=
K
.
to
(
device
)
if
model
.
clustered
:
clust_hots
=
torch
.
zeros
((
len
(
fp
),
model
.
num_clusts
))
for
i
,
f
in
enumerate
(
fp
):
clust_hots
[
i
][
int
(
f
)]
=
1.
fp
=
clust_hots
fp
=
fp
.
long
()
fp
=
fp
.
to
(
device
)
K
=
torch
.
ones
(
K
.
shape
).
to
(
device
)
-
K
...
...
@@ -187,14 +184,11 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# Get data on the devices
#convert ints to one hots
if
model
.
clustered
:
clust_hots
=
torch
.
zeros
((
len
(
fp
),
model
.
num_clusts
))
for
i
,
f
in
enumerate
(
fp
):
clust_hots
[
i
][
int
(
f
)]
=
1.
fp
=
clust_hots
fp
=
fp
.
to
(
device
)
graph
=
send_graph_to_device
(
graph
,
device
)
if
model
.
clustered
:
fp
=
fp
.
long
()
fp
=
fp
.
to
(
device
)
fp_pred
,
embeddings
=
model
(
graph
)
...
...
learning/rgcn.py
View file @
12ce2a5a
...
...
@@ -71,9 +71,7 @@ class Attributor(nn.Module):
# hidden to output
layers
.
append
(
nn
.
Linear
(
last_hidden
,
last
))
#predict one class
if
self
.
clustered
:
layers
.
append
(
nn
.
Softmax
(
dim
=
1
))
else
:
if
not
self
.
clustered
:
layers
.
append
(
nn
.
Sigmoid
())
self
.
net
=
nn
.
Sequential
(
*
layers
)
...
...
@@ -240,7 +238,10 @@ class Model(nn.Module):
"""
# pw = torch.tensor([self.pos_weight], dtype=torch.float, requires_grad=False).to(self.device)
# loss = torch.nn.BCEWithLogitsLoss(pos_weight=pw)(pred_fp, target_fp)
loss
=
torch
.
nn
.
BCELoss
()(
pred_fp
,
target_fp
)
if
self
.
clustered
:
loss
=
torch
.
nn
.
CrossEntropyLoss
()(
pred_fp
,
target_fp
)
else
:
loss
=
torch
.
nn
.
BCELoss
()(
pred_fp
,
target_fp
)
# loss = JaccardDistanceLoss()(pred_fp, target_fp)
return
loss
...
...
post/validation.py
View file @
12ce2a5a
...
...
@@ -102,20 +102,6 @@ def decoy_test(model, decoys, edge_map, embed_dim,
nx_graph
,
dgl_graph
=
nx_to_dgl
(
g
,
edge_map
,
nucs
=
nucs
)
fp_pred
,
_
=
model
(
dgl_graph
)
if
False
:
n_nodes
=
len
(
dgl_graph
.
nodes
)
att
=
get_attention_map
(
dgl_graph
,
src_nodes
=
dgl_graph
.
nodes
(),
dst_nodes
=
dgl_graph
.
nodes
(),
h
=
1
)
att_g0
=
att
[
0
]
# get attn weights only for g0
# Select atoms with highest attention weights and plot them
tops
=
np
.
unique
(
np
.
where
(
att_g0
>
0.51
))
# get top atoms in attention
print
(
f
"tops
{
tops
}
"
)
g0
=
dgl_to_nx
(
dgl_graph
,
edge_map
)
nodelist
=
list
(
g0
.
nodes
())
highlight_edges
=
list
(
g0
.
subgraph
([
nodelist
[
t
]
for
t
in
tops
]).
edges
())
rna_draw
(
g0
,
highlight_edges
=
highlight_edges
)
fp_pred
=
fp_pred
.
detach
().
numpy
()
>
0.5
# print(fp_pred)
# fp_pred = np.random.choice([0, 1], size=(166,), p=[1./2, 1./2])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment