Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Carlos GO
RNAmigos_GCN
Commits
29235f10
Commit
29235f10
authored
Feb 13, 2021
by
Carlos GO
Browse files
tools dir
parent
b83befd4
Changes
6
Hide whitespace changes
Inline
Side-by-side
tools/drawing.py
0 → 100644
View file @
29235f10
import
os
,
sys
import
pickle
import
networkx
as
nx
import
matplotlib
matplotlib
.
rcParams
[
'text.usetex'
]
=
True
import
matplotlib.pyplot
as
plt
import
seaborn
as
sns
if
__name__
==
"__main__"
:
sys
.
path
.
append
(
".."
)
from
tools.rna_layout
import
circular_layout
params
=
{
'text.latex.preamble'
:
[
r
'\usepackage{fdsymbol}\usepackage{xspace}'
]}
plt
.
rc
(
'font'
,
family
=
'serif'
)
plt
.
rcParams
.
update
(
params
)
labels
=
{
'CW'
:
r
"$\medblackcircle$\xspace"
,
'CS'
:
r
"$\medblacktriangleright$\xspace"
,
'CH'
:
r
"$\medblacksquare$\xspace"
,
'TW'
:
r
"$\medcircle$\xspace"
,
'TS'
:
r
"$\medtriangleright$\xspace"
,
'TH'
:
r
"$\medsquare$\xspace"
}
make_label
=
lambda
s
:
labels
[
s
[:
2
]]
+
labels
[
s
[
0
::
2
]]
if
len
(
set
(
s
[
1
:]))
==
2
\
else
labels
[
s
[:
2
]]
def
rna_draw
(
nx_g
,
title
=
""
,
highlight_edges
=
None
,
nt_info
=
False
,
node_colors
=
None
,
num_clusters
=
None
):
"""
Draw an RNA with the edge labels used by Leontis Westhof
:param nx_g:
:param title:
:param highlight_edges:
:param node_colors:
:param num_clusters:
:return:
"""
# pos = circular_layout(nx_g)
pos
=
nx
.
spring_layout
(
nx_g
)
if
node_colors
is
None
:
nodes
=
nx
.
draw_networkx_nodes
(
nx_g
,
pos
,
node_size
=
150
,
node_color
=
'white'
,
linewidths
=
2
)
else
:
nodes
=
nx
.
draw_networkx_nodes
(
nx_g
,
pos
,
node_size
=
150
,
node_color
=
node_colors
,
linewidths
=
2
)
nodes
.
set_edgecolor
(
'black'
)
if
nt_info
:
nx
.
draw_networkx_labels
(
nx_g
,
pos
,
font_color
=
'black'
)
# plt.title(r"{0}".format(title))
edge_labels
=
{}
for
n1
,
n2
,
d
in
nx_g
.
edges
(
data
=
True
):
try
:
symbol
=
make_label
(
d
[
'label'
])
edge_labels
[(
n1
,
n2
)]
=
symbol
except
:
if
d
[
'label'
]
==
'B53'
:
edge_labels
[(
n1
,
n2
)]
=
''
else
:
edge_labels
[(
n1
,
n2
)]
=
r
"{0}"
.
format
(
d
[
'label'
])
continue
non_bb_edges
=
[(
n1
,
n2
)
for
n1
,
n2
,
d
in
nx_g
.
edges
(
data
=
True
)
if
d
[
'label'
]
!=
'B53'
]
bb_edges
=
[(
n1
,
n2
)
for
n1
,
n2
,
d
in
nx_g
.
edges
(
data
=
True
)
if
d
[
'label'
]
==
'B53'
]
nx
.
draw_networkx_edges
(
nx_g
,
pos
,
edgelist
=
non_bb_edges
)
nx
.
draw_networkx_edges
(
nx_g
,
pos
,
edgelist
=
bb_edges
,
width
=
2
)
if
not
highlight_edges
is
None
:
nx
.
draw_networkx_edges
(
nx_g
,
pos
,
edgelist
=
highlight_edges
,
edge_color
=
'y'
,
width
=
8
,
alpha
=
0.5
)
nx
.
draw_networkx_edge_labels
(
nx_g
,
pos
,
font_size
=
16
,
edge_labels
=
edge_labels
)
plt
.
axis
(
'off'
)
# plt.savefig('fmn_' + title + '.png', format='png')
# plt.clf()
plt
.
show
()
def
rna_draw_pair
(
graphs
,
estimated_value
=
None
,
highlight_edges
=
None
,
node_colors
=
None
,
num_clusters
=
None
,
similarity
=
False
,
true_value
=
None
):
fig
,
ax
=
plt
.
subplots
(
1
,
len
(
graphs
),
num
=
1
)
for
i
,
g
in
enumerate
(
graphs
):
pos
=
nx
.
spring_layout
(
g
)
if
not
node_colors
is
None
:
nodes
=
nx
.
draw_networkx_nodes
(
g
,
pos
,
node_size
=
150
,
node_color
=
node_colors
[
i
],
linewidths
=
2
,
ax
=
ax
[
i
])
else
:
nodes
=
nx
.
draw_networkx_nodes
(
g
,
pos
,
node_size
=
150
,
node_color
=
'grey'
,
linewidths
=
2
,
ax
=
ax
[
i
])
nodes
.
set_edgecolor
(
'black'
)
# plt.title(r"{0}".format(title))
edge_labels
=
{}
for
n1
,
n2
,
d
in
g
.
edges
(
data
=
True
):
try
:
symbol
=
make_label
(
d
[
'label'
])
edge_labels
[(
n1
,
n2
)]
=
symbol
except
:
if
d
[
'label'
]
==
'B53'
:
edge_labels
[(
n1
,
n2
)]
=
''
else
:
edge_labels
[(
n1
,
n2
)]
=
r
"{0}"
.
format
(
d
[
'label'
])
continue
non_bb_edges
=
[(
n1
,
n2
)
for
n1
,
n2
,
d
in
g
.
edges
(
data
=
True
)
if
d
[
'label'
]
!=
'B53'
]
bb_edges
=
[(
n1
,
n2
)
for
n1
,
n2
,
d
in
g
.
edges
(
data
=
True
)
if
d
[
'label'
]
==
'B53'
]
nx
.
draw_networkx_edges
(
g
,
pos
,
edgelist
=
non_bb_edges
,
ax
=
ax
[
i
])
nx
.
draw_networkx_edges
(
g
,
pos
,
edgelist
=
bb_edges
,
width
=
2
,
ax
=
ax
[
i
])
if
not
highlight_edges
is
None
:
nx
.
draw_networkx_edges
(
g
,
pos
,
edgelist
=
highlight_edges
,
edge_color
=
'y'
,
width
=
8
,
alpha
=
0.5
,
ax
=
ax
[
i
])
nx
.
draw_networkx_edge_labels
(
g
,
pos
,
font_size
=
16
,
edge_labels
=
edge_labels
,
ax
=
ax
[
i
])
ax
[
i
].
set_axis_off
()
plt
.
axis
(
'off'
)
title
=
'similarity : '
if
similarity
else
'distance : '
+
str
(
estimated_value
)
if
true_value
is
not
None
:
title
=
title
+
f
' true :
{
true_value
}
'
plt
.
title
(
title
)
plt
.
show
()
def
generic_draw_pair
(
graphs
,
title
=
""
,
highlight_edges
=
None
,
node_colors
=
None
,
num_clusters
=
None
):
fig
,
ax
=
plt
.
subplots
(
1
,
len
(
graphs
),
num
=
1
)
for
i
,
g
in
enumerate
(
graphs
):
pos
=
nx
.
spring_layout
(
g
)
if
not
node_colors
is
None
:
nodes
=
nx
.
draw_networkx_nodes
(
g
,
pos
,
node_size
=
150
,
node_color
=
node_colors
[
i
],
linewidths
=
2
,
ax
=
ax
[
i
])
else
:
nodes
=
nx
.
draw_networkx_nodes
(
g
,
pos
,
node_size
=
150
,
node_color
=
'grey'
,
linewidths
=
2
,
ax
=
ax
[
i
])
nodes
.
set_edgecolor
(
'black'
)
# plt.title(r"{0}".format(title))
edge_labels
=
{}
for
n1
,
n2
,
d
in
g
.
edges
(
data
=
True
):
edge_labels
[(
n1
,
n2
)]
=
str
(
d
[
'label'
])
if
not
highlight_edges
is
None
:
nx
.
draw_networkx_edges
(
g
,
pos
,
edgelist
=
highlight_edges
,
edge_color
=
'y'
,
width
=
8
,
alpha
=
0.5
,
ax
=
ax
[
i
])
nx
.
draw_networkx_edge_labels
(
g
,
pos
,
font_size
=
16
,
edge_labels
=
edge_labels
,
ax
=
ax
[
i
])
ax
[
i
].
set_axis_off
()
plt
.
axis
(
'off'
)
plt
.
title
(
f
"distance
{
title
}
"
)
plt
.
show
()
def
generic_draw
(
graph
,
title
=
""
,
highlight_edges
=
None
,
node_colors
=
None
):
fig
,
ax
=
plt
.
subplots
(
1
,
2
,
num
=
1
)
pos
=
nx
.
spring_layout
(
graph
)
if
not
node_colors
is
None
:
nodes
=
nx
.
draw_networkx_nodes
(
graph
,
pos
,
node_size
=
150
,
cmap
=
plt
.
cm
.
Blues
,
node_color
=
node_colors
,
linewidths
=
2
,
ax
=
ax
[
0
])
else
:
nodes
=
nx
.
draw_networkx_nodes
(
graph
,
pos
,
node_size
=
150
,
node_color
=
'grey'
,
linewidths
=
2
,
ax
=
ax
[
0
])
nodes
.
set_edgecolor
(
'black'
)
# plt.title(r"{0}".format(title))
edge_labels
=
{}
for
n1
,
n2
,
d
in
graph
.
edges
(
data
=
True
):
edge_labels
[(
n1
,
n2
)]
=
str
(
d
[
'label'
])
if
not
highlight_edges
is
None
:
nx
.
draw_networkx_edges
(
graph
,
pos
,
edgelist
=
highlight_edges
,
edge_color
=
'y'
,
width
=
8
,
alpha
=
0.5
,
ax
=
ax
[
0
])
nx
.
draw_networkx_edges
(
graph
,
pos
,
ax
=
ax
[
0
])
nx
.
draw_networkx_edge_labels
(
graph
,
pos
,
font_size
=
16
,
edge_labels
=
edge_labels
,
ax
=
ax
[
0
])
ax
[
0
].
set_axis_off
()
plt
.
axis
(
'off'
)
plt
.
title
(
f
"motif
{
title
}
"
)
plt
.
show
()
def
ablation_draw
():
g_name
=
"1fmn_#0.1:A:FMN:36.nx_annot.p"
modes
=
[
''
,
'_bb-only'
,
'_wc-bb'
,
'_wc-bb-nc'
,
'_no-label'
,
'_label-shuffle'
]
for
m
in
modes
:
g_dir
=
"../data/annotated/pockets_nx"
+
m
g
,
_
,
_
,
_
=
pickle
.
load
(
open
(
os
.
path
.
join
(
g_dir
,
g_name
),
'rb'
))
rna_draw
(
g
,
title
=
m
)
pass
if
__name__
==
"__main__"
:
ablation_draw
()
tools/graph_utils.py
0 → 100644
View file @
29235f10
import
pickle
import
os
import
itertools
from
tqdm
import
tqdm
import
networkx
as
nx
import
torch
import
dgl
def
get_edge_map
(
graphs_dir
):
edge_labels
=
set
()
print
(
"Collecting edge labels."
)
for
g
in
tqdm
(
os
.
listdir
(
graphs_dir
)):
graph
,
_
,
_
=
pickle
.
load
(
open
(
os
.
path
.
join
(
graphs_dir
,
g
),
'rb'
))
edges
=
{
e_dict
[
'label'
]
for
_
,
_
,
e_dict
in
graph
.
edges
(
data
=
True
)}
edge_labels
=
edge_labels
.
union
(
edges
)
return
{
label
:
i
for
i
,
label
in
enumerate
(
sorted
(
edge_labels
))}
def
nx_to_dgl_jacques
(
graph
,
edge_map
):
"""
Returns one training item at index `idx`.
"""
#adding the self edges
# graph.add_edges_from([(n, n, {'label': 'X'}) for n in graph.nodes()])
graph
=
nx
.
to_undirected
(
graph
)
one_hot
=
{
edge
:
torch
.
tensor
(
edge_map
[
label
])
for
edge
,
label
in
(
nx
.
get_edge_attributes
(
graph
,
'label'
)).
items
()}
nx
.
set_edge_attributes
(
graph
,
name
=
'one_hot'
,
values
=
one_hot
)
g_dgl
=
dgl
.
DGLGraph
()
# g_dgl.from_networkx(nx_graph=graph, edge_attrs=['one_hot'], node_attrs=['one_hot'])
g_dgl
.
from_networkx
(
nx_graph
=
graph
,
edge_attrs
=
[
'one_hot'
],
node_attrs
=
[
'angles'
,
'identity'
])
#JACQUES
# Init node embeddings with nodes features
floatid
=
g_dgl
.
ndata
[
'identity'
].
float
()
g_dgl
.
ndata
[
'h'
]
=
torch
.
cat
([
g_dgl
.
ndata
[
'angles'
],
floatid
],
dim
=
1
)
print
(
"HII"
)
return
graph
,
g_dgl
def
nx_to_dgl_
(
graph
,
edge_map
,
embed_dim
):
"""
Networkx graph to DGL.
"""
import
torch
import
dgl
graph
,
_
,
ring
=
pickle
.
load
(
open
(
graph
,
'rb'
))
one_hot
=
{
edge
:
edge_map
[
label
]
for
edge
,
label
in
(
nx
.
get_edge_attributes
(
graph
,
'label'
)).
items
()}
nx
.
set_edge_attributes
(
graph
,
name
=
'one_hot'
,
values
=
one_hot
)
one_hot
=
{
edge
:
torch
.
tensor
(
edge_map
[
label
])
for
edge
,
label
in
(
nx
.
get_edge_attributes
(
graph
,
'label'
)).
items
()}
g_dgl
=
dgl
.
DGLGraph
()
g_dgl
.
from_networkx
(
nx_graph
=
graph
,
edge_attrs
=
[
'one_hot'
])
n_nodes
=
len
(
g_dgl
.
nodes
())
g_dgl
.
ndata
[
'h'
]
=
torch
.
ones
((
n_nodes
,
embed_dim
))
return
graph
,
g_dgl
def
dgl_to_nx
(
graph
,
edge_map
):
g
=
dgl
.
to_networkx
(
graph
,
edge_attrs
=
[
'one_hot'
])
edge_map_r
=
{
v
:
k
for
k
,
v
in
edge_map
.
items
()}
nx
.
set_edge_attributes
(
g
,
{(
n1
,
n2
):
edge_map_r
[
d
[
'one_hot'
].
item
()]
for
n1
,
n2
,
d
in
g
.
edges
(
data
=
True
)},
'label'
)
return
g
def
bfs_expand
(
G
,
initial_nodes
,
depth
=
2
):
"""
Extend motif graph starting with motif_nodes.
Returns list of nodes.
"""
total_nodes
=
[
list
(
initial_nodes
)]
for
d
in
range
(
depth
):
depth_ring
=
[]
for
n
in
total_nodes
[
d
]:
for
nei
in
G
.
neighbors
(
n
):
depth_ring
.
append
(
nei
)
total_nodes
.
append
(
depth_ring
)
return
set
(
itertools
.
chain
(
*
total_nodes
))
def
bfs
(
G
,
initial_node
,
depth
=
2
):
"""
Generator for bfs given graph and initial node.
Yields nodes at next hop at each call.
"""
total_nodes
=
[[
initial_node
]]
visited
=
[]
for
d
in
range
(
depth
):
depth_ring
=
[]
for
n
in
total_nodes
[
d
]:
visited
.
append
(
n
)
for
nei
in
G
.
neighbors
(
n
):
if
nei
not
in
visited
:
depth_ring
.
append
(
nei
)
total_nodes
.
append
(
depth_ring
)
yield
depth_ring
def
graph_ablations
(
G
,
mode
):
"""
Remove edges with certain labels depending on the mode.
:params
:G Binding Site Graph
:mode how to remove edges ('bb-only', 'wc-bb', 'wc-bb-nc', 'no-label')
:returns: Copy of original graph with edges removed/relabeled.
"""
H
=
nx
.
Graph
()
if
mode
==
'label-shuffle'
:
# assign a random label from the same graph to each edge.
labels
=
[
d
[
'label'
]
for
_
,
_
,
d
in
G
.
edges
(
data
=
True
)]
shuffle
(
labels
)
for
n1
,
n2
,
d
in
G
.
edges
(
data
=
True
):
H
.
add_edge
(
n1
,
n2
,
label
=
labels
.
pop
())
return
H
if
mode
==
'no-label'
:
for
n1
,
n2
,
d
in
G
.
edges
(
data
=
True
):
H
.
add_edge
(
n1
,
n2
,
label
=
'X'
)
return
H
if
mode
==
'wc-bb-nc'
:
for
n1
,
n2
,
d
in
G
.
edges
(
data
=
True
):
label
=
d
[
'label'
]
if
d
[
'label'
]
not
in
[
'CWW'
,
'B53'
]:
label
=
'NC'
H
.
add_edge
(
n1
,
n2
,
label
=
label
)
return
H
if
mode
==
'bb-only'
:
valid_edges
=
[
'B53'
]
if
mode
==
'wc-bb'
:
valid_edges
=
[
'B53'
,
'CWW'
]
for
n1
,
n2
,
d
in
G
.
edges
(
data
=
True
):
if
d
[
'label'
]
in
valid_edges
:
H
.
add_edge
(
n1
,
n2
,
label
=
d
[
'label'
])
return
H
tools/learning_utils.py
0 → 100644
View file @
29235f10
import
os
import
configparser
from
ast
import
literal_eval
import
pickle
from
tqdm
import
tqdm
import
torch
import
numpy
as
np
import
networkx
as
nx
from
learning.loader
import
Loader
,
InferenceLoader
from
learning.learn
import
send_graph_to_device
from
learning.rgcn
import
Model
def
remove
(
name
):
"""
delete an experiment results
:param name:
:return:
"""
import
shutil
script_dir
=
os
.
path
.
dirname
(
__file__
)
logdir
=
os
.
path
.
join
(
script_dir
,
f
'../results/logs/
{
name
}
'
)
weights_dir
=
os
.
path
.
join
(
script_dir
,
f
'../results/trained_models/
{
name
}
'
)
experiment
=
os
.
path
.
join
(
script_dir
,
f
'../results/experiments/
{
name
}
.exp'
)
shutil
.
rmtree
(
logdir
)
shutil
.
rmtree
(
weights_dir
)
os
.
remove
(
experiment
)
return
True
def
setup
():
"""
Create all relevant directories to setup the learning procedure
:return:
"""
script_dir
=
os
.
path
.
dirname
(
__file__
)
resdir
=
os
.
path
.
join
(
script_dir
,
f
'../results/'
)
logdir
=
os
.
path
.
join
(
script_dir
,
f
'../results/logs/'
)
weights_dir
=
os
.
path
.
join
(
script_dir
,
f
'../results/trained_models/'
)
experiment
=
os
.
path
.
join
(
script_dir
,
f
'../results/experiments/'
)
os
.
mkdir
(
resdir
)
os
.
mkdir
(
logdir
)
os
.
mkdir
(
weights_dir
)
os
.
mkdir
(
experiment
)
def
mkdirs_learning
(
name
,
permissive
=
True
):
"""
Try to make the logs folder for each experiment
:param name:
:param permissive: If True will overwrite existing files (good for debugging)
:return:
"""
from
tools.utils
import
makedir
log_path
=
os
.
path
.
join
(
'results/logs'
,
name
)
save_path
=
os
.
path
.
join
(
'results/trained_models'
,
name
)
makedir
(
log_path
,
permissive
)
makedir
(
save_path
,
permissive
)
save_name
=
os
.
path
.
join
(
save_path
,
name
+
'.pth'
)
return
log_path
,
save_name
def
load_model
(
run
):
"""
Load full trained model with id `run`
"""
meta
=
pickle
.
load
(
open
(
f
'../results/trained_models/
{
run
}
/meta.p'
,
'rb'
))
edge_map
=
meta
[
'edge_map'
]
num_edge_types
=
len
(
edge_map
)
model_dict
=
torch
.
load
(
f
'../results/trained_models/
{
run
}
/
{
run
}
.pth'
,
map_location
=
'cpu'
)
model
=
Model
(
dims
=
meta
[
'embedding_dims'
],
attributor_dims
=
meta
[
'attributor_dims'
],
num_rels
=
num_edge_types
,
num_bases
=-
1
,
device
=
'cpu'
,
pool
=
meta
[
'pool'
])
model
.
load_state_dict
(
model_dict
[
'model_state_dict'
])
return
model
,
meta
def
load_data
(
annotated_path
,
meta
,
get_sim_mat
=
True
):
"""
:params
:get_sim_mat: switches off computation of rings and K matrix for faster loading.
"""
loader
=
Loader
(
annotated_path
=
annotated_path
,
batch_size
=
1
,
num_workers
=
1
,
sim_function
=
meta
[
'sim_function'
],
get_sim_mat
=
get_sim_mat
)
train_loader
,
_
,
test_loader
=
loader
.
get_data
()
return
train_loader
,
test_loader
def
predict
(
model
,
loader
,
max_graphs
=
10
,
device
=
'cpu'
):
all_graphs
=
loader
.
dataset
.
all_graphs
Z
=
[]
fps
=
[]
g_inds
=
[]
model
=
model
.
to
(
device
)
with
torch
.
no_grad
():
for
i
,
(
graph
,
K
,
fp
,
graph_index
)
in
tqdm
(
enumerate
(
loader
),
total
=
len
(
loader
)):
graph
=
send_graph_to_device
(
graph
,
device
)
fp
,
z
=
model
(
graph
)
Z
.
append
(
z
.
cpu
().
numpy
())
fps
.
append
(
fp
.
cpu
().
numpy
())
Z
=
np
.
concatenate
(
Z
)
fps
=
np
.
array
(
fps
)
return
fps
,
Z
def
inference_on_dir
(
run
,
graph_dir
,
ini
=
True
,
max_graphs
=
10
,
get_sim_mat
=
False
,
split_mode
=
'test'
,
attributions
=
False
,
device
=
'cpu'
):
"""
Load model and get node embeddings.
The results then need to be parsed as the order of the graphs is random and that the order of
each node in the graph is the messed up one (sorted)
Returns : embeddings and attributions, as well as 'g_inds':
a dict (graph name, node_id in sorted g_nodes) : index in the embedding matrix
:params
:get_sim_mat: switches off computation of rings and K matrix for faster loading.
:max_graphs max number of graphs to get embeddings for
"""
model
,
meta
=
meta_load_model
(
run
)
loader
=
InferenceLoader
(
graph_dir
).
get_data
()
return
predict
(
model
,
loader
,
max_graphs
=
max_graphs
,
device
=
device
)
def
meta_load_model
(
run
):
"""
Load full trained model with id `run`
"""
meta
=
pickle
.
load
(
open
(
f
'models/
{
run
}
/meta.p'
,
'rb'
))
print
(
meta
)
edge_map
=
meta
[
'edge_map'
]
num_edge_types
=
len
(
edge_map
)
model_dict
=
torch
.
load
(
f
'models/
{
run
}
/
{
run
}
.pth'
,
map_location
=
'cpu'
)
model
=
Model
(
dims
=
meta
[
'embedding_dims'
],
attributor_dims
=
meta
[
'attributor_dims'
],
num_rels
=
num_edge_types
,
num_bases
=-
1
,
device
=
'cpu'
)
model
.
load_state_dict
(
model_dict
[
'model_state_dict'
])
return
model
,
meta
def
model_from_hparams
(
hparams
):
"""
Load full trained model with id `run`
"""
edge_map
=
hparams
.
get
(
'edges'
,
'edge_map'
)
num_edge_types
=
len
(
edge_map
)
run
=
hparams
.
get
(
'argparse'
,
'name'
)
model_dict
=
torch
.
load
(
f
'../results/trained_models/
{
run
}
/
{
run
}
.pth'
,
map_location
=
'cpu'
)
model
=
Model
(
dims
=
hparams
.
get
(
'argparse'
,
'embedding_dims'
),
attributor_dims
=
hparams
.
get
(
'argparse'
,
'attributor_dims'
),
num_rels
=
num_edge_types
,
num_bases
=-
1
,
hard_embed
=
hparams
.
get
(
'argparse'
,
'hard_embed'
))
model
.
load_state_dict
(
model_dict
[
'model_state_dict'
])
return
model
def
data_from_hparams
(
annotated_path
,
hparams
,
get_sim_mat
=
True
):
"""
:params
:get_sim_mat: switches off computation of rings and K matrix for faster loading.
"""
dims
=
hparams
.
get
(
'argparse'
,
'embedding_dims'
)
loader
=
Loader
(
annotated_path
=
annotated_path
,
batch_size
=
hparams
.
get
(
'argparse'
,
'batch_size'
),
num_workers
=
1
,
sim_function
=
hparams
.
get
(
'argparse'
,
'sim_function'
),
depth
=
hparams
.
get
(
'argparse'
,
'kernel_depth'
),
hard_embed
=
hparams
.
get
(
'argparse'
,
'hard_embed'
),
hparams
=
hparams
,
get_sim_mat
=
get_sim_mat
)
train_loader
,
_
,
test_loader
=
loader
.
get_data
()
return
train_loader
,
test_loader
def
get_rgcn_outputs
(
run
,
graph_dir
,
ini
=
False
,
max_graphs
=
100
,
nc_only
=
False
,
get_sim_mat
=
True
):
"""
Load model and get node embeddings.
:params
:get_sim_mat: switches off computation of rings and K matrix for faster loading.
:max_graphs max number of graphs to get embeddings for
"""
from
tools.graph_utils
import
dgl_to_nx
if
ini
:
hparams
=
ConfParser
(
default_path
=
os
.
path
.
join
(
'../results/experiments'
,
f
'
{
run
}
.exp'
))
model
=
model_from_hparams
(
hparams
)
train_loader
,
test_loader
=
data_from_hparams
(
graph_dir
,
hparams
,
get_sim_mat
=
get_sim_mat
)
edge_map
=
hparams
.
get
(
'edges'
,
'edge_map'
)
similarity
=
hparams
.
get
(
'argparse'
,
'similarity'
)
else
:
model
,
meta
=
load_model
(
run
)
train_loader
,
test_loader
=
load_data
(
graph_dir
,
meta
,
get_sim_mat
=
get_sim_mat
)
edge_map
=
meta
[
'edge_map'
]
similarity
=
False
for
param_tensor
in
model
.
state_dict
():
print
(
param_tensor
,
"
\t
"
,
model
.
state_dict
()[
param_tensor
])
Z
=
[]
fp_mat
=
[]
nx_graphs
=
[]
KS
=
[]
# maps full nodeset index to graph and node index inside graph
node_map
=
{}
ind
=
0
offset
=
0
for
i
,
(
graph
,
K
,
graph_sizes
)
in
enumerate
(
train_loader
):
if
i
>
max_graphs
-
1
:
break
fp
,
z
=
model
(
graph
)
KS
.
append
(
K
)
fp_mat
.
append
(
np
.
array
(
fp
.
detach
().
numpy
()))
for
j
,
emb
in
enumerate
(
z
.
detach
().
numpy
()):
Z
.
append
(
np
.
array
(
emb
))
node_map
[
ind
]
=
(
i
,
j
)
ind
+=
1
# nx_graphs.append(nx_graph)
nx_g
=
dgl_to_nx
(
graph
,
edge_map
)
#assign unique id to graph nodes
nx_g
=
nx
.
relabel_nodes
(
nx_g
,{
node
:
offset
+
k
for
k
,
node
in
enumerate
(
nx_g
.
nodes
())})
offset
+=
len
(
nx_g
.
nodes
())
# print(z)
# rna_draw(nx_g)
nx_graphs
.
append
(
nx_g
)