Commit 743c5992 authored by Carlos GO's avatar Carlos GO
Browse files

proto select

parent d0fbcc0c
......@@ -62,6 +62,28 @@ Just pass the alignment object to the `graph_align` function.
You can obtain a vector representation of a binding pocket by building your own prototype set, or by using our pre-built prototype sets.
### Building your prototype set
The output of the GED computation on a set of graphs is a pickled list object containing all pairwise comparisions and some additional information.
You can convert this output to a distance matrix and a list indicating the graphs each entry in the distance matrix correspond to.
```python
>>> from RNAmigos.post_ged import data_prepare
>>> geds = '../data/geds_delta.pickle'
>>> fps = '../data/all_rna_ligands_fingerprints.pickle'
>>> DM, L, graphlist = data_prepare(geds, fps)
```
The distance matrix can be passed to a prototype selector.
```python
>>> from RNAmigos.dissimilarity_embed import prototype_select
>>> prototypes = prototype_select(DM, 20)
```
## Fingerprint Prediction
......
......@@ -173,7 +173,7 @@ def k_centers(DM, k, return_assignments=False):
# return protos
def prototypes(D, m, DM, heuristic='sphere'):
def prototypes(DM,m, heuristic='spanning'):
"""
Compute set of m prototype graphs.
......@@ -188,51 +188,6 @@ def prototypes(D, m, DM, heuristic='sphere'):
"""
logging.info(f"Using {heuristic} heuristic")
if heuristic == 'sphere':
"""
Select prototypes from a sphere induced on the dataset.
"""
prototypes = []
logging.info("Computing distance matrix...")
logging.info("Finding center of graph set")
#get center graph
distances = np.sum(DM, axis=1)
center_index = np.argmin(distances)
center = D[center_index]
#get graph furthest from center
border_index = np.argmax(DM[center_index])
border = D[border_index]
radius = DM[center_index][border_index]
#define interval along radius
interval = radius / m
proto_indices = []
prototypes += [center, border]
proto_indices += [center_index, border_index]
mask = np.zeros(DM.shape[0])
mask[border_index] = 1
mask[center_index] = 1
center_ref = np.ma.MaskedArray(DM[center_index], mask)
logging.info("Obtaining prototype graphs...")
for i in range(m-2):
border_dist = abs(center_ref - (i*interval))
dist_mask = np.ma.MaskedArray(border_dist, mask)
proto_index = dist_mask.argmin()
proto_indices.append(proto_index)
prototypes.append(D[proto_index])
#mask the prototype we selected
mask[proto_index] = 1
return proto_indices
if heuristic == 'spanning':
return spanning_selection(DM, m)
if heuristic == "k-centers":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment