Commit 2779ff55 authored by Carlos GO's avatar Carlos GO
Browse files

first commit

parent 04a39122
"""
Extract binding pockets from directory of mCIFS.
Writes a PDB for each binding pocket identified.
"""
import os
import sys
import numpy as np
import pandas as pd
from Bio.PDB import *
from sdf_parse import ligand_dict as sdf_info
# PDB_PATH = os.path.join("..", "Data", "rna_ligand_full_structures")
PDB_PATH = os.path.join("..", "Data", "rna_ligand_with_protein")
def find_atom(res, atomname):
for a in res.get_atom():
atomclean = a.fullname.strip("'").strip()
if atomclean == atomname:
return a
def res_dist_to_ligand(rna_res, ligand_res):
"""
Returns distance from rna_res phosphate to nearest ligand_res
UPDATE: returns smallest distance between any ligand_res atom and any
atom in the RNA residue rna_res.
"""
distances = []
for atom in ligand_res:
#check that rna_res is RNA residue and not ligand !!
for a in rna_res.child_list:
# rna_atom = find_atom(rna_res, a)
# diff_vector = find_atom(rna_res, a).coord - atom.coord
diff_vector = a.coord - atom.coord
distances.append(np.sqrt(np.sum(diff_vector * diff_vector)))
return min(distances)
class LigandSelect(Select):
"""
Override Select class to select ligand atoms and RNA atoms within
distance threshold.
"""
def __init__(self, chain, distance_threshold, ligand_res):
self.chain = chain
self.threshold = distance_threshold
self.ligand_res = ligand_res
def accept_residue(self, residue):
if residue == self.ligand_res:
return True
else:
dist = res_dist_to_ligand(residue, self.ligand_res)
return dist < self.threshold
def ligand_info(res, lig_dict):
"""
Return tuple (is_ligand, mol_weight)
"""
try:
return lig_dict[res]
except KeyError:
return (False, 0)
def get_binding(strucpath, name,
ligands_path="../Data/Ligands_noHydrogens_withMissing_6036_Instances.sdf",
dest="pockets_extracted", distance_threshold=5):
"""
Extract PDB of binding pocket.
"""
cif = os.path.join(strucpath, name + ".cif")
#load mmCIF structure
struc_dict = MMCIF2Dict.MMCIF2Dict(cif)
#load PDB
parser = MMCIFParser(QUIET=False)
pdbstruc = parser.get_structure(name, cif)
lig_dict = sdf_info(ligands_path)
ligand_dict = {}
try:
ligand_dict['position'] = struc_dict['_pdbx_nonpoly_scheme.pdb_seq_num']
ligand_dict['res_name'] = struc_dict['_pdbx_nonpoly_scheme.mon_id']
ligand_dict['chain'] = struc_dict['_pdbx_nonpoly_scheme.pdb_strand_id']
ligand_dict['unique_id'] = struc_dict['_pdbx_nonpoly_scheme.asym_id']
is_ion = lambda x: ligand_info(x, lig_dict)[0]
ligand_dict['is_ion'] = [is_ion(x) for x in ligand_dict['res_name']]
# print(ligand_dict['is_ion'])
except:
print("Ligand not detected.")
model = pdbstruc[0]
try:
ligand_df = pd.DataFrame.from_dict(ligand_dict)
#pandas complains when dictionary values are not lists
#this happens when there is only one ligand in PDB
except ValueError:
ligand_df = pd.DataFrame(ligand_dict, index=[0])
ligand_df['position'] = pd.to_numeric(ligand_df['position'])
ligand_df['pdbid'] = name
#remove this after.
# return ligand_df
io = PDBIO()
#redo this but for each chain in PDB
for ligand in ligand_df.itertuples():
ligand_res = None
#find the residue corresponding to ligand
for res in model[ligand.chain].get_residues():
if res.id[1] == ligand.position:
ligand_res = res
if ligand.res_name in ["HOH", "MG", "NCO", "K"]:
continue
#get neighbouring atoms for each chain in PDB
# for chain in pdbstruc.get_chains():
io.set_structure(pdbstruc)
if not os.path.isdir(dest):
os.mkdir(dest)
io.save(os.path.join(f"{dest}", "ligand_{name}_{ligand.res_name}_{ligand.unique_id}.pdb"),\
LigandSelect("A", distance_threshold, ligand_res))
pass
if __name__ == "__main__":
testpdb = os.path.join(PDB_PATH)
pdb_ids = [p.split(".")[0] for p in os.listdir(PDB_PATH) if ".cif" in p]
lig_list = []
for i, p in enumerate(pdb_ids[pdb_ids.index('3g6e'):]):
print(f"working on {i} of {len(pdb_ids) -1}: {p}")
try:
# l = get_binding(testpdb, p)
get_binding(testpdb, p)
# lig_list.append(l)
except Exception as e:
print("exception")
print(e)
continue
# all_ligands = pd.concat(lig_list)
# all_ligands.to_csv("binding_sites_3A_interchain.csv")
pass
"""
Embed graphs from a GED run.
"""
import pickle
from ligand_knn import prepare_data
from dissimilarity_embed import full_embed
def get_DM(ged_pickle):
geds = pickle.load(open(ged_pickle, "rb"))
fps = pickle.load(open("rna_smiles_fp_0ent_np_maccs.pickle", "rb"))
fps_prot = pickle.load(open("rna_with_prot_maccs_fp.pickle", "rb"))
fps = {**fps, **fps_prot}
return prepare_data(ged_pickle,
fps, non_redundant=True)
def embed(DM, graphlist, m, method='spanning'):
return full_embed(graphlist,m, dist_mat=True, DM=DM, heuristic=method)
if __name__ == "__main__":
DM, L, graphlist = get_DM('geds_delta.pickle')
pickle.dump(DM, open('delta_DM.pickle', 'wb'))
pickle.dump(L, open('delta_L.pickle', 'wb'))
# print(embed(DM, graphlist, 20, method='k-centers'))
"""
Perform graph dissimilarity embedding of graphs.
1. Select a subset size N of graphs from full set to be 'prototype' graph set P. 2. For each new graph, embedding vector v is a real vector of size N where each
entry v_i in v is the distance from G to P_i
"""
import sys
import time
import logging
from random import randint
import multiprocessing
import cProfile
import pstats
import uuid
import random
import networkx as nx
import numpy as np
from numpy.linalg import eig
from spectral_distance import *
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
PROCS = 15
def row_compute(args):
"""
Compute distances
"""
D, i = args
logging.info(f"Computing distance for graph {i+1} of {len(D)}")
dist = np.zeros((len(D)))
for j in range(i, len(D)):
d = graph_distance_evec(D[i], D[j])
dist[j] = d
return dist
def dummy(args):
cProfile.runctx('row_compute(args)', globals(), locals(),\
f'prof_{random.randint(0,20)}.prof')
def graph_eigen(G):
L = graph_laplacian(G)
return eig(L)
def distance_matrix_para(D):
eigens = []
#compute list of eigendecompositions for each graph.
print("diagonalizing laplacians")
with multiprocessing.Pool(PROCS) as pool:
for i, v in enumerate(pool.map(graph_eigen, D)):
eigens.append(v)
#fill distance matrix
dist = np.zeros((len(D), len(D)))
todo = ((eigens, i) for i in range(dist.shape[0]))
with multiprocessing.Pool(PROCS) as pool:
for i, v in enumerate(pool.map(row_compute, todo)):
dist[i] = v
#make symmetric
i_lower = np.tril_indices(dist.shape[0], -1)
dist[i_lower] = dist.T[i_lower]
return dist
def distance_matrix(D):
"""
Compute distance matrix for list of graphs in D.
"""
dist = np.zeros((len(D), len(D)))
for i in range(dist.shape[0]):
logging.info(f"Computing distance for graph {i+1} of {len(D)}")
for j in range(i, dist.shape[0]):
d = graph_distance(D[i], D[j])
dist[i][j] = d
dist[j][i] = d
return dist
def median_graph(DM):
"""
Rkketurn median graph as:
median(D) = argmin_g1 \sum_{g2} d(g1, g2).
Graph whose distance to all other graphs is minimal.
"""
return np.argmin(np.sum(DM, axis=1))
def center_graph(DM, mask=None, masked=False):
"""
Return center graph as;
center(D) = argmin_g1 max_g2 d(g1,g2)
"""
if masked:
mask = mask
else:
mask = np.zeros(len(DM))
return np.argmin(np.ma.array([max([DM[i][j] for i in
range(len(DM))]) for j in
range(len(DM))],
mask=mask))
def spanning_selection(DM, m):
median = median_graph(DM)
proto_indices = [median]
tot_ind = list(np.arange(len(DM)))
d_indices = set(tot_ind) - {median}
#get point furtherst from prototype set.
while len(proto_indices) < m:
proto = np.argmax(
np.ma.array([min([DM[i][p] for p in proto_indices])
for i in tot_ind],
mask=np.isin(tot_ind, proto_indices)
))
proto_indices.append(proto)
return proto_indices
def random_proto(DM, k):
"""
Random prototype selection.
"""
return np.random.choice(list(range(len(DM))), size=k, replace=False)
def k_centers(DM, k, return_assignments=False):
"""
k-centers selection algorithm.
"""
protos = spanning_selection(DM, k)
print(f"protos: {protos}")
print(DM)
inds = list(range(len(DM)))
protos = np.random.choice(inds, size=k, replace=False)
while True:
#find the closest graph to a center and add to centers
centers = [{p} for p in protos]
for g in range(len(DM)):
if g in protos:
continue
nearest_proto_ind = np.argmin([DM[g][c] for c in protos])
centers[nearest_proto_ind].add(g)
# print(f"current protos: {protos}")
# print(f"current centers: {centers}")
num_changed = 0
for i,cs in enumerate(centers):
# clean_D = DM[indices,:][:, indices]
# print(DM[list(c):,][:,list(c)])
mask = np.logical_not(np.isin(inds, list(cs)))
# print(f"cs: {cs}")
# print(f"mask: {mask}")
center = center_graph(DM, masked=True, mask=mask)
# print(f"center graph {i}: {center}")
if center != protos[i]:
num_changed += 1
protos[i] = center
print(f"changed: {num_changed} of {len(protos)}")
if num_changed == 0:
if return_assignments:
return protos, centers
else:
return protos
# return protos
def prototypes(D, m, DM, heuristic='sphere'):
"""
Compute set of m prototype graphs.
Input:
`list (D)`: list of original graphs
`int (m)`: number of prototypes to select.
`np array (DM)`: Distance matrix.
Returns:
`list`: list with nx graphs forming prototype set.
"""
logging.info(f"Using {heuristic} heuristic")
if heuristic == 'sphere':
"""
Select prototypes from a sphere induced on the dataset.
"""
prototypes = []
logging.info("Computing distance matrix...")
logging.info("Finding center of graph set")
#get center graph
distances = np.sum(DM, axis=1)
center_index = np.argmin(distances)
center = D[center_index]
#get graph furthest from center
border_index = np.argmax(DM[center_index])
border = D[border_index]
radius = DM[center_index][border_index]
#define interval along radius
interval = radius / m
proto_indices = []
prototypes += [center, border]
proto_indices += [center_index, border_index]
mask = np.zeros(DM.shape[0])
mask[border_index] = 1
mask[center_index] = 1
center_ref = np.ma.MaskedArray(DM[center_index], mask)
logging.info("Obtaining prototype graphs...")
for i in range(m-2):
border_dist = abs(center_ref - (i*interval))
dist_mask = np.ma.MaskedArray(border_dist, mask)
proto_index = dist_mask.argmin()
proto_indices.append(proto_index)
prototypes.append(D[proto_index])
#mask the prototype we selected
mask[proto_index] = 1
return proto_indices
if heuristic == 'spanning':
return spanning_selection(DM, m)
if heuristic == "k-centers":
return k_centers(DM, m)
if heuristic == "random":
return random_proto(DM, m)
pass
def graph_embed(args):
"""
Embed graph G given prototype set P.
Returns:
`array`: numpy array representing embedding vector.
"""
i,G,P, = args
return (i, np.array([graph_distance(G, p) for p in P]))
def full_embed(D, m, DM=None, dist_mat=None, heuristic='spanning'):
"""
For a dataset of graphs D, perform prototype selection and embedding.
Returns:
`matrix`: numpy matrix where each row is the embedding of a graph from `D`.
"""
if dist_mat == None:
DM = distance_matrix_para(D)
P_idx = prototypes(D, m, DM, heuristic=heuristic)
print(P_idx)
logging.info("Embedding graphs.")
# embeddings = [graph_embed(g, P) for g in D]
# embeddings = [graph_embed(g, P) for g in D]
# todo = ((i,g,P) for i,g in enumerate(D))
embeddings = np.zeros((len(D), m))
for i, g in enumerate(D):
embeddings[i] = DM[i][P_idx]
# with multiprocessing.Pool(PROCS) as pool:
# for e in pool.map(graph_embed, todo):
# i,embed = e
# embeddings[i] = embed
return embeddings
if __name__ == "__main__":
logging.info("Generating random graphs.")
D = [nx.erdos_renyi_graph(randint(10, 20), .7) for _ in range(5)]
# print(full_embed(D, 5))
cProfile.run('full_embed(D, 5)', 'runstats')
p = pstats.Stats('runstats')
p.sort_stats('time').print_stats(10)
pass
"""
A* GED for RNA graphs.
Given two 2.5D RNA graphs, compute minimum graph edit distance.
Every time a node is added it must be connected to the backbone.
Nodes are also ordered..
Max degree is 3 (maybe 4 for base triples..).
Maybe decompose into planar graphs first and then solve exact for non-nested
keeping secondary structure fixed.
1. Remove pseudoknots.
2. Align SSEs. 3. Get GED of corresponding loop regions.
Just implement brute force A* first and then see where to improve.
"""
import os,sys
import logging
import itertools
from heapq import *
from collections import Counter
import time
import random
import multiprocessing
import pickle
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import cm
import networkx as nx
# import pygraphviz as PG
import rna_layout
from pocket_draw import *
MAX_PROCS = 20
logging.basicConfig(filename='ged.log',
filemode='a',
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
level=logging.INFO)
class OpNode:
def __init__(self,op=('NILL', 'NILL'), cost=0, parent=None,source_map={},target_map={},
depth=0):
self.op = op
self.cost = cost
self.parent = parent
if parent == None:
self.source_map = {}
self.target_map= {}
self.depth = 0
else:
self.source_map = {**parent.source_map, **{op[0]: op[1]}}
self.target_map= {**parent.target_map, **{op[1]: op[0]}}
self.depth = parent.depth + 1
def __str__(self):
"""
Recursive path print.
"""
if self.parent == None:
return str(self.op)
else:
return f"{self.op} -> {self.parent.__str__()} : {self.cost}"
def path_iter(self):
if self.op != ('NILL', 'NILL'):
yield self
yield from self.parent.path_iter()
else:
yield self
raise StopIteration
def label_path(op_node, g1, g2):
e1, e2 = op_node.op
label = lambda e,G: 'NILL' if e == 'NILL' else G[e[0]][e[1]]['label']
print(e1, e2)
short = lambda e: f"{e[0][0]}{e[0][1]}||{e[1][0]}{e[1][1]}"
if op_node.parent == None:
return f"{label(e1, g1)}-{label(e2, g2)}"
else:
return f"{short(e1)} {label(e1,g1)} - {short(e2)} {label(e2,g2)} --> {label_path(op_node.parent, g1, g2)}"
def edge_edit_cost(u,v, g1, g2, node):
"""
Recursively search for edge substitution.
"""
if node == None:
return 0
else:
u1, v1 = node.op
e1 = (u1, u)
e2 = (v1, v)
# print(e1, e2)
cost = 0
if e1 in g1.edges and e2 in g2.edges:
#edge substitution
l1 = g1[e1[0]][e1[1]]['label']
l2 = g2[e2[0]][e2[1]]['label']
if l1 != l2:
cost = 1
if 'B53' in (l1, l2):
cost = sys.maxsize
elif e1 not in g1.edges and e2 in g2.edges:
#edge insert+ion
cost = 2
elif e1 in g1.edges and e2 not in g2.edges:
cost = 2
else:
cost = 0
return cost + edge_edit_cost(u,v,g1,g2,node.parent)
def node_cost(op, G1, G2):
u,v = op
u_bb = False
v_bb = False
try:
u_bb = 'L' in u[0]
except:
pass
try:
v_bb = 'L' in v[0]
except:
pass
if u_bb or v_bb:
if 'NILL' in op:
#deleting a loop costs number of nodes in loop
if u_bb:
return u[2]