Commit 12ce2a5a authored by Carlos GO's avatar Carlos GO
Browse files

proper cluster loss

parent 16e9faef
......@@ -78,10 +78,7 @@ def test(model, test_loader, device, fp_draw=False):
# Get data on the devices
K = K.to(device)
if model.clustered:
clust_hots = torch.zeros((len(fp), model.num_clusts))
for i,f in enumerate(fp):
clust_hots[i][int(f)] = 1.
fp = clust_hots
fp = fp.long()
fp = fp.to(device)
K = torch.ones(K.shape).to(device) - K
......@@ -187,14 +184,11 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# Get data on the devices
#convert ints to one hots
if model.clustered:
clust_hots = torch.zeros((len(fp), model.num_clusts))
for i,f in enumerate(fp):
clust_hots[i][int(f)] = 1.
fp = clust_hots
fp = fp.to(device)
graph = send_graph_to_device(graph, device)
if model.clustered:
fp = fp.long()
fp = fp.to(device)
fp_pred, embeddings = model(graph)
......
......@@ -71,9 +71,7 @@ class Attributor(nn.Module):
# hidden to output
layers.append(nn.Linear(last_hidden, last))
#predict one class
if self.clustered:
layers.append(nn.Softmax(dim=1))
else:
if not self.clustered:
layers.append(nn.Sigmoid())
self.net = nn.Sequential(*layers)
......@@ -240,7 +238,10 @@ class Model(nn.Module):
"""
# pw = torch.tensor([self.pos_weight], dtype=torch.float, requires_grad=False).to(self.device)
# loss = torch.nn.BCEWithLogitsLoss(pos_weight=pw)(pred_fp, target_fp)
loss = torch.nn.BCELoss()(pred_fp, target_fp)
if self.clustered:
loss = torch.nn.CrossEntropyLoss()(pred_fp, target_fp)
else:
loss = torch.nn.BCELoss()(pred_fp, target_fp)
# loss = JaccardDistanceLoss()(pred_fp, target_fp)
return loss
......
......@@ -102,20 +102,6 @@ def decoy_test(model, decoys, edge_map, embed_dim,
nx_graph, dgl_graph = nx_to_dgl(g, edge_map, nucs=nucs)
fp_pred, _ = model(dgl_graph)
if False:
n_nodes = len(dgl_graph.nodes)
att= get_attention_map(dgl_graph, src_nodes=dgl_graph.nodes(), dst_nodes=dgl_graph.nodes(), h=1)
att_g0 = att[0] # get attn weights only for g0
# Select atoms with highest attention weights and plot them
tops = np.unique(np.where(att_g0>0.51)) # get top atoms in attention
print(f"tops {tops}")
g0 = dgl_to_nx(dgl_graph, edge_map)
nodelist = list(g0.nodes())
highlight_edges = list(g0.subgraph([nodelist[t] for t in tops]).edges())
rna_draw(g0, highlight_edges=highlight_edges)
fp_pred = fp_pred.detach().numpy() > 0.5
# print(fp_pred)
# fp_pred = np.random.choice([0, 1], size=(166,), p=[1./2, 1./2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment