Commit 820d4c6c authored by Carlos GO's avatar Carlos GO
Browse files

requirements

parent 2b95014c
......@@ -40,6 +40,26 @@ def send_graph_to_device(g, device):
return g
def send_graph_to_device(g, device):
"""
Send dgl graph to device
:param g: :param device:
:return:
"""
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
# nodes
labels = g.node_attr_schemes()
for l in labels.keys():
g.ndata[l] = g.ndata.pop(l).to(device, non_blocking=True)
# edges
labels = g.edge_attr_schemes()
for i, l in enumerate(labels.keys()):
g.edata[l] = g.edata.pop(l).to(device, non_blocking=True)
return g
def set_gradients(model, embedding=True, attributor=True):
"""
Set the gradients to the embedding and the attributor networks.
......
# Name Version Build Channel
biopython 1.73 py36h1de35cc_0
dgl 0.4.1 pypi_0 pypi
networkx 2.4 py_0
numpy 1.16.3 py36hacdab7b_0
openbabel 2.4.1 py36_5 openbabel
pandas 0.24.2 py36h0a44026_0
pytorch 1.1.0 py3.6_0 pytorch
scipy 1.2.1 py36h1410ff5_0
seaborn 0.9.0 py36_0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment