Commit 16e9faef authored by Carlos GO's avatar Carlos GO
Browse files

violins working

parent 433ea597
......@@ -245,7 +245,7 @@ def annotate_all(fp_file="../data/all_ligs_maccs.p", dump_path='../data/annotate
res = annotate_one((graph, graph_path, dump_path, get_label(graph, mode, fp_dict), ablate))
except KeyError:
failed += 1
print("missing fingerprint: {lig_name(graph)}")
print(f"missing fingerprint: {graph_path}")
continue
if res[0]:
failed += 1
......@@ -255,5 +255,7 @@ def annotate_all(fp_file="../data/all_ligs_maccs.p", dump_path='../data/annotate
if __name__ == '__main__':
# annotate_all(parallel=False, graph_path="../data/pockets_nx_2", dump_path="../data/annotated/pockets_nx_2", ablate="")
annotate_all(parallel=False, graph_path="../data/pockets_nx_symmetric", dump_path="../data/annotated/pockets_nx_symmetric",
ablate="", mode='fp')
annotate_all(parallel=False, graph_path="../data/pockets_nx_symmetric", dump_path="../data/annotated/pockets_nx_symmetric_wc-bb",
ablate="wc-bb", mode='fp')
# annotate_all(parallel=False, graph_path="../data/pockets_nx_symmetric", dump_path="../data/annotated/pockets_nx_symmetric_clust",
# ablate="", mode='fp', fp_file='../data/fp_dict_8clusters.p')
......@@ -176,7 +176,7 @@ def get_binding_site_graphs_all(lig_dict_path, dump_path, non_binding=False):
pass
done_pdbs = {f.split('_')[0] for f in os.listdir(dump_path)}
print(f">>> skipping {done_pdbs}")
print(f">>> skipping {len(done_pdbs)}")
failed = []
empties = 0
......@@ -197,6 +197,11 @@ def get_binding_site_graphs_all(lig_dict_path, dump_path, non_binding=False):
print(f"{pdbid} had {len(ligs)} binding sites")
missing_graphs.append(pdbid)
continue
try:
pdb_graph.nodes()
except AttributeError:
print("empty graph")
continue
# print(f"new guy: {pdbid}")
# continue
for lig in ligs:
......
......@@ -86,6 +86,9 @@ def graph_ablations(G, mode):
H = nx.Graph()
H.add_nodes_from(G.nodes(data=True))
# nx.set_node_attributes(H, 'pdb_pos', {n:d['pdb_pos'] for n,d in G.nodes(data=True)})
# nx.set_node_attributes(H, 'nt', {n:d['nt'] for n,d in G.nodes(data=True)})
if mode == 'label-shuffle':
#assign a random label from the same graph to each edge.
labels = [d['label'] for _,_,d in G.edges(data=True)]
......
......@@ -72,7 +72,10 @@ def fp_dict(smiles_file, include_ions=False, bits=False, fptype='FP2'):
with open(smiles_file, "r") as sms:
for s in sms:
smile, name = s.split()
mol = readstring('smi', smile)
try:
mol = readstring('smi', smile)
except:
continue
fp = mol.calcfp(fptype=fptype)
if bits:
fp = index_to_vec(fp.bits, nbits=nbits[fptype])
......@@ -101,5 +104,6 @@ if __name__ == "__main__":
# all_ligs = fp_dict("../data/ligs", bits=True, fptype='maccs')
# smiles = smiles_dict("../data/ligs")
# pickle.dump(smiles, open("../data/smiles_ligs_dict.p", "wb"))
all_ligs = fp_dict("../data/pdb_rna_smiles.txt", bits=True, fptype='maccs')
# pickle.dump(all_ligs, open("../data/all_ligs_maccs.p", "wb"))
# all_ligs = fp_dict("../data/pdb_rna_smiles.txt", bits=True, fptype='maccs')
all_ligs = fp_dict("../data/all_ligs_pdb.txt", bits=True, fptype='maccs')
pickle.dump(all_ligs, open("../data/all_ligs_pdb_maccs.p", "wb"))
......@@ -61,7 +61,7 @@ def print_gradients(model):
name, p = param
print(name, p.grad)
pass
def test(model, test_loader, device, decoys=None, fp_draw=False):
def test(model, test_loader, device, fp_draw=False):
"""
Compute accuracy and loss of model over given dataset
:param model:
......@@ -77,6 +77,12 @@ def test(model, test_loader, device, decoys=None, fp_draw=False):
for batch_idx, (graph, K, fp, idx) in enumerate(test_loader):
# Get data on the devices
K = K.to(device)
if model.clustered:
clust_hots = torch.zeros((len(fp), model.num_clusts))
for i,f in enumerate(fp):
clust_hots[i][int(f)] = 1.
fp = clust_hots
fp = fp.to(device)
K = torch.ones(K.shape).to(device) - K
graph = send_graph_to_device(graph, device)
......@@ -90,15 +96,6 @@ def test(model, test_loader, device, decoys=None, fp_draw=False):
'vmin': 0,
'vmax': 1}
jaccards = []
enrichments = []
for i,f in zip(idx, fp_pred):
true_lig = all_graphs[i.item()].split(":")[2]
rank,sim = decoy_test(f, true_lig, decoys)
enrichments.append(rank)
jaccards.append(sim)
mean_ranks = np.mean(enrichments)
mean_jaccard = np.mean(jaccards)
del K
del graph
......@@ -121,12 +118,12 @@ def test(model, test_loader, device, decoys=None, fp_draw=False):
del fp
return test_loss / test_size, mean_ranks, mean_jaccard
return test_loss / test_size
def train_model(model, criterion, optimizer, device, train_loader, test_loader, save_path,
writer=None, num_epochs=25, wall_time=None,
reconstruction_lam=1, fp_lam=1, embed_only=-1,
decoys=None, early_stop_threshold=10, fp_draw=False):
early_stop_threshold=10, fp_draw=False):
"""
Performs the entire training routine.
:param model: (torch.nn.Module): the model to train
......@@ -148,8 +145,6 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
edge_map = train_loader.dataset.dataset.edge_map
all_graphs = train_loader.dataset.dataset.all_graphs
decoys = get_decoys(mode='pdb', annots_dir=train_loader.dataset.dataset.path)
epochs_from_best = 0
start_time = time.time()
......@@ -158,6 +153,9 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
fp_lam_orig = fp_lam
reconstruction_lam_orig = reconstruction_lam
dec_mode = 'pdb-whole'
batch_size = train_loader.batch_size
#if we delay attributor, start with attributor OFF
#if <= -1, both always ON.
if embed_only > -1:
......@@ -185,12 +183,16 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
num_batches = len(train_loader)
train_enrichments = []
train_jaccards = []
for batch_idx, (graph, K, fp, idx) in enumerate(train_loader):
# Get data on the devices
batch_size = len(K)
#convert ints to one hots
if model.clustered:
clust_hots = torch.zeros((len(fp), model.num_clusts))
for i,f in enumerate(fp):
clust_hots[i][int(f)] = 1.
fp = clust_hots
fp = fp.to(device)
graph = send_graph_to_device(graph, device)
......@@ -216,11 +218,6 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
plt.show()
loss = model.compute_loss(fp, fp_pred)
for i,f in zip(idx, fp_pred):
true_lig = all_graphs[i.item()].split(":")[2]
rank,sim = decoy_test(f, true_lig, decoys)
train_enrichments.append(rank)
train_jaccards.append(sim)
# l = model.rec_loss(embeddings, K, similarity=False)
# print(l)
......@@ -254,17 +251,13 @@ def train_model(model, criterion, optimizer, device, train_loader, test_loader,
# Log training metrics
train_loss = running_loss / num_batches
writer.add_scalar("Training epoch loss", train_loss, epoch)
print(">> train enrichments", np.mean(train_enrichments))
print(">> train jaccards", np.mean(train_jaccards))
# train_accuracy = running_corrects / num_batches
# writer.log_scalar("Train accuracy during training", train_accuracy, epoch)
# Test phase
test_loss, enrichments, jaccards = test(model, test_loader, device, decoys=decoys)
test_loss = test(model, test_loader, device)
print(">> test loss ", test_loss)
print(">> test enrichments", enrichments)
print(">> test jaccards ", jaccards)
writer.add_scalar("Test loss during training", test_loss, epoch)
......
......@@ -23,7 +23,9 @@ class V1(Dataset):
nucs=True,
depth=3,
shuffle=False,
seed=0):
seed=0,
clustered=False,
num_clusters=8):
"""
Setup for data loader.
......@@ -45,6 +47,8 @@ class V1(Dataset):
self.edge_map, self.edge_freqs = self._get_edge_data()
self.num_edge_types = len(self.edge_map)
self.nucs = nucs
self.clustered = clustered
self.num_clusters = num_clusters
if nucs:
print(">>> storing nucleotide IDs")
self.nuc_map = {n:i for i,n in enumerate(['A', 'C', 'G', 'N', 'U'])}
......@@ -88,6 +92,11 @@ class V1(Dataset):
g_dgl.from_networkx(nx_graph=graph, edge_attrs=['one_hot'], node_attrs=['one_hot'])
g_dgl.title = self.all_graphs[idx]
if self.clustered:
one_hot_label = torch.zeros((1,self.num_clusters))
one_hot_label[fp] = 1.
fp = one_hot_label
if self.get_sim_mat:
# put the rings in same order as the dgl graph
ring = dict(sorted(ring.items()))
......
......@@ -30,6 +30,8 @@ parser.add_argument("-nu", "--nucs", default=True, help="Use nucleotide IDs for
parser.add_argument('-rs', '--seed', type=int, default=0, help='Random seed to use (if > 0, else no seed is set).')
parser.add_argument('-kf', '--kfold', type=int, default=0, help='Do k-fold crossval and do decoys on each fold..')
parser.add_argument('-es', '--early_stop', type=int, default=10, help='Early stop epoch threshold (default=10)')
parser.add_argument('-cl', '--clustered', action='store_true', default=False, help='Predict ligand cluster (default=False)')
parser.add_argument('-cn', '--num_clusts', type=int, default=8, help='Number of clusters (default=8)')
args = parser.parse_args()
......@@ -120,7 +122,8 @@ for k, (train_loader, test_loader) in enumerate(data):
num_rels=loader.num_edge_types,
num_bases=-1, pool=args.pool,
pos_weight=args.pos_weight,
nucs=args.nucs)
nucs=args.nucs, clustered=args.clustered,
num_clusts=args.num_clusts)
#if pre-trained initialize matching layers
if args.warm_start:
......
......@@ -48,10 +48,11 @@ class Attributor(nn.Module):
Linear/ReLu layers with Sigmoid in output since fingerprints between 0 and 1.
"""
def __init__(self, dims):
def __init__(self, dims, clustered=False):
super(Attributor, self).__init__()
# self.num_nodes = num_nodes
self.dims = dims
self.clustered = clustered
# create layers
self.build_model()
......@@ -69,7 +70,11 @@ class Attributor(nn.Module):
layers.append(nn.Dropout(0.5))
# hidden to output
layers.append(nn.Linear(last_hidden, last))
layers.append(nn.Sigmoid())
#predict one class
if self.clustered:
layers.append(nn.Softmax(dim=1))
else:
layers.append(nn.Sigmoid())
self.net = nn.Sequential(*layers)
def forward(self, x):
......@@ -137,7 +142,7 @@ class Embedder(nn.Module):
# ~~~~~~~~~~~~~~~~~~~~~~~
class Model(nn.Module):
def __init__(self, dims, device, attributor_dims, num_rels, pool='att', num_bases=-1,
pos_weight=0, nucs=True):
pos_weight=0, nucs=True, clustered=False, num_clusts=8):
"""
:param dims: the embeddings dimensions
......@@ -158,6 +163,8 @@ class Model(nn.Module):
self.pos_weight = pos_weight
self.device = device
self.nucs = nucs
self.clustered = clustered
self.num_clusts = num_clusts
if pool == 'att':
pooling_gate_nn = nn.Linear(attributor_dims[0], 1)
......@@ -167,7 +174,7 @@ class Model(nn.Module):
self.embedder = Embedder(dims=dims, num_rels=num_rels, num_bases=num_bases)
self.attributor = Attributor(attributor_dims)
self.attributor = Attributor(attributor_dims, clustered=clustered)
def forward(self, g):
embeddings = self.embedder(g)
......
......@@ -121,10 +121,7 @@ def decoy_test(model, decoys, edge_map, embed_dim,
# fp_pred = np.random.choice([0, 1], size=(166,), p=[1./2, 1./2])
if shuffle:
orig = true_id
print(orig)
true_id = np.random.choice(ligs, replace=False)
print(true_id)
print(true_id == orig)
active = decoys[true_id][0]
decs = decoys[true_id][1]
rank = distance_rank(active, fp_pred, decs)
......@@ -145,93 +142,104 @@ def generic_fp(annot_dir):
consensus = np.unique(fps, axis=0)
pass
def make_violins(df, x='method', y='rank', save=None, show=True):
ax = sns.violinplot(x=x, y=y, data=df, color='0.8', bw=.1)
for artist in ax.lines:
artist.set_zorder(10)
for artist in ax.findobj(PathCollection):
artist.set_zorder(11)
sns.stripplot(data=df, x=x, y=y, jitter=True, alpha=0.6)
if not save is None:
plt.savefig(save, format="pdf")
if show:
plt.show()
pass
def ablation_results():
# modes = ['', '_bb-only', '_wc-bb', '_wc-bb-nc', '_no-label', '_label-shuffle', 'pair-shuffle']
modes = ['', 'pair-shuffle']
modes = ['raw', 'bb', 'wc-bb', 'pair-shuffle']
decoys = get_decoys(mode='pdb')
ranks, methods = [], []
ranks, methods, jaccards = [], [], []
graph_dir = '../data/annotated/pockets_nx_symmetric'
# graph_dir = '../data/annotated/pockets_nx_2'
run = "small_no_rec_2"
run = 'ppp'
run = 'ismb'
num_folds = 10
for m in modes:
# if m in ['', 'pair-shuffle']:
# graph_dir = "../data/annotated/pockets_nx"
# run = 'small_no_rec_2'
# else:
# graph_dir = "../data/annotated/pockets_nx" + m
# run = 'small_no_rec' + m
if m in ['raw', 'pair-shuffle']:
graph_dir = "../data/annotated/pockets_nx_symmetric"
run = 'ismb-raw'
else:
graph_dir = "../data/annotated/pockets_nx_symmetric_" + m
run = 'ismb-' + m
model, meta = load_model(run)
edge_map = meta['edge_map']
embed_dim = meta['embedding_dims'][-1]
num_edge_types = len(edge_map)
graph_ids = pickle.load(open(f'../results/trained_models/{run}/splits.p', 'rb'))
for fold in range(num_folds):
model, meta = load_model(run +"_" + str(fold))
edge_map = meta['edge_map']
embed_dim = meta['embedding_dims'][-1]
num_edge_types = len(edge_map)
shuffle = False
if m == 'pair-shuffle':
shuffle = True
ranks_this,sims_this = decoy_test(model, decoys, edge_map, embed_dim,
shuffle=shuffle,
nucs=meta['nucs'],
test_graphlist=graph_ids['train'],
test_graph_path=graph_dir)
test_ligs = []
ranks.extend(ranks_this)
methods.extend([m]*len(ranks_this))
graph_ids = pickle.load(open(f'../results/trained_models/{run}_{fold}/splits_{fold}.p', 'rb'))
#decoy distance distribution
dists = []
for _,(active, decs) in decoys.items():
for d in decs:
dists.append(jaccard(active, d))
plt.scatter(ranks_this, sims_this)
plt.xlabel("ranks")
plt.ylabel("distance")
plt.show()
sns.distplot(dists, label='decoy distance')
sns.distplot(sims_this, label='pred distance')
plt.xlabel("distance")
plt.legend()
plt.show()
shuffle = False
if m == 'pair-shuffle':
shuffle = True
ranks_this,sims_this = decoy_test(model, decoys, edge_map, embed_dim,
shuffle=shuffle,
nucs=meta['nucs'],
test_graphlist=graph_ids['test'],
test_graph_path=graph_dir)
test_ligs = []
ranks.extend(ranks_this)
jaccards.extend(sims_this)
methods.extend([m]*len(ranks_this))
rank_cut = 0.9
# cool = [graph_ids['test'][i] for i,(d,r) in enumerate(zip(sims_this, ranks_this)) if d <0.4 and r > 0.8]
# print(cool, len(ranks_this))
#4v6q_#0:BB:FME:3001.nx_annot.p
# test_ligs = set([f.split(":")[2] for f in graph_ids['test']])
# train_ligs = set([f.split(":")[2] for f in graph_ids['train']])
# print("ligands not in train set", test_ligs - train_ligs)
# points = []
# tot = len([x for x in ranks_this if x >= rank_cut])
# for sim_cut in np.arange(0,1.1,0.1):
# pos = 0
# for s,r in zip(sims_this, ranks_this):
# if s < sim_cut and r > rank_cut:
# pos += 1
# points.append(pos / tot)
# from sklearn.metrics import auc
# plt.title(f"Top 20% Accuracy {auc(np.arange(0, 1.1, 0.1), points)}, {m}")
# plt.plot(points, label=m)
# plt.plot([x for x in np.arange(0,1.1, 0.1)], '--')
# plt.ylabel("Positives")
# plt.xlabel("Distance threshold")
# plt.xticks(np.arange(10), [0, 0.1, 0.2, 0.3, 0.4, 0.5,0.6, 0.7, 0.9, 1.0])
# plt.legend()
# plt.show()
#decoy distance distribution
# dists = []
# for _,(active, decs) in decoys.items():
# for d in decs:
# dists.append(jaccard(active, d))
# plt.scatter(ranks_this, sims_this)
# plt.xlabel("ranks")
# plt.ylabel("distance")
# plt.show()
# sns.distplot(dists, label='decoy distance')
# sns.distplot(sims_this, label='pred distance')
# plt.xlabel("distance")
# plt.legend()
# plt.show()
df = pd.DataFrame({'rank': ranks, 'method':methods})
ax = sns.violinplot(x="method", y="rank", data=df, color='0.8')
for artist in ax.lines:
artist.set_zorder(10)
for artist in ax.findobj(PathCollection):
artist.set_zorder(11)
sns.stripplot(data=df, x='method', y='rank', jitter=True, alpha=0.6)
# plt.savefig("../tex/Figs/violins_gcn_2.pdf", format="pdf")
plt.show()
# # rank_cut = 0.9
# cool = [graph_ids['test'][i] for i,(d,r) in enumerate(zip(sims_this, ranks_this)) if d <0.4 and r > 0.8]
# cool = [graph_ids['test'][i] for i,(d,r) in enumerate(zip(sims_this, ranks_this)) if d <0.3]
# print(cool)
# print(cool, len(ranks_this))
#4v6q_#0:BB:FME:3001.nx_annot.p
# test_ligs = set([f.split(":")[2] for f in graph_ids['test']])
# train_ligs = set([f.split(":")[2] for f in graph_ids['train']])
# print("ligands not in train set", test_ligs - train_ligs)
# points = []
# tot = len([x for x in ranks_this if x >= rank_cut])
# for sim_cut in np.arange(0,1.1,0.1):
# pos = 0
# for s,r in zip(sims_this, ranks_this):
# if s < sim_cut and r > rank_cut:
# pos += 1
# points.append(pos / tot)
# from sklearn.metrics import auc
# plt.title(f"Top 20% Accuracy {auc(np.arange(0, 1.1, 0.1), points)}, {m}")
# plt.plot(points, label=m)
# plt.plot([x for x in np.arange(0,1.1, 0.1)], '--')
# plt.ylabel("Positives")
# plt.xlabel("Distance threshold")
# plt.xticks(np.arange(10), [0, 0.1, 0.2, 0.3, 0.4, 0.5,0.6, 0.7, 0.9, 1.0])
# plt.legend()
# plt.show()
df = pd.DataFrame({'rank': ranks, 'jaccard': jaccards, 'method':methods})
make_violins(df, x='method', y='jaccard')
make_violins(df, x='method', y='rank')
def structure_scanning(pdb, ligname, graph, model, edge_map, embed_dim):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment