Skip to content

Commit

Permalink
Code for drawing polygons (slow) + README
Browse files Browse the repository at this point in the history
  • Loading branch information
axanderssonuu committed Jul 3, 2023
1 parent c3cad9f commit 4fa4881
Show file tree
Hide file tree
Showing 11 changed files with 978 additions and 350 deletions.
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
best-model.pt
examples/best-model.pt
is3g/best-model.pt
/.vscode/
/is3g/_version.py
.tissuumaps
plots
run_experiments.py
benchmark.py
simulated_data.csv

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

benchmark.py
iss_simulated.csv
# C extensions
*.so

Expand Down
23 changes: 15 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# issseg
# IS3G

In Situ Sequencing Segmentation

Expand All @@ -10,7 +10,7 @@ Use the package manager [pip](https://pip.pypa.io/en/stable/) to install is3g:

## Usage

Jupyter notebook examples using the `is3g` python API can be found in the [examples](examples) directory.
See [example.ipynb](this) Jupyter Notebook for a minimal example.

You can also use the `is3g` command to execute is3g from the command-line. For example:
```console
Expand All @@ -28,19 +28,26 @@ $ is3g --help
Usage: is3g [OPTIONS] CSV_PATH CSV_OUT

Options:
-x, --x TEXT TODO
-y, --y TEXT TODO
-l, --label TEXT TODO
-r, --radius FLOAT RANGE TODO [x>=0]
-x, --x TEXT Column in the CSV file where the x coordinate is stored for each gene.

-y, --y TEXT Column in the CSV file where the y coordinate is stored for each gene.

-l, --label TEXT Column in the CSV file where the gene labels are stored.
-r, --radius FLOAT RANGE Approximate radius of a cell. [x>=0]
--remove-background / --no-remove-background
TODO
Specify whether to automatically remove genes from low-density regions. Default is true.
--version Show the version and exit.
--help Show this message and exit.
```


## Support

If you find a bug, please [raise an issue](https://github.com/wahlby-lab/is3g/issues/new).
If (when) you find a bug, please [raise an issue](https://github.com/wahlby-lab/is3g/issues/new).

## Contributing

Expand Down
File renamed without changes.
299 changes: 0 additions & 299 deletions examples/minimal_example.ipynb

This file was deleted.

2 changes: 2 additions & 0 deletions is3g/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .is3g import is3g as is3g
from .is3g import make_binary_cell_boundary_image as make_binary_cell_boundary_image
from .is3g import replace_low_freq_with_zero as replace_low_freq_with_zero

try:
from is3g._version import version as __version__
Expand Down
150 changes: 150 additions & 0 deletions is3g/_knn_tools/draw_polygons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import numpy as np
import scipy.sparse as sp
from scipy.sparse.csgraph import minimum_spanning_tree, dijkstra
from .linalg import spatial_binning_matrix, attribute_matrix, connectivity_matrix
from skimage.transform import rescale
from typing import Tuple

def _find_minimum_spanning_trees(label_mask):
labels = np.unique(label_mask)
paths = {}
trees = {}
from tqdm.auto import tqdm
for label in tqdm(labels):
if label > 0:
# Create a binary mask for the current label
rc = np.where(label_mask==label)
# Convert to array
rc = np.vstack(rc).T
# Convert the binary mask to a graph adjacency matrix
edges = connectivity_matrix(rc, method='radius', r=np.sqrt(2))
# Convert the adjacency matrix to a sparse matrix
# Find the minimum spanning tree using Kruskal's algorithm
mst = minimum_spanning_tree(edges)
# Find the longest path in the minimum spanning tree
# using Dijkstra's algoritm
longest_path = _extract_longest_path(mst)
longest_path = [rc[l,:] for l in longest_path]
longest_path.append(longest_path[0])

rr = np.vstack(mst.nonzero()).T
mst = [(rc[r[0],:], rc[r[1],:]) for r in rr]


paths[label] = np.array(longest_path)
trees[label] = mst
return paths, trees

def _extract_longest_path(mst):
import warnings

# Find the shortest paths from a starting node to all other nodes
with warnings.catch_warnings():
warnings.simplefilter("ignore")
dist_matrix, predecessors = dijkstra(-mst, directed=False, return_predecessors=True)

# Find node to start with
start_node = dist_matrix.min(axis=1).argmin()

# Find the node with the maximum distance (longest path)
end_node = np.argmin(dist_matrix[start_node])
longest_path = [end_node]

# Trace back the longest path
while end_node != start_node:
end_node = predecessors[start_node,end_node]
longest_path.append(end_node)

# Reverse the longest path to obtain the correct order
longest_path = longest_path[::-1]

return longest_path


def _finite_difference_matrix(w):
return np.eye(w,k=-1) - np.eye(w,k=1)

def _gaussian_blur_matrix(w, sigma):
x = np.arange(w).reshape((1,-1))
d = abs(x-x.T)
mask = d < 3*sigma
g = np.exp(-0.5*d**2/sigma**2) * mask
g = g / g.max(axis=1,keepdims=True)
return sp.csr_matrix(g)

def _compute_cell_label_mask(xy, A, gridstep, threshold, dapi_shape):
# Compute binning matrix
B_non_empty, grid_props = spatial_binning_matrix(xy, gridstep, return_grid_props=True, xy_min=(0,0), xy_max=[dapi_shape[1], dapi_shape[0]])

# Get dimension of binning matrix
_, n_cells = B_non_empty.shape

# Get shape of full binning matrix
grid_shape = grid_props['grid_size']
n_pixels = np.prod(grid_shape)

# Get coordinate of each bin
grid_coords = grid_props['grid_coords']
grid_coords_linear = np.ravel_multi_index(grid_coords, grid_shape, order='C')

# Create binning matrix (which includes empty bins)
rows, cols = B_non_empty.nonzero()
rows = grid_coords_linear[rows]
val = np.ones(len(cols))

# Create a complete binning matrix (num pixels x num cells)
B = sp.csr_matrix((val, (rows,cols)), shape=(n_pixels, n_cells))

# Bin points on a grid
BA = B @ A

# Convolutions in x and y
Gy, Gx = tuple(_gaussian_blur_matrix(w, 2.0) for w in grid_shape)

# Compute KDE
s = BA.shape
h, w = grid_shape
c = BA.shape[1]

# Convolve in y
kde = (Gy @ BA.reshape((h, c*w))).reshape(s).tocsr()
kde = (kde.T.reshape((h*c,w))@ Gx.T).reshape((c,h*w)).T.tocsr()

# Remove background
kde = kde.multiply(kde > threshold)

# Find label mask
label_mask = kde.tocoo().argmax(axis=1).A.flatten().reshape(grid_shape)

return label_mask, grid_props




def _compute_boundary_label_mask(im):
Dy, Dx = tuple(_finite_difference_matrix(w) for w in im.shape)
not_bg = im > 0
y = Dy @ im
x = im @ Dx.T
return im * (((y != 0) | (x != 0)) & not_bg)

def _create_raster(trees, shape):
raster = np.zeros(shape, dtype='bool')
ind = np.vstack(list(trees.values()))
raster[ind[:,0],ind[:,1]] = True
return raster


def create_binary_edges(xy: np.ndarray, cell: np.ndarray, gridstep: float, threshold:float, dapi_shape: Tuple[int,int]) -> np.ndarray:
A, _ = attribute_matrix(cell)
print('Binning ...')
cell_label_mask, grid_props = _compute_cell_label_mask(xy, A, gridstep, threshold, dapi_shape)
print('Compute label mask ...')
boundary_label_mask = _compute_boundary_label_mask(cell_label_mask)
print('Finding MST ...')
trees, _ = _find_minimum_spanning_trees(boundary_label_mask)
print('Creating rasters ...')
raster = _create_raster(trees, grid_props['grid_size'])
print('Rescale ...')
raster = rescale(raster, gridstep).T
return raster
145 changes: 144 additions & 1 deletion is3g/_knn_tools/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,35 @@
from scipy.spatial import cKDTree

from .linalg import connectivity_matrix, spatial_binning_matrix

from scipy.spatial import Delaunay


def delauny_edges(xy: np.ndarray) -> List[Tuple[int, int]]:
# Perform Delaunay triangulation
tri = Delaunay(xy)

# Get the indices of the vertices forming each triangle
triangles = tri.simplices

# Create a set to store the unique edges
edges_set = set()

# Iterate over each triangle
for triangle in triangles:
# Get the indices of the vertices forming the edges
edge1 = (triangle[0], triangle[1])
edge2 = (triangle[1], triangle[2])
edge3 = (triangle[2], triangle[0])

# Add the edges to the set
edges_set.add(edge1)
edges_set.add(edge2)
edges_set.add(edge3)

# Convert the set of edges to a list
edges = list(edges_set)
edges = [(min(e0,e1),max(e0,e1)) for e0,e1 in edges if e0 != e1]
return edges

def knn_undirected_edges(xy: np.ndarray, k: int) -> List[Tuple[int, int]]:
"""
Expand Down Expand Up @@ -255,3 +283,118 @@ def _pick_point(self):
output = self._queue[self._counter]
self._counter += 1
return output


class SimCLRSampler:
def __init__(
self,
xy: np.ndarray,
neighbor_max_distance: float,
non_neighbor_distance_interval: Tuple[float, float],
):
"""
Initializes the PairSampler object.
Args:
xy (np.ndarray): Array of shape (n, 2) containing the (x, y) coordinates of
the points.
neighbor_max_distance (float): Maximum distance between two points for them
to be considered neighbors.
non_neighbor_distance_interval (Tuple[float, float]): Tuple containing the
minimum and maximum distance between two
points for them to be considered non-neighbors.
Returns:
None
"""

self.r_min = non_neighbor_distance_interval[0]
# Find positive neighbors
self._positive_neighbors = cKDTree(xy).query_ball_point(
xy, neighbor_max_distance
)
# No self loops
self._positive_neighbors = [
[j for j in n if i != j] for i, n in enumerate(self._positive_neighbors)
]
self.xy = xy
# Bin data
bin_width = non_neighbor_distance_interval[0] / 3.0
bin_matrix = spatial_binning_matrix(xy, bin_width)

# Keep only non-empty bins
non_empty_bins = bin_matrix.sum(axis=1).A.flatten() > 0
# Keep only id of non_empty bins
bin_matrix = bin_matrix[non_empty_bins]
self._bin_ids = bin_matrix.argmax(axis=0).A.flatten()
# Get number of points
self._points = list(np.arange(len(xy)))
# Get location of the bins
bin_locations = (bin_matrix @ xy) / bin_matrix.sum(axis=1)
bin_locations = bin_locations.A

self._neighboring_bins = cKDTree(bin_locations / bin_width).query_ball_point(
bin_locations / bin_width, non_neighbor_distance_interval[1] / bin_width
)
p = np.array([i for i, n in enumerate(self._neighboring_bins) for _ in n])
q = np.array([j for i, n in enumerate(self._neighboring_bins) for j in n])
dist = np.linalg.norm(bin_locations[p] - bin_locations[q], axis=1)
adj = sp.csr_matrix(
(dist, (p, q)), shape=(len(bin_locations), len(bin_locations))
)
adj = adj > non_neighbor_distance_interval[0]
self._neighboring_bins = adj.tolil().rows
self._bin_matrix = bin_matrix.tolil().rows

# Prepare a queue
self._queue = [i for i in range(len(self._points))]
self._counter = 0
random.shuffle(self._queue)

def sample(self, neighbor: bool) -> Tuple[int, int]:
"""
Sample a pair of points from the provided data array.
Args:
neighbor (bool): A boolean indicating whether to sample a pair of
neighboring points or not.
Returns:
A tuple of two integers representing the indices of the sampled
points. If `neighbor=True`, the sampled pair of points will be
neighbors, i.e., their distance will be less than
`neighbor_max_distance`. If `neighbor=False`, the sampled pair of
points will not be neighbors, i.e., their distance will be greater
than or equal to `r_min`.
"""

while True:
anchor = self._pick_point()
if len(self._positive_neighbors[anchor]):
gt = _choose_one(self._positive_neighbors[anchor])
negatives = []
for _ in range(50):
bin_id_p1 = self._bin_ids[anchor]
if len(self._neighboring_bins[bin_id_p1]):
bin_id_p2 = _choose_one(self._neighboring_bins[bin_id_p1])
if len(self._bin_matrix[bin_id_p2]):
negative = _choose_one(self._bin_matrix[bin_id_p2])
negatives.append(negative)
negatives.append(gt)
negatives = np.array(negatives)
ind = np.arange(len(negatives))
random.shuffle(ind)
negatives = negatives[ind]
gt = np.where(ind==50)[0]
break
return anchor, gt, negatives


def _pick_point(self):
if self._counter == len(self._queue):
self._queue = [i for i in range(len(self._points))]
random.shuffle(self._queue)
self._counter = 0
output = self._queue[self._counter]
self._counter += 1
return output
Loading

0 comments on commit 4fa4881

Please sign in to comment.